http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java deleted file mode 100644 index e4ad34e..0000000 --- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.lens.ml.spark.trainers; - -import java.util.List; - -import org.apache.lens.api.LensConf; -import org.apache.lens.api.LensException; -import org.apache.lens.ml.*; -import org.apache.lens.ml.spark.HiveTableRDD; -import org.apache.lens.ml.spark.models.KMeansClusteringModel; - -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.api.FieldSchema; -import org.apache.hadoop.hive.ql.metadata.Hive; -import org.apache.hadoop.hive.ql.metadata.Table; -import org.apache.hadoop.io.WritableComparable; -import org.apache.hive.hcatalog.data.HCatRecord; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.KMeans; -import org.apache.spark.mllib.clustering.KMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -import scala.Tuple2; - -/** - * The Class KMeansTrainer. - */ -@Algorithm(name = "spark_kmeans_trainer", description = "Spark MLLib KMeans trainer") -public class KMeansTrainer implements MLTrainer { - - /** The conf. */ - private transient LensConf conf; - - /** The spark context. */ - private JavaSparkContext sparkContext; - - /** The part filter. */ - @TrainerParam(name = "partition", help = "Partition filter to be used while constructing table RDD") - private String partFilter = null; - - /** The k. */ - @TrainerParam(name = "k", help = "Number of cluster") - private int k; - - /** The max iterations. */ - @TrainerParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100") - private int maxIterations = 100; - - /** The runs. */ - @TrainerParam(name = "runs", help = "Number of parallel run", defaultValue = "1") - private int runs = 1; - - /** The initialization mode. */ - @TrainerParam(name = "initializationMode", - help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||") - private String initializationMode = "k-means||"; - - @Override - public String getName() { - return getClass().getAnnotation(Algorithm.class).name(); - } - - @Override - public String getDescription() { - return getClass().getAnnotation(Algorithm.class).description(); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.MLTrainer#configure(org.apache.lens.api.LensConf) - */ - @Override - public void configure(LensConf configuration) { - this.conf = configuration; - } - - @Override - public LensConf getConf() { - return conf; - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.MLTrainer#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String, - * java.lang.String, java.lang.String[]) - */ - @Override - public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException { - List<String> features = TrainerArgParser.parseArgs(this, params); - final int[] featurePositions = new int[features.size()]; - final int NUM_FEATURES = features.size(); - - JavaPairRDD<WritableComparable, HCatRecord> rdd = null; - try { - // Map feature names to positions - Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table); - List<FieldSchema> allCols = tbl.getAllCols(); - int f = 0; - for (int i = 0; i < tbl.getAllCols().size(); i++) { - String colName = allCols.get(i).getName(); - if (features.contains(colName)) { - featurePositions[f++] = i; - } - } - - rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter); - JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() { - @Override - public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception { - HCatRecord hCatRecord = v1._2(); - double[] arr = new double[NUM_FEATURES]; - for (int i = 0; i < NUM_FEATURES; i++) { - Object val = hCatRecord.get(featurePositions[i]); - arr[i] = val == null ? 0d : (Double) val; - } - return Vectors.dense(arr); - } - }); - - KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode); - return new KMeansClusteringModel(modelId, model); - } catch (Exception e) { - throw new LensException("KMeans trainer failed for " + db + "." + table, e); - } - } - - /** - * To hive conf. - * - * @param conf the conf - * @return the hive conf - */ - private HiveConf toHiveConf(LensConf conf) { - HiveConf hiveConf = new HiveConf(); - for (String key : conf.getProperties().keySet()) { - hiveConf.set(key, conf.getProperties().get(key)); - } - return hiveConf; - } -}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java deleted file mode 100644 index b12e2be..0000000 --- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.lens.ml.spark.trainers; - -import java.util.Map; - -import org.apache.lens.api.LensException; -import org.apache.lens.ml.Algorithm; -import org.apache.lens.ml.TrainerParam; -import org.apache.lens.ml.spark.models.BaseSparkClassificationModel; -import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel; - -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.rdd.RDD; - -/** - * The Class LogisticRegressionTrainer. - */ -@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression trainer") -public class LogisticRegressionTrainer extends BaseSparkTrainer { - - /** The iterations. */ - @TrainerParam(name = "iterations", help = "Max number of iterations", defaultValue = "100") - private int iterations; - - /** The step size. */ - @TrainerParam(name = "stepSize", help = "Step size", defaultValue = "1.0d") - private double stepSize; - - /** The min batch fraction. */ - @TrainerParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d") - private double minBatchFraction; - - /** - * Instantiates a new logistic regression trainer. - * - * @param name the name - * @param description the description - */ - public LogisticRegressionTrainer(String name, String description) { - super(name, description); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map) - */ - @Override - public void parseTrainerParams(Map<String, String> params) { - iterations = getParamValue("iterations", 100); - stepSize = getParamValue("stepSize", 1.0d); - minBatchFraction = getParamValue("minBatchFraction", 1.0d); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD) - */ - @Override - protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD) - throws LensException { - LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize, - minBatchFraction); - return new LogitRegressionClassificationModel(modelId, lrModel); - } -} http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java deleted file mode 100644 index 4eb50c9..0000000 --- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.lens.ml.spark.trainers; - -import java.util.Map; - -import org.apache.lens.api.LensException; -import org.apache.lens.ml.Algorithm; -import org.apache.lens.ml.TrainerParam; -import org.apache.lens.ml.spark.models.BaseSparkClassificationModel; -import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel; - -import org.apache.spark.mllib.classification.NaiveBayes; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.rdd.RDD; - -/** - * The Class NaiveBayesTrainer. - */ -@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier trainer") -public class NaiveBayesTrainer extends BaseSparkTrainer { - - /** The lambda. */ - @TrainerParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d") - private double lambda = 1.0; - - /** - * Instantiates a new naive bayes trainer. - * - * @param name the name - * @param description the description - */ - public NaiveBayesTrainer(String name, String description) { - super(name, description); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map) - */ - @Override - public void parseTrainerParams(Map<String, String> params) { - lambda = getParamValue("lambda", 1.0d); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD) - */ - @Override - protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD) - throws LensException { - return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda)); - } -} http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java deleted file mode 100644 index cf7a7c9..0000000 --- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.lens.ml.spark.trainers; - -import java.util.Map; - -import org.apache.lens.api.LensException; -import org.apache.lens.ml.Algorithm; -import org.apache.lens.ml.TrainerParam; -import org.apache.lens.ml.spark.models.BaseSparkClassificationModel; -import org.apache.lens.ml.spark.models.SVMClassificationModel; - -import org.apache.spark.mllib.classification.SVMModel; -import org.apache.spark.mllib.classification.SVMWithSGD; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.rdd.RDD; - -/** - * The Class SVMTrainer. - */ -@Algorithm(name = "spark_svm", description = "Spark SVML classifier trainer") -public class SVMTrainer extends BaseSparkTrainer { - - /** The min batch fraction. */ - @TrainerParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d") - private double minBatchFraction; - - /** The reg param. */ - @TrainerParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d") - private double regParam; - - /** The step size. */ - @TrainerParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d") - private double stepSize; - - /** The iterations. */ - @TrainerParam(name = "iterations", help = "Number of iterations", defaultValue = "100") - private int iterations; - - /** - * Instantiates a new SVM trainer. - * - * @param name the name - * @param description the description - */ - public SVMTrainer(String name, String description) { - super(name, description); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map) - */ - @Override - public void parseTrainerParams(Map<String, String> params) { - minBatchFraction = getParamValue("minBatchFraction", 1.0); - regParam = getParamValue("regParam", 1.0); - stepSize = getParamValue("stepSize", 1.0); - iterations = getParamValue("iterations", 100); - } - - /* - * (non-Javadoc) - * - * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD) - */ - @Override - protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD) - throws LensException { - SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction); - return new SVMClassificationModel(modelId, svmModel); - } -} http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java index e413808..aa59100 100644 --- a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java +++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java @@ -49,7 +49,7 @@ public class MLTask implements Runnable { private State taskState; /** - * Name of the trainer/algorithm. + * Name of the algo/algorithm. */ @Getter private String algorithm; @@ -253,10 +253,10 @@ public class MLTask implements Runnable { LOG.info("Working in Lens server"); } - String[] trainerArgs = buildTrainingArgs(); - LOG.info("Starting task " + taskID + " trainer args: " + Arrays.toString(trainerArgs)); + String[] algoArgs = buildTrainingArgs(); + LOG.info("Starting task " + taskID + " algo args: " + Arrays.toString(algoArgs)); - modelID = ml.train(trainingTable, algorithm, trainerArgs); + modelID = ml.train(trainingTable, algorithm, algoArgs); printModelMetadata(taskID, modelID); LOG.info("Starting test " + taskID); http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java index d34d77b..9eb2723 100644 --- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java +++ b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java @@ -80,11 +80,11 @@ public class MLServiceImpl extends CompositeService implements MLService { /* * (non-Javadoc) * - * @see org.apache.lens.ml.LensML#getTrainerForName(java.lang.String) + * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String) */ @Override - public MLTrainer getTrainerForName(String algorithm) throws LensException { - return ml.getTrainerForName(algorithm); + public MLAlgo getAlgoForName(String algorithm) throws LensException { + return ml.getAlgoForName(algorithm); } /* http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java index 992e610..c0b32d3 100644 --- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java +++ b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java @@ -129,15 +129,15 @@ public class MLServiceResource { } /** - * Get a list of trainers available + * Get a list of algos available * * @return */ @GET - @Path("trainers") - public StringList getTrainerNames() { - List<String> trainers = getMlService().getAlgorithms(); - StringList result = new StringList(trainers); + @Path("algos") + public StringList getAlgoNames() { + List<String> algos = getMlService().getAlgorithms(); + StringList result = new StringList(algos); return result; } @@ -148,7 +148,7 @@ public class MLServiceResource { * @return the param description */ @GET - @Path("trainers/{algorithm}") + @Path("algos/{algorithm}") public StringList getParamDescription(@PathParam("algorithm") String algorithm) { Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm); if (paramDesc == null) { @@ -196,7 +196,7 @@ public class MLServiceResource { throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm); } - ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getTrainerName(), StringUtils.join( + ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join( model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID), model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ",")); return meta; @@ -243,9 +243,9 @@ public class MLServiceResource { public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form) throws LensException { - // Check if trainer is valid - if (getMlService().getTrainerForName(algorithm) == null) { - throw new NotFoundException("Trainer for algo: " + algorithm + " not found"); + // Check if algo is valid + if (getMlService().getAlgoForName(algorithm) == null) { + throw new NotFoundException("Algo for algo: " + algorithm + " not found"); } if (isBlank(form.getFirst("table"))) { @@ -264,7 +264,7 @@ public class MLServiceResource { throw new BadRequestException("At least one feature is required"); } - List<String> trainerArgs = new ArrayList<String>(); + List<String> algoArgs = new ArrayList<String>(); Set<Map.Entry<String, List<String>>> paramSet = form.entrySet(); for (Map.Entry<String, List<String>> e : paramSet) { @@ -274,19 +274,19 @@ public class MLServiceResource { continue; } else if ("feature".equals(p)) { for (String feature : values) { - trainerArgs.add("feature"); - trainerArgs.add(feature); + algoArgs.add("feature"); + algoArgs.add(feature); } } else if ("label".equals(p)) { - trainerArgs.add("label"); - trainerArgs.add(values.get(0)); + algoArgs.add("label"); + algoArgs.add(values.get(0)); } else { - trainerArgs.add(p); - trainerArgs.add(values.get(0)); + algoArgs.add(p); + algoArgs.add(values.get(0)); } } - LOG.info("Training table " + table + " with algo " + algorithm + " params=" + trainerArgs.toString()); - String modelId = getMlService().train(table, algorithm, trainerArgs.toArray(new String[]{})); + LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString()); + String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{})); LOG.info("Done training " + table + " modelid = " + modelId); return modelId; } http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java ---------------------------------------------------------------------- diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java index 7548ed2..1d40b76 100644 --- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java +++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java @@ -28,10 +28,10 @@ import javax.ws.rs.core.Application; import org.apache.lens.api.LensSessionHandle; import org.apache.lens.client.LensConnectionParams; import org.apache.lens.client.LensMLClient; -import org.apache.lens.ml.spark.trainers.DecisionTreeTrainer; -import org.apache.lens.ml.spark.trainers.LogisticRegressionTrainer; -import org.apache.lens.ml.spark.trainers.NaiveBayesTrainer; -import org.apache.lens.ml.spark.trainers.SVMTrainer; +import org.apache.lens.ml.spark.algos.DecisionTreeAlgo; +import org.apache.lens.ml.spark.algos.LogisticRegressionAlgo; +import org.apache.lens.ml.spark.algos.NaiveBayesAlgo; +import org.apache.lens.ml.spark.algos.SVMAlgo; import org.apache.lens.ml.task.MLTask; import org.apache.lens.server.LensJerseyTest; import org.apache.lens.server.LensServerConf; @@ -54,6 +54,7 @@ import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; + import org.apache.hive.service.Service; import org.glassfish.jersey.client.ClientConfig; @@ -137,26 +138,26 @@ public class TestMLResource extends LensJerseyTest { } @Test - public void testGetTrainers() throws Exception { - List<String> trainerNames = mlClient.getAlgorithms(); - Assert.assertNotNull(trainerNames); + public void testGetAlgos() throws Exception { + List<String> algoNames = mlClient.getAlgorithms(); + Assert.assertNotNull(algoNames); - Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(NaiveBayesTrainer.class)), - MLUtils.getTrainerName(NaiveBayesTrainer.class)); + Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)), + MLUtils.getAlgoName(NaiveBayesAlgo.class)); - Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(SVMTrainer.class)), - MLUtils.getTrainerName(SVMTrainer.class)); + Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(SVMAlgo.class)), + MLUtils.getAlgoName(SVMAlgo.class)); - Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(LogisticRegressionTrainer.class)), - MLUtils.getTrainerName(LogisticRegressionTrainer.class)); + Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)), + MLUtils.getAlgoName(LogisticRegressionAlgo.class)); - Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(DecisionTreeTrainer.class)), - MLUtils.getTrainerName(DecisionTreeTrainer.class)); + Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)), + MLUtils.getAlgoName(DecisionTreeAlgo.class)); } @Test - public void testGetTrainerParams() throws Exception { - Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils.getTrainerName(DecisionTreeTrainer.class)); + public void testGetAlgoParams() throws Exception { + Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils.getAlgoName(DecisionTreeAlgo.class)); Assert.assertNotNull(params); Assert.assertFalse(params.isEmpty()); @@ -168,7 +169,7 @@ public class TestMLResource extends LensJerseyTest { @Test public void trainAndEval() throws Exception { LOG.info("Starting train & eval"); - final String algoName = MLUtils.getTrainerName(NaiveBayesTrainer.class); + final String algoName = MLUtils.getAlgoName(NaiveBayesAlgo.class); HiveConf conf = new HiveConf(); String database = "default"; String tableName = "naivebayes_training_table";