You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by am...@apache.org on 2015/04/15 21:47:43 UTC
[19/30] incubator-lens git commit: LENS-319: Renamed Trainer to Algo
(sharad)
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java
deleted file mode 100644
index e4ad34e..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/KMeansTrainer.java
+++ /dev/null
@@ -1,163 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.*;
-import org.apache.lens.ml.spark.HiveTableRDD;
-import org.apache.lens.ml.spark.models.KMeansClusteringModel;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.clustering.KMeans;
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-
-import scala.Tuple2;
-
-/**
- * The Class KMeansTrainer.
- */
-@Algorithm(name = "spark_kmeans_trainer", description = "Spark MLLib KMeans trainer")
-public class KMeansTrainer implements MLTrainer {
-
- /** The conf. */
- private transient LensConf conf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /** The part filter. */
- @TrainerParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
- private String partFilter = null;
-
- /** The k. */
- @TrainerParam(name = "k", help = "Number of cluster")
- private int k;
-
- /** The max iterations. */
- @TrainerParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
- private int maxIterations = 100;
-
- /** The runs. */
- @TrainerParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
- private int runs = 1;
-
- /** The initialization mode. */
- @TrainerParam(name = "initializationMode",
- help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
- private String initializationMode = "k-means||";
-
- @Override
- public String getName() {
- return getClass().getAnnotation(Algorithm.class).name();
- }
-
- @Override
- public String getDescription() {
- return getClass().getAnnotation(Algorithm.class).description();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLTrainer#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLTrainer#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
- List<String> features = TrainerArgParser.parseArgs(this, params);
- final int[] featurePositions = new int[features.size()];
- final int NUM_FEATURES = features.size();
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
- try {
- // Map feature names to positions
- Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
- List<FieldSchema> allCols = tbl.getAllCols();
- int f = 0;
- for (int i = 0; i < tbl.getAllCols().size(); i++) {
- String colName = allCols.get(i).getName();
- if (features.contains(colName)) {
- featurePositions[f++] = i;
- }
- }
-
- rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
- JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
- @Override
- public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
- HCatRecord hCatRecord = v1._2();
- double[] arr = new double[NUM_FEATURES];
- for (int i = 0; i < NUM_FEATURES; i++) {
- Object val = hCatRecord.get(featurePositions[i]);
- arr[i] = val == null ? 0d : (Double) val;
- }
- return Vectors.dense(arr);
- }
- });
-
- KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
- return new KMeansClusteringModel(modelId, model);
- } catch (Exception e) {
- throw new LensException("KMeans trainer failed for " + db + "." + table, e);
- }
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- private HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java
deleted file mode 100644
index b12e2be..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/LogisticRegressionTrainer.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel;
-
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class LogisticRegressionTrainer.
- */
-@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression trainer")
-public class LogisticRegressionTrainer extends BaseSparkTrainer {
-
- /** The iterations. */
- @TrainerParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
- private int iterations;
-
- /** The step size. */
- @TrainerParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The min batch fraction. */
- @TrainerParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /**
- * Instantiates a new logistic regression trainer.
- *
- * @param name the name
- * @param description the description
- */
- public LogisticRegressionTrainer(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map)
- */
- @Override
- public void parseTrainerParams(Map<String, String> params) {
- iterations = getParamValue("iterations", 100);
- stepSize = getParamValue("stepSize", 1.0d);
- minBatchFraction = getParamValue("minBatchFraction", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
- minBatchFraction);
- return new LogitRegressionClassificationModel(modelId, lrModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java
deleted file mode 100644
index 4eb50c9..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/NaiveBayesTrainer.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
-
-import org.apache.spark.mllib.classification.NaiveBayes;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class NaiveBayesTrainer.
- */
-@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier trainer")
-public class NaiveBayesTrainer extends BaseSparkTrainer {
-
- /** The lambda. */
- @TrainerParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
- private double lambda = 1.0;
-
- /**
- * Instantiates a new naive bayes trainer.
- *
- * @param name the name
- * @param description the description
- */
- public NaiveBayesTrainer(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map)
- */
- @Override
- public void parseTrainerParams(Map<String, String> params) {
- lambda = getParamValue("lambda", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java
deleted file mode 100644
index cf7a7c9..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/SVMTrainer.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.SVMClassificationModel;
-
-import org.apache.spark.mllib.classification.SVMModel;
-import org.apache.spark.mllib.classification.SVMWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class SVMTrainer.
- */
-@Algorithm(name = "spark_svm", description = "Spark SVML classifier trainer")
-public class SVMTrainer extends BaseSparkTrainer {
-
- /** The min batch fraction. */
- @TrainerParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /** The reg param. */
- @TrainerParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
- private double regParam;
-
- /** The step size. */
- @TrainerParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The iterations. */
- @TrainerParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
- private int iterations;
-
- /**
- * Instantiates a new SVM trainer.
- *
- * @param name the name
- * @param description the description
- */
- public SVMTrainer(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map)
- */
- @Override
- public void parseTrainerParams(Map<String, String> params) {
- minBatchFraction = getParamValue("minBatchFraction", 1.0);
- regParam = getParamValue("regParam", 1.0);
- stepSize = getParamValue("stepSize", 1.0);
- iterations = getParamValue("iterations", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
- return new SVMClassificationModel(modelId, svmModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
index e413808..aa59100 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
@@ -49,7 +49,7 @@ public class MLTask implements Runnable {
private State taskState;
/**
- * Name of the trainer/algorithm.
+ * Name of the algo/algorithm.
*/
@Getter
private String algorithm;
@@ -253,10 +253,10 @@ public class MLTask implements Runnable {
LOG.info("Working in Lens server");
}
- String[] trainerArgs = buildTrainingArgs();
- LOG.info("Starting task " + taskID + " trainer args: " + Arrays.toString(trainerArgs));
+ String[] algoArgs = buildTrainingArgs();
+ LOG.info("Starting task " + taskID + " algo args: " + Arrays.toString(algoArgs));
- modelID = ml.train(trainingTable, algorithm, trainerArgs);
+ modelID = ml.train(trainingTable, algorithm, algoArgs);
printModelMetadata(taskID, modelID);
LOG.info("Starting test " + taskID);
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
index d34d77b..9eb2723 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
@@ -80,11 +80,11 @@ public class MLServiceImpl extends CompositeService implements MLService {
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.LensML#getTrainerForName(java.lang.String)
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
*/
@Override
- public MLTrainer getTrainerForName(String algorithm) throws LensException {
- return ml.getTrainerForName(algorithm);
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ return ml.getAlgoForName(algorithm);
}
/*
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
index 992e610..c0b32d3 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
@@ -129,15 +129,15 @@ public class MLServiceResource {
}
/**
- * Get a list of trainers available
+ * Get a list of algos available
*
* @return
*/
@GET
- @Path("trainers")
- public StringList getTrainerNames() {
- List<String> trainers = getMlService().getAlgorithms();
- StringList result = new StringList(trainers);
+ @Path("algos")
+ public StringList getAlgoNames() {
+ List<String> algos = getMlService().getAlgorithms();
+ StringList result = new StringList(algos);
return result;
}
@@ -148,7 +148,7 @@ public class MLServiceResource {
* @return the param description
*/
@GET
- @Path("trainers/{algorithm}")
+ @Path("algos/{algorithm}")
public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm);
if (paramDesc == null) {
@@ -196,7 +196,7 @@ public class MLServiceResource {
throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
}
- ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getTrainerName(), StringUtils.join(
+ ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
return meta;
@@ -243,9 +243,9 @@ public class MLServiceResource {
public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form)
throws LensException {
- // Check if trainer is valid
- if (getMlService().getTrainerForName(algorithm) == null) {
- throw new NotFoundException("Trainer for algo: " + algorithm + " not found");
+ // Check if algo is valid
+ if (getMlService().getAlgoForName(algorithm) == null) {
+ throw new NotFoundException("Algo for algo: " + algorithm + " not found");
}
if (isBlank(form.getFirst("table"))) {
@@ -264,7 +264,7 @@ public class MLServiceResource {
throw new BadRequestException("At least one feature is required");
}
- List<String> trainerArgs = new ArrayList<String>();
+ List<String> algoArgs = new ArrayList<String>();
Set<Map.Entry<String, List<String>>> paramSet = form.entrySet();
for (Map.Entry<String, List<String>> e : paramSet) {
@@ -274,19 +274,19 @@ public class MLServiceResource {
continue;
} else if ("feature".equals(p)) {
for (String feature : values) {
- trainerArgs.add("feature");
- trainerArgs.add(feature);
+ algoArgs.add("feature");
+ algoArgs.add(feature);
}
} else if ("label".equals(p)) {
- trainerArgs.add("label");
- trainerArgs.add(values.get(0));
+ algoArgs.add("label");
+ algoArgs.add(values.get(0));
} else {
- trainerArgs.add(p);
- trainerArgs.add(values.get(0));
+ algoArgs.add(p);
+ algoArgs.add(values.get(0));
}
}
- LOG.info("Training table " + table + " with algo " + algorithm + " params=" + trainerArgs.toString());
- String modelId = getMlService().train(table, algorithm, trainerArgs.toArray(new String[]{}));
+ LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString());
+ String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
LOG.info("Done training " + table + " modelid = " + modelId);
return modelId;
}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
index 7548ed2..1d40b76 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
@@ -28,10 +28,10 @@ import javax.ws.rs.core.Application;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.client.LensConnectionParams;
import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.spark.trainers.DecisionTreeTrainer;
-import org.apache.lens.ml.spark.trainers.LogisticRegressionTrainer;
-import org.apache.lens.ml.spark.trainers.NaiveBayesTrainer;
-import org.apache.lens.ml.spark.trainers.SVMTrainer;
+import org.apache.lens.ml.spark.algos.DecisionTreeAlgo;
+import org.apache.lens.ml.spark.algos.LogisticRegressionAlgo;
+import org.apache.lens.ml.spark.algos.NaiveBayesAlgo;
+import org.apache.lens.ml.spark.algos.SVMAlgo;
import org.apache.lens.ml.task.MLTask;
import org.apache.lens.server.LensJerseyTest;
import org.apache.lens.server.LensServerConf;
@@ -54,6 +54,7 @@ import org.apache.hadoop.hive.ql.metadata.Hive;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
+
import org.apache.hive.service.Service;
import org.glassfish.jersey.client.ClientConfig;
@@ -137,26 +138,26 @@ public class TestMLResource extends LensJerseyTest {
}
@Test
- public void testGetTrainers() throws Exception {
- List<String> trainerNames = mlClient.getAlgorithms();
- Assert.assertNotNull(trainerNames);
+ public void testGetAlgos() throws Exception {
+ List<String> algoNames = mlClient.getAlgorithms();
+ Assert.assertNotNull(algoNames);
- Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(NaiveBayesTrainer.class)),
- MLUtils.getTrainerName(NaiveBayesTrainer.class));
+ Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)),
+ MLUtils.getAlgoName(NaiveBayesAlgo.class));
- Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(SVMTrainer.class)),
- MLUtils.getTrainerName(SVMTrainer.class));
+ Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(SVMAlgo.class)),
+ MLUtils.getAlgoName(SVMAlgo.class));
- Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(LogisticRegressionTrainer.class)),
- MLUtils.getTrainerName(LogisticRegressionTrainer.class));
+ Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)),
+ MLUtils.getAlgoName(LogisticRegressionAlgo.class));
- Assert.assertTrue(trainerNames.contains(MLUtils.getTrainerName(DecisionTreeTrainer.class)),
- MLUtils.getTrainerName(DecisionTreeTrainer.class));
+ Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)),
+ MLUtils.getAlgoName(DecisionTreeAlgo.class));
}
@Test
- public void testGetTrainerParams() throws Exception {
- Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils.getTrainerName(DecisionTreeTrainer.class));
+ public void testGetAlgoParams() throws Exception {
+ Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils.getAlgoName(DecisionTreeAlgo.class));
Assert.assertNotNull(params);
Assert.assertFalse(params.isEmpty());
@@ -168,7 +169,7 @@ public class TestMLResource extends LensJerseyTest {
@Test
public void trainAndEval() throws Exception {
LOG.info("Starting train & eval");
- final String algoName = MLUtils.getTrainerName(NaiveBayesTrainer.class);
+ final String algoName = MLUtils.getAlgoName(NaiveBayesAlgo.class);
HiveConf conf = new HiveConf();
String database = "default";
String tableName = "naivebayes_training_table";