You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by am...@apache.org on 2015/04/15 21:47:44 UTC
[20/30] incubator-lens git commit: LENS-319: Renamed Trainer to Algo
(sharad)
LENS-319: Renamed Trainer to Algo (sharad)
Project: http://git-wip-us.apache.org/repos/asf/incubator-lens/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-lens/commit/49fef8e2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-lens/tree/49fef8e2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-lens/diff/49fef8e2
Branch: refs/heads/master
Commit: 49fef8e2bb3995ee7d5b3db74bf954c205c187fc
Parents: ff152d9
Author: Sharad Agarwal <sh...@flipkarts-MacBook-Pro.local>
Authored: Wed Feb 18 14:40:17 2015 +0530
Committer: Amareshwari Sriramadasu <am...@apache.org>
Committed: Thu Feb 19 10:23:12 2015 +0530
----------------------------------------------------------------------
.../org/apache/lens/client/LensMLClient.java | 14 +-
.../apache/lens/client/LensMLJerseyClient.java | 14 +-
.../java/org/apache/lens/ml/AlgoArgParser.java | 114 ++++++++
.../main/java/org/apache/lens/ml/AlgoParam.java | 53 ++++
.../java/org/apache/lens/ml/Algorithms.java | 32 +-
.../main/java/org/apache/lens/ml/LensML.java | 6 +-
.../java/org/apache/lens/ml/LensMLImpl.java | 66 ++---
.../main/java/org/apache/lens/ml/MLAlgo.java | 53 ++++
.../main/java/org/apache/lens/ml/MLDriver.java | 18 +-
.../main/java/org/apache/lens/ml/MLModel.java | 4 +-
.../main/java/org/apache/lens/ml/MLTrainer.java | 53 ----
.../main/java/org/apache/lens/ml/MLUtils.java | 6 +-
.../org/apache/lens/ml/TrainerArgParser.java | 114 --------
.../java/org/apache/lens/ml/TrainerParam.java | 53 ----
.../org/apache/lens/ml/spark/SparkMLDriver.java | 42 +--
.../lens/ml/spark/algos/BaseSparkAlgo.java | 290 +++++++++++++++++++
.../lens/ml/spark/algos/DecisionTreeAlgo.java | 109 +++++++
.../apache/lens/ml/spark/algos/KMeansAlgo.java | 163 +++++++++++
.../ml/spark/algos/LogisticRegressionAlgo.java | 86 ++++++
.../lens/ml/spark/algos/NaiveBayesAlgo.java | 73 +++++
.../org/apache/lens/ml/spark/algos/SVMAlgo.java | 90 ++++++
.../ml/spark/trainers/BaseSparkTrainer.java | 289 ------------------
.../ml/spark/trainers/DecisionTreeTrainer.java | 109 -------
.../lens/ml/spark/trainers/KMeansTrainer.java | 163 -----------
.../trainers/LogisticRegressionTrainer.java | 86 ------
.../ml/spark/trainers/NaiveBayesTrainer.java | 73 -----
.../lens/ml/spark/trainers/SVMTrainer.java | 90 ------
.../java/org/apache/lens/ml/task/MLTask.java | 8 +-
.../apache/lens/server/ml/MLServiceImpl.java | 6 +-
.../lens/server/ml/MLServiceResource.java | 38 +--
.../java/org/apache/lens/ml/TestMLResource.java | 37 +--
31 files changed, 1177 insertions(+), 1175 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
index 1bad7a0..9f7fa26 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
@@ -35,9 +35,9 @@ import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.ml.ModelMetadata;
import org.apache.lens.api.ml.TestReport;
import org.apache.lens.ml.LensML;
+import org.apache.lens.ml.MLAlgo;
import org.apache.lens.ml.MLModel;
import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.MLTrainer;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
@@ -81,7 +81,7 @@ public class LensMLClient implements LensML, Closeable {
*/
@Override
public List<String> getAlgorithms() {
- return client.getTrainerNames();
+ return client.getAlgoNames();
}
/**
@@ -92,7 +92,7 @@ public class LensMLClient implements LensML, Closeable {
*/
@Override
public Map<String, String> getAlgoParamDescription(String algorithm) {
- List<String> paramDesc = client.getParamDescriptionOfTrainer(algorithm);
+ List<String> paramDesc = client.getParamDescriptionOfAlgo(algorithm);
// convert paramDesc to map
Map<String, String> paramDescMap = new LinkedHashMap<String, String>();
for (String str : paramDesc) {
@@ -103,15 +103,15 @@ public class LensMLClient implements LensML, Closeable {
}
/**
- * Get a trainer object instance which could be used to generate a model of the given algorithm.
+ * Get a algo object instance which could be used to generate a model of the given algorithm.
*
* @param algorithm the algorithm
- * @return the trainer for name
+ * @return the algo for name
* @throws LensException the lens exception
*/
@Override
- public MLTrainer getTrainerForName(String algorithm) throws LensException {
- throw new UnsupportedOperationException("MLTrainer cannot be accessed from client");
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ throw new UnsupportedOperationException("MLAlgo cannot be accessed from client");
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
index 2b1ece4..af47a41 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
@@ -134,9 +134,9 @@ public class LensMLJerseyClient {
}
}
- public List<String> getTrainerNames() {
- StringList trainerNames = getMLWebTarget().path("trainers").request().get(StringList.class);
- return trainerNames == null ? null : trainerNames.getElements();
+ public List<String> getAlgoNames() {
+ StringList algoNames = getMLWebTarget().path("algos").request().get(StringList.class);
+ return algoNames == null ? null : algoNames.getElements();
}
/**
@@ -234,14 +234,14 @@ public class LensMLJerseyClient {
}
/**
- * Gets the param description of trainer.
+ * Gets the param description of algo.
*
* @param algorithm the algorithm
- * @return the param description of trainer
+ * @return the param description of algo
*/
- public List<String> getParamDescriptionOfTrainer(String algorithm) {
+ public List<String> getParamDescriptionOfAlgo(String algorithm) {
try {
- StringList paramHelp = getMLWebTarget().path("trainers").path(algorithm).request(MediaType.APPLICATION_XML)
+ StringList paramHelp = getMLWebTarget().path("algos").path(algorithm).request(MediaType.APPLICATION_XML)
.get(StringList.class);
return paramHelp.getElements();
} catch (NotFoundException exc) {
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
new file mode 100644
index 0000000..20da083
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+/**
+ * The Class AlgoArgParser.
+ */
+public final class AlgoArgParser {
+ private AlgoArgParser() {
+ }
+
+ /**
+ * The Class CustomArgParser.
+ *
+ * @param <E> the element type
+ */
+ public abstract static class CustomArgParser<E> {
+
+ /**
+ * Parses the.
+ *
+ * @param value the value
+ * @return the e
+ */
+ public abstract E parse(String value);
+ }
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(AlgoArgParser.class);
+
+ /**
+ * Extracts feature names. If the algo has any parameters associated with @AlgoParam annotation, those are set
+ * as well.
+ *
+ * @param algo the algo
+ * @param args the args
+ * @return List of feature column names.
+ */
+ public static List<String> parseArgs(MLAlgo algo, String[] args) {
+ List<String> featureColumns = new ArrayList<String>();
+ Class<? extends MLAlgo> algoClass = algo.getClass();
+ // Get param fields
+ Map<String, Field> fieldMap = new HashMap<String, Field>();
+
+ for (Field fld : algoClass.getDeclaredFields()) {
+ fld.setAccessible(true);
+ AlgoParam paramAnnotation = fld.getAnnotation(AlgoParam.class);
+ if (paramAnnotation != null) {
+ fieldMap.put(paramAnnotation.name(), fld);
+ }
+ }
+
+ for (int i = 0; i < args.length; i += 2) {
+ String key = args[i].trim();
+ String value = args[i + 1].trim();
+
+ try {
+ if ("feature".equalsIgnoreCase(key)) {
+ featureColumns.add(value);
+ } else if (fieldMap.containsKey(key)) {
+ Field f = fieldMap.get(key);
+ if (String.class.equals(f.getType())) {
+ f.set(algo, value);
+ } else if (Integer.TYPE.equals(f.getType())) {
+ f.setInt(algo, Integer.parseInt(value));
+ } else if (Double.TYPE.equals(f.getType())) {
+ f.setDouble(algo, Double.parseDouble(value));
+ } else if (Long.TYPE.equals(f.getType())) {
+ f.setLong(algo, Long.parseLong(value));
+ } else {
+ // check if the algo provides a deserializer for this param
+ String customParserClass = algo.getConf().getProperties().get("lens.ml.args." + key);
+ if (customParserClass != null) {
+ Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
+ .forName(customParserClass);
+ CustomArgParser<?> parser = clz.newInstance();
+ f.set(algo, parser.parse(value));
+ } else {
+ LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
+ }
+ }
+ }
+ } catch (Exception exc) {
+ LOG.error("Error while setting param " + key + " to " + value + " for algo " + algo);
+ }
+ }
+ return featureColumns;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
new file mode 100644
index 0000000..5836f51
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The Interface AlgoParam.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.FIELD)
+public @interface AlgoParam {
+
+ /**
+ * Name.
+ *
+ * @return the string
+ */
+ String name();
+
+ /**
+ * Help.
+ *
+ * @return the string
+ */
+ String help();
+
+ /**
+ * Default value.
+ *
+ * @return the string
+ */
+ String defaultValue() default "None";
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
index 3b74a09..c1b7212 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
@@ -32,41 +32,41 @@ import org.apache.lens.api.LensException;
public class Algorithms {
/** The algorithm classes. */
- private final Map<String, Class<? extends MLTrainer>> algorithmClasses
- = new HashMap<String, Class<? extends MLTrainer>>();
+ private final Map<String, Class<? extends MLAlgo>> algorithmClasses
+ = new HashMap<String, Class<? extends MLAlgo>>();
/**
* Register.
*
- * @param trainerClass the trainer class
+ * @param algoClass the algo class
*/
- public void register(Class<? extends MLTrainer> trainerClass) {
- if (trainerClass != null && trainerClass.getAnnotation(Algorithm.class) != null) {
- algorithmClasses.put(trainerClass.getAnnotation(Algorithm.class).name(), trainerClass);
+ public void register(Class<? extends MLAlgo> algoClass) {
+ if (algoClass != null && algoClass.getAnnotation(Algorithm.class) != null) {
+ algorithmClasses.put(algoClass.getAnnotation(Algorithm.class).name(), algoClass);
} else {
- throw new IllegalArgumentException("Not a valid algorithm class: " + trainerClass);
+ throw new IllegalArgumentException("Not a valid algorithm class: " + algoClass);
}
}
/**
- * Gets the trainer for name.
+ * Gets the algo for name.
*
* @param name the name
- * @return the trainer for name
+ * @return the algo for name
* @throws LensException the lens exception
*/
- public MLTrainer getTrainerForName(String name) throws LensException {
- Class<? extends MLTrainer> trainerClass = algorithmClasses.get(name);
- if (trainerClass == null) {
+ public MLAlgo getAlgoForName(String name) throws LensException {
+ Class<? extends MLAlgo> algoClass = algorithmClasses.get(name);
+ if (algoClass == null) {
return null;
}
- Algorithm algoAnnotation = trainerClass.getAnnotation(Algorithm.class);
+ Algorithm algoAnnotation = algoClass.getAnnotation(Algorithm.class);
String description = algoAnnotation.description();
try {
- Constructor<? extends MLTrainer> trainerConstructor = trainerClass.getConstructor(String.class, String.class);
- return trainerConstructor.newInstance(name, description);
+ Constructor<? extends MLAlgo> algoConstructor = algoClass.getConstructor(String.class, String.class);
+ return algoConstructor.newInstance(name, description);
} catch (Exception exc) {
- throw new LensException("Unable to get trainer: " + name, exc);
+ throw new LensException("Unable to get algo: " + name, exc);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
index 9a15ea0..fe65d2f 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
@@ -48,13 +48,13 @@ public interface LensML {
Map<String, String> getAlgoParamDescription(String algorithm);
/**
- * Get a trainer object instance which could be used to generate a model of the given algorithm.
+ * Get a algo object instance which could be used to generate a model of the given algorithm.
*
* @param algorithm the algorithm
- * @return the trainer for name
+ * @return the algo for name
* @throws LensException the lens exception
*/
- MLTrainer getTrainerForName(String algorithm) throws LensException;
+ MLAlgo getAlgoForName(String algorithm) throws LensException;
/**
* Create a model using the given HCatalog table as input. The arguments should contain information needeed to
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
index 7cd0580..2555ca0 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
@@ -35,7 +35,7 @@ import org.apache.lens.api.query.LensQuery;
import org.apache.lens.api.query.QueryHandle;
import org.apache.lens.api.query.QueryStatus;
import org.apache.lens.ml.spark.SparkMLDriver;
-import org.apache.lens.ml.spark.trainers.BaseSparkTrainer;
+import org.apache.lens.ml.spark.algos.BaseSparkAlgo;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.commons.io.IOUtils;
@@ -93,25 +93,25 @@ public class LensMLImpl implements LensML {
}
public List<String> getAlgorithms() {
- List<String> trainers = new ArrayList<String>();
+ List<String> algos = new ArrayList<String>();
for (MLDriver driver : drivers) {
- trainers.addAll(driver.getTrainerNames());
+ algos.addAll(driver.getAlgoNames());
}
- return trainers;
+ return algos;
}
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.LensML#getTrainerForName(java.lang.String)
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
*/
- public MLTrainer getTrainerForName(String algorithm) throws LensException {
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
for (MLDriver driver : drivers) {
- if (driver.isTrainerSupported(algorithm)) {
- return driver.getTrainerInstance(algorithm);
+ if (driver.isAlgoSupported(algorithm)) {
+ return driver.getAlgoInstance(algorithm);
}
}
- throw new LensException("Trainer not supported " + algorithm);
+ throw new LensException("Algo not supported " + algorithm);
}
/*
@@ -120,11 +120,11 @@ public class LensMLImpl implements LensML {
* @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
*/
public String train(String table, String algorithm, String[] args) throws LensException {
- MLTrainer trainer = getTrainerForName(algorithm);
+ MLAlgo algo = getAlgoForName(algorithm);
String modelId = UUID.randomUUID().toString();
- LOG.info("Begin training model " + modelId + ", trainer=" + algorithm + ", table=" + table + ", params="
+ LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
+ Arrays.toString(args));
String database = null;
@@ -134,33 +134,33 @@ public class LensMLImpl implements LensML {
database = "default";
}
- MLModel model = trainer.train(toLensConf(conf), database, table, modelId, args);
+ MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
LOG.info("Done training model: " + modelId);
model.setCreatedAt(new Date());
- model.setTrainerName(algorithm);
+ model.setAlgoName(algorithm);
Path modelLocation = null;
try {
modelLocation = persistModel(model);
- LOG.info("Model saved: " + modelId + ", trainer: " + algorithm + ", path: " + modelLocation);
+ LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
return model.getId();
} catch (IOException e) {
- throw new LensException("Error saving model " + modelId + " for trainer " + algorithm, e);
+ throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
}
}
/**
- * Gets the trainer dir.
+ * Gets the algo dir.
*
- * @param trainerName the trainer name
- * @return the trainer dir
+ * @param algoName the algo name
+ * @return the algo dir
* @throws IOException Signals that an I/O exception has occurred.
*/
- private Path getTrainerDir(String trainerName) throws IOException {
+ private Path getAlgoDir(String algoName) throws IOException {
String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
- return new Path(new Path(modelSaveBaseDir), trainerName);
+ return new Path(new Path(modelSaveBaseDir), algoName);
}
/**
@@ -172,14 +172,14 @@ public class LensMLImpl implements LensML {
*/
private Path persistModel(MLModel model) throws IOException {
// Get model save path
- Path trainerDir = getTrainerDir(model.getTrainerName());
- FileSystem fs = trainerDir.getFileSystem(conf);
+ Path algoDir = getAlgoDir(model.getAlgoName());
+ FileSystem fs = algoDir.getFileSystem(conf);
- if (!fs.exists(trainerDir)) {
- fs.mkdirs(trainerDir);
+ if (!fs.exists(algoDir)) {
+ fs.mkdirs(algoDir);
}
- Path modelSavePath = new Path(trainerDir, model.getId());
+ Path modelSavePath = new Path(algoDir, model.getId());
ObjectOutputStream outputStream = null;
try {
@@ -202,15 +202,15 @@ public class LensMLImpl implements LensML {
*/
public List<String> getModels(String algorithm) throws LensException {
try {
- Path trainerDir = getTrainerDir(algorithm);
- FileSystem fs = trainerDir.getFileSystem(conf);
- if (!fs.exists(trainerDir)) {
+ Path algoDir = getAlgoDir(algorithm);
+ FileSystem fs = algoDir.getFileSystem(conf);
+ if (!fs.exists(algoDir)) {
return null;
}
List<String> models = new ArrayList<String>();
- for (FileStatus stat : fs.listStatus(trainerDir)) {
+ for (FileStatus stat : fs.listStatus(algoDir)) {
models.add(stat.getPath().getName());
}
@@ -563,15 +563,15 @@ public class LensMLImpl implements LensML {
* @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
*/
public Map<String, String> getAlgoParamDescription(String algorithm) {
- MLTrainer trainer = null;
+ MLAlgo algo = null;
try {
- trainer = getTrainerForName(algorithm);
+ algo = getAlgoForName(algorithm);
} catch (LensException e) {
LOG.error("Error getting algo description : " + algorithm, e);
return null;
}
- if (trainer instanceof BaseSparkTrainer) {
- return ((BaseSparkTrainer) trainer).getArgUsage();
+ if (algo instanceof BaseSparkAlgo) {
+ return ((BaseSparkAlgo) algo).getArgUsage();
}
return null;
}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
new file mode 100644
index 0000000..7dccf2c
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+
+/**
+ * The Interface MLAlgo.
+ */
+public interface MLAlgo {
+ String getName();
+
+ String getDescription();
+
+ /**
+ * Configure.
+ *
+ * @param configuration the configuration
+ */
+ void configure(LensConf configuration);
+
+ LensConf getConf();
+
+ /**
+ * Train.
+ *
+ * @param conf the conf
+ * @param db the db
+ * @param table the table
+ * @param modelId the model id
+ * @param params the params
+ * @return the ML model
+ * @throws LensException the lens exception
+ */
+ MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
index 562253a..567e717 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
@@ -29,21 +29,21 @@ import org.apache.lens.api.LensException;
public interface MLDriver {
/**
- * Checks if is trainer supported.
+ * Checks if is algo supported.
*
- * @param trainer the trainer
- * @return true, if is trainer supported
+ * @param algo the algo
+ * @return true, if is algo supported
*/
- boolean isTrainerSupported(String trainer);
+ boolean isAlgoSupported(String algo);
/**
- * Gets the trainer instance.
+ * Gets the algo instance.
*
- * @param trainer the trainer
- * @return the trainer instance
+ * @param algo the algo
+ * @return the algo instance
* @throws LensException the lens exception
*/
- MLTrainer getTrainerInstance(String trainer) throws LensException;
+ MLAlgo getAlgoInstance(String algo) throws LensException;
/**
* Inits the.
@@ -67,5 +67,5 @@ public interface MLDriver {
*/
void stop() throws LensException;
- List<String> getTrainerNames();
+ List<String> getAlgoNames();
}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
index 863cdfe..c177757 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
@@ -44,10 +44,10 @@ public abstract class MLModel<PREDICTION> implements Serializable {
@Setter
private Date createdAt;
- /** The trainer name. */
+ /** The algo name. */
@Getter
@Setter
- private String trainerName;
+ private String algoName;
/** The table. */
@Getter
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
deleted file mode 100644
index f1ae291..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-
-/**
- * The Interface MLTrainer.
- */
-public interface MLTrainer {
- String getName();
-
- String getDescription();
-
- /**
- * Configure.
- *
- * @param configuration the configuration
- */
- void configure(LensConf configuration);
-
- LensConf getConf();
-
- /**
- * Train.
- *
- * @param conf the conf
- * @param db the db
- * @param table the table
- * @param modelId the model id
- * @param params the params
- * @return the ML model
- * @throws LensException the lens exception
- */
- MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
index 4ea8773..2e240af 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
@@ -39,12 +39,12 @@ public final class MLUtils {
HIVE_CONF.addResource("lens-site.xml");
}
- public static String getTrainerName(Class<? extends MLTrainer> trainerClass) {
- Algorithm annotation = trainerClass.getAnnotation(Algorithm.class);
+ public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
+ Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
if (annotation != null) {
return annotation.name();
}
- throw new IllegalArgumentException("Trainer should be decorated with annotation - " + Algorithm.class.getName());
+ throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
}
public static MLServiceImpl getMLService() throws Exception {
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
deleted file mode 100644
index 92c025d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.reflect.Field;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-/**
- * The Class TrainerArgParser.
- */
-public final class TrainerArgParser {
- private TrainerArgParser() {
- }
-
- /**
- * The Class CustomArgParser.
- *
- * @param <E> the element type
- */
- public abstract static class CustomArgParser<E> {
-
- /**
- * Parses the.
- *
- * @param value the value
- * @return the e
- */
- public abstract E parse(String value);
- }
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(TrainerArgParser.class);
-
- /**
- * Extracts feature names. If the trainer has any parameters associated with @TrainerParam annotation, those are set
- * as well.
- *
- * @param trainer the trainer
- * @param args the args
- * @return List of feature column names.
- */
- public static List<String> parseArgs(MLTrainer trainer, String[] args) {
- List<String> featureColumns = new ArrayList<String>();
- Class<? extends MLTrainer> trainerClass = trainer.getClass();
- // Get param fields
- Map<String, Field> fieldMap = new HashMap<String, Field>();
-
- for (Field fld : trainerClass.getDeclaredFields()) {
- fld.setAccessible(true);
- TrainerParam paramAnnotation = fld.getAnnotation(TrainerParam.class);
- if (paramAnnotation != null) {
- fieldMap.put(paramAnnotation.name(), fld);
- }
- }
-
- for (int i = 0; i < args.length; i += 2) {
- String key = args[i].trim();
- String value = args[i + 1].trim();
-
- try {
- if ("feature".equalsIgnoreCase(key)) {
- featureColumns.add(value);
- } else if (fieldMap.containsKey(key)) {
- Field f = fieldMap.get(key);
- if (String.class.equals(f.getType())) {
- f.set(trainer, value);
- } else if (Integer.TYPE.equals(f.getType())) {
- f.setInt(trainer, Integer.parseInt(value));
- } else if (Double.TYPE.equals(f.getType())) {
- f.setDouble(trainer, Double.parseDouble(value));
- } else if (Long.TYPE.equals(f.getType())) {
- f.setLong(trainer, Long.parseLong(value));
- } else {
- // check if the trainer provides a deserializer for this param
- String customParserClass = trainer.getConf().getProperties().get("lens.ml.args." + key);
- if (customParserClass != null) {
- Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
- .forName(customParserClass);
- CustomArgParser<?> parser = clz.newInstance();
- f.set(trainer, parser.parse(value));
- } else {
- LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
- }
- }
- }
- } catch (Exception exc) {
- LOG.error("Error while setting param " + key + " to " + value + " for trainer " + trainer);
- }
- }
- return featureColumns;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
deleted file mode 100644
index fe8af60..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
-
-/**
- * The Interface TrainerParam.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.FIELD)
-public @interface TrainerParam {
-
- /**
- * Name.
- *
- * @return the string
- */
- String name();
-
- /**
- * Help.
- *
- * @return the string
- */
- String help();
-
- /**
- * Default value.
- *
- * @return the string
- */
- String defaultValue() default "None";
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
index b19e8ea..1e452f1 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
@@ -26,9 +26,9 @@ import java.util.List;
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensException;
import org.apache.lens.ml.Algorithms;
+import org.apache.lens.ml.MLAlgo;
import org.apache.lens.ml.MLDriver;
-import org.apache.lens.ml.MLTrainer;
-import org.apache.lens.ml.spark.trainers.*;
+import org.apache.lens.ml.spark.algos.*;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
@@ -90,46 +90,46 @@ public class SparkMLDriver implements MLDriver {
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.MLDriver#isTrainerSupported(java.lang.String)
+ * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
*/
@Override
- public boolean isTrainerSupported(String name) {
+ public boolean isAlgoSupported(String name) {
return algorithms.isAlgoSupported(name);
}
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.MLDriver#getTrainerInstance(java.lang.String)
+ * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
*/
@Override
- public MLTrainer getTrainerInstance(String name) throws LensException {
+ public MLAlgo getAlgoInstance(String name) throws LensException {
checkStarted();
- if (!isTrainerSupported(name)) {
+ if (!isAlgoSupported(name)) {
return null;
}
- MLTrainer trainer = null;
+ MLAlgo algo = null;
try {
- trainer = algorithms.getTrainerForName(name);
- if (trainer instanceof BaseSparkTrainer) {
- ((BaseSparkTrainer) trainer).setSparkContext(sparkContext);
+ algo = algorithms.getAlgoForName(name);
+ if (algo instanceof BaseSparkAlgo) {
+ ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
}
} catch (LensException exc) {
- LOG.error("Error creating trainer object", exc);
+ LOG.error("Error creating algo object", exc);
}
- return trainer;
+ return algo;
}
/**
- * Register trainers.
+ * Register algos.
*/
- private void registerTrainers() {
- algorithms.register(NaiveBayesTrainer.class);
- algorithms.register(SVMTrainer.class);
- algorithms.register(LogisticRegressionTrainer.class);
- algorithms.register(DecisionTreeTrainer.class);
+ private void registerAlgos() {
+ algorithms.register(NaiveBayesAlgo.class);
+ algorithms.register(SVMAlgo.class);
+ algorithms.register(LogisticRegressionAlgo.class);
+ algorithms.register(DecisionTreeAlgo.class);
}
/*
@@ -140,7 +140,7 @@ public class SparkMLDriver implements MLDriver {
@Override
public void init(LensConf conf) throws LensException {
sparkConf = new SparkConf();
- registerTrainers();
+ registerAlgos();
for (String key : conf.getProperties().keySet()) {
if (key.startsWith("lens.ml.sparkdriver.")) {
sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
@@ -253,7 +253,7 @@ public class SparkMLDriver implements MLDriver {
}
@Override
- public List<String> getTrainerNames() {
+ public List<String> getAlgoNames() {
return algorithms.getAlgorithmNames();
}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
new file mode 100644
index 0000000..22cda6d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
@@ -0,0 +1,290 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.lang.reflect.Field;
+import java.util.*;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.MLAlgo;
+import org.apache.lens.ml.MLModel;
+
+import org.apache.lens.ml.spark.TableTrainingSpec;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class BaseSparkAlgo.
+ */
+public abstract class BaseSparkAlgo implements MLAlgo {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
+
+ /** The name. */
+ private final String name;
+
+ /** The description. */
+ private final String description;
+
+ /** The spark context. */
+ protected JavaSparkContext sparkContext;
+
+ /** The params. */
+ protected Map<String, String> params;
+
+ /** The conf. */
+ protected transient LensConf conf;
+
+ /** The training fraction. */
+ @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
+ protected double trainingFraction;
+
+ /** The use training fraction. */
+ private boolean useTrainingFraction;
+
+ /** The label. */
+ @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
+ protected String label;
+
+ /** The partition filter. */
+ @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
+ protected String partitionFilter;
+
+ /** The features. */
+ @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
+ protected List<String> features;
+
+ /**
+ * Instantiates a new base spark algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public BaseSparkAlgo(String name, String description) {
+ this.name = name;
+ this.description = description;
+ }
+
+ public void setSparkContext(JavaSparkContext sparkContext) {
+ this.sparkContext = sparkContext;
+ }
+
+ @Override
+ public LensConf getConf() {
+ return conf;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void configure(LensConf configuration) {
+ this.conf = configuration;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+ * java.lang.String, java.lang.String[])
+ */
+ @Override
+ public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
+ throws LensException {
+ parseParams(params);
+
+ TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
+ .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
+
+ if (useTrainingFraction) {
+ builder.trainingFraction(trainingFraction);
+ }
+
+ TableTrainingSpec spec = builder.build();
+ LOG.info("Training " + " with " + features.size() + " features");
+
+ spec.createRDDs(sparkContext);
+
+ RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
+ BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
+ model.setTable(table);
+ model.setParams(Arrays.asList(params));
+ model.setLabelColumn(label);
+ model.setFeatureColumns(features);
+ return model;
+ }
+
+ /**
+ * To hive conf.
+ *
+ * @param conf the conf
+ * @return the hive conf
+ */
+ protected HiveConf toHiveConf(LensConf conf) {
+ HiveConf hiveConf = new HiveConf();
+ for (String key : conf.getProperties().keySet()) {
+ hiveConf.set(key, conf.getProperties().get(key));
+ }
+ return hiveConf;
+ }
+
+ /**
+ * Parses the params.
+ *
+ * @param args the args
+ */
+ public void parseParams(String[] args) {
+ if (args.length % 2 != 0) {
+ throw new IllegalArgumentException("Invalid number of params " + args.length);
+ }
+
+ params = new LinkedHashMap<String, String>();
+
+ for (int i = 0; i < args.length; i += 2) {
+ if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
+ if (features == null) {
+ features = new ArrayList<String>();
+ }
+ features.add(args[i + 1]);
+ } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
+ label = args[i + 1];
+ } else {
+ params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
+ }
+ }
+
+ if (params.containsKey("trainingFraction")) {
+ // Get training Fraction
+ String trainingFractionStr = params.get("trainingFraction");
+ try {
+ trainingFraction = Double.parseDouble(trainingFractionStr);
+ useTrainingFraction = true;
+ } catch (NumberFormatException nfe) {
+ throw new IllegalArgumentException("Invalid training fraction", nfe);
+ }
+ }
+
+ if (params.containsKey("partition") || params.containsKey("p")) {
+ partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
+ }
+
+ parseAlgoParams(params);
+ }
+
+ /**
+ * Gets the param value.
+ *
+ * @param param the param
+ * @param defaultVal the default val
+ * @return the param value
+ */
+ public double getParamValue(String param, double defaultVal) {
+ if (params.containsKey(param)) {
+ try {
+ return Double.parseDouble(params.get(param));
+ } catch (NumberFormatException nfe) {
+ LOG.warn("Couldn't parse param value: " + param + " as double.");
+ }
+ }
+ return defaultVal;
+ }
+
+ /**
+ * Gets the param value.
+ *
+ * @param param the param
+ * @param defaultVal the default val
+ * @return the param value
+ */
+ public int getParamValue(String param, int defaultVal) {
+ if (params.containsKey(param)) {
+ try {
+ return Integer.parseInt(params.get(param));
+ } catch (NumberFormatException nfe) {
+ LOG.warn("Couldn't parse param value: " + param + " as integer.");
+ }
+ }
+ return defaultVal;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public String getDescription() {
+ return description;
+ }
+
+ public Map<String, String> getArgUsage() {
+ Map<String, String> usage = new LinkedHashMap<String, String>();
+ Class<?> clz = this.getClass();
+ // Put class name and description as well as part of the usage
+ Algorithm algorithm = clz.getAnnotation(Algorithm.class);
+ if (algorithm != null) {
+ usage.put("Algorithm Name", algorithm.name());
+ usage.put("Algorithm Description", algorithm.description());
+ }
+
+ // Get all algo params including base algo params
+ while (clz != null) {
+ for (Field field : clz.getDeclaredFields()) {
+ AlgoParam param = field.getAnnotation(AlgoParam.class);
+ if (param != null) {
+ usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
+ }
+ }
+
+ if (clz.equals(BaseSparkAlgo.class)) {
+ break;
+ }
+ clz = clz.getSuperclass();
+ }
+ return usage;
+ }
+
+ /**
+ * Parses the algo params.
+ *
+ * @param params the params
+ */
+ public abstract void parseAlgoParams(Map<String, String> params);
+
+ /**
+ * Train internal.
+ *
+ * @param modelId the model id
+ * @param trainingRDD the training rdd
+ * @return the base spark classification model
+ * @throws LensException the lens exception
+ */
+ protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
new file mode 100644
index 0000000..a6d66c5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
+import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
+
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.DecisionTree$;
+import org.apache.spark.mllib.tree.configuration.Algo$;
+import org.apache.spark.mllib.tree.impurity.Entropy$;
+import org.apache.spark.mllib.tree.impurity.Gini$;
+import org.apache.spark.mllib.tree.impurity.Impurity;
+import org.apache.spark.mllib.tree.impurity.Variance$;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+import scala.Enumeration;
+
+/**
+ * The Class DecisionTreeAlgo.
+ */
+@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
+public class DecisionTreeAlgo extends BaseSparkAlgo {
+
+ /** The algo. */
+ @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
+ private Enumeration.Value algo;
+
+ /** The decision tree impurity. */
+ @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
+ + "Allowed values are 'gini', 'entropy' and 'variance'")
+ private Impurity decisionTreeImpurity;
+
+ /** The max depth. */
+ @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
+ defaultValue = "100")
+ private int maxDepth;
+
+ /**
+ * Instantiates a new decision tree algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public DecisionTreeAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ String dtreeAlgoName = params.get("algo");
+ if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Classification();
+ } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Regression();
+ }
+
+ String impurity = params.get("impurity");
+ if ("gini".equals(impurity)) {
+ decisionTreeImpurity = Gini$.MODULE$;
+ } else if ("entropy".equals(impurity)) {
+ decisionTreeImpurity = Entropy$.MODULE$;
+ } else if ("variance".equals(impurity)) {
+ decisionTreeImpurity = Variance$.MODULE$;
+ }
+
+ maxDepth = getParamValue("maxDepth", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
+ return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
new file mode 100644
index 0000000..7ca5a79
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.*;
+import org.apache.lens.ml.spark.HiveTableRDD;
+import org.apache.lens.ml.spark.models.KMeansClusteringModel;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.KMeans;
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+import scala.Tuple2;
+
+/**
+ * The Class KMeansAlgo.
+ */
+@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
+public class KMeansAlgo implements MLAlgo {
+
+ /** The conf. */
+ private transient LensConf conf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /** The part filter. */
+ @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
+ private String partFilter = null;
+
+ /** The k. */
+ @AlgoParam(name = "k", help = "Number of cluster")
+ private int k;
+
+ /** The max iterations. */
+ @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
+ private int maxIterations = 100;
+
+ /** The runs. */
+ @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
+ private int runs = 1;
+
+ /** The initialization mode. */
+ @AlgoParam(name = "initializationMode",
+ help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
+ private String initializationMode = "k-means||";
+
+ @Override
+ public String getName() {
+ return getClass().getAnnotation(Algorithm.class).name();
+ }
+
+ @Override
+ public String getDescription() {
+ return getClass().getAnnotation(Algorithm.class).description();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void configure(LensConf configuration) {
+ this.conf = configuration;
+ }
+
+ @Override
+ public LensConf getConf() {
+ return conf;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+ * java.lang.String, java.lang.String[])
+ */
+ @Override
+ public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
+ List<String> features = AlgoArgParser.parseArgs(this, params);
+ final int[] featurePositions = new int[features.size()];
+ final int NUM_FEATURES = features.size();
+
+ JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
+ try {
+ // Map feature names to positions
+ Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
+ List<FieldSchema> allCols = tbl.getAllCols();
+ int f = 0;
+ for (int i = 0; i < tbl.getAllCols().size(); i++) {
+ String colName = allCols.get(i).getName();
+ if (features.contains(colName)) {
+ featurePositions[f++] = i;
+ }
+ }
+
+ rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
+ JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
+ @Override
+ public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
+ HCatRecord hCatRecord = v1._2();
+ double[] arr = new double[NUM_FEATURES];
+ for (int i = 0; i < NUM_FEATURES; i++) {
+ Object val = hCatRecord.get(featurePositions[i]);
+ arr[i] = val == null ? 0d : (Double) val;
+ }
+ return Vectors.dense(arr);
+ }
+ });
+
+ KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
+ return new KMeansClusteringModel(modelId, model);
+ } catch (Exception e) {
+ throw new LensException("KMeans algo failed for " + db + "." + table, e);
+ }
+ }
+
+ /**
+ * To hive conf.
+ *
+ * @param conf the conf
+ * @return the hive conf
+ */
+ private HiveConf toHiveConf(LensConf conf) {
+ HiveConf hiveConf = new HiveConf();
+ for (String key : conf.getProperties().keySet()) {
+ hiveConf.set(key, conf.getProperties().get(key));
+ }
+ return hiveConf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
new file mode 100644
index 0000000..106b3c5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class LogisticRegressionAlgo.
+ */
+@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
+public class LogisticRegressionAlgo extends BaseSparkAlgo {
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /**
+ * Instantiates a new logistic regression algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public LogisticRegressionAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ iterations = getParamValue("iterations", 100);
+ stepSize = getParamValue("stepSize", 1.0d);
+ minBatchFraction = getParamValue("minBatchFraction", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
+ minBatchFraction);
+ return new LogitRegressionClassificationModel(modelId, lrModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
new file mode 100644
index 0000000..f7652d1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayes;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class NaiveBayesAlgo.
+ */
+@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
+public class NaiveBayesAlgo extends BaseSparkAlgo {
+
+ /** The lambda. */
+ @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
+ private double lambda = 1.0;
+
+ /**
+ * Instantiates a new naive bayes algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public NaiveBayesAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ lambda = getParamValue("lambda", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
new file mode 100644
index 0000000..09251b7
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.SVMClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.classification.SVMWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class SVMAlgo.
+ */
+@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
+public class SVMAlgo extends BaseSparkAlgo {
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /** The reg param. */
+ @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
+ private double regParam;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /**
+ * Instantiates a new SVM algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public SVMAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ minBatchFraction = getParamValue("minBatchFraction", 1.0);
+ regParam = getParamValue("regParam", 1.0);
+ stepSize = getParamValue("stepSize", 1.0);
+ iterations = getParamValue("iterations", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
+ return new SVMClassificationModel(modelId, svmModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
deleted file mode 100644
index f75e41c..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
+++ /dev/null
@@ -1,289 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.lang.reflect.Field;
-import java.util.*;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTrainer;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.TableTrainingSpec;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class BaseSparkTrainer.
- */
-public abstract class BaseSparkTrainer implements MLTrainer {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(BaseSparkTrainer.class);
-
- /** The name. */
- private final String name;
-
- /** The description. */
- private final String description;
-
- /** The spark context. */
- protected JavaSparkContext sparkContext;
-
- /** The params. */
- protected Map<String, String> params;
-
- /** The conf. */
- protected transient LensConf conf;
-
- /** The training fraction. */
- @TrainerParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
- protected double trainingFraction;
-
- /** The use training fraction. */
- private boolean useTrainingFraction;
-
- /** The label. */
- @TrainerParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
- protected String label;
-
- /** The partition filter. */
- @TrainerParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
- protected String partitionFilter;
-
- /** The features. */
- @TrainerParam(name = "feature", help = "Column name(s) which are to be used as sample features")
- protected List<String> features;
-
- /**
- * Instantiates a new base spark trainer.
- *
- * @param name the name
- * @param description the description
- */
- public BaseSparkTrainer(String name, String description) {
- this.name = name;
- this.description = description;
- }
-
- public void setSparkContext(JavaSparkContext sparkContext) {
- this.sparkContext = sparkContext;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLTrainer#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLTrainer#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
- throws LensException {
- parseParams(params);
-
- TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
- .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
-
- if (useTrainingFraction) {
- builder.trainingFraction(trainingFraction);
- }
-
- TableTrainingSpec spec = builder.build();
- LOG.info("Training " + " with " + features.size() + " features");
-
- spec.createRDDs(sparkContext);
-
- RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
- BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
- model.setTable(table);
- model.setParams(Arrays.asList(params));
- model.setLabelColumn(label);
- model.setFeatureColumns(features);
- return model;
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- protected HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-
- /**
- * Parses the params.
- *
- * @param args the args
- */
- public void parseParams(String[] args) {
- if (args.length % 2 != 0) {
- throw new IllegalArgumentException("Invalid number of params " + args.length);
- }
-
- params = new LinkedHashMap<String, String>();
-
- for (int i = 0; i < args.length; i += 2) {
- if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
- if (features == null) {
- features = new ArrayList<String>();
- }
- features.add(args[i + 1]);
- } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
- label = args[i + 1];
- } else {
- params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
- }
- }
-
- if (params.containsKey("trainingFraction")) {
- // Get training Fraction
- String trainingFractionStr = params.get("trainingFraction");
- try {
- trainingFraction = Double.parseDouble(trainingFractionStr);
- useTrainingFraction = true;
- } catch (NumberFormatException nfe) {
- throw new IllegalArgumentException("Invalid training fraction", nfe);
- }
- }
-
- if (params.containsKey("partition") || params.containsKey("p")) {
- partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
- }
-
- parseTrainerParams(params);
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public double getParamValue(String param, double defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Double.parseDouble(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as double.");
- }
- }
- return defaultVal;
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public int getParamValue(String param, int defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Integer.parseInt(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as integer.");
- }
- }
- return defaultVal;
- }
-
- public String getName() {
- return name;
- }
-
- public String getDescription() {
- return description;
- }
-
- public Map<String, String> getArgUsage() {
- Map<String, String> usage = new LinkedHashMap<String, String>();
- Class<?> clz = this.getClass();
- // Put class name and description as well as part of the usage
- Algorithm algorithm = clz.getAnnotation(Algorithm.class);
- if (algorithm != null) {
- usage.put("Algorithm Name", algorithm.name());
- usage.put("Algorithm Description", algorithm.description());
- }
-
- // Get all trainer params including base trainer params
- while (clz != null) {
- for (Field field : clz.getDeclaredFields()) {
- TrainerParam param = field.getAnnotation(TrainerParam.class);
- if (param != null) {
- usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
- }
- }
-
- if (clz.equals(BaseSparkTrainer.class)) {
- break;
- }
- clz = clz.getSuperclass();
- }
- return usage;
- }
-
- /**
- * Parses the trainer params.
- *
- * @param params the params
- */
- public abstract void parseTrainerParams(Map<String, String> params);
-
- /**
- * Train internal.
- *
- * @param modelId the model id
- * @param trainingRDD the training rdd
- * @return the base spark classification model
- * @throws LensException the lens exception
- */
- protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
deleted file mode 100644
index 96fbf28..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
-import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
-
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.DecisionTree$;
-import org.apache.spark.mllib.tree.configuration.Algo$;
-import org.apache.spark.mllib.tree.impurity.Entropy$;
-import org.apache.spark.mllib.tree.impurity.Gini$;
-import org.apache.spark.mllib.tree.impurity.Impurity;
-import org.apache.spark.mllib.tree.impurity.Variance$;
-import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.rdd.RDD;
-
-import scala.Enumeration;
-
-/**
- * The Class DecisionTreeTrainer.
- */
-@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier trainer")
-public class DecisionTreeTrainer extends BaseSparkTrainer {
-
- /** The algo. */
- @TrainerParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
- private Enumeration.Value algo;
-
- /** The decision tree impurity. */
- @TrainerParam(name = "impurity", help = "Impurity measure used by the decision tree. "
- + "Allowed values are 'gini', 'entropy' and 'variance'")
- private Impurity decisionTreeImpurity;
-
- /** The max depth. */
- @TrainerParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
- defaultValue = "100")
- private int maxDepth;
-
- /**
- * Instantiates a new decision tree trainer.
- *
- * @param name the name
- * @param description the description
- */
- public DecisionTreeTrainer(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map)
- */
- @Override
- public void parseTrainerParams(Map<String, String> params) {
- String dtreeAlgoName = params.get("algo");
- if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Classification();
- } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Regression();
- }
-
- String impurity = params.get("impurity");
- if ("gini".equals(impurity)) {
- decisionTreeImpurity = Gini$.MODULE$;
- } else if ("entropy".equals(impurity)) {
- decisionTreeImpurity = Entropy$.MODULE$;
- } else if ("variance".equals(impurity)) {
- decisionTreeImpurity = Variance$.MODULE$;
- }
-
- maxDepth = getParamValue("maxDepth", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
- return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
- }
-}