You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by am...@apache.org on 2015/04/15 21:47:44 UTC

[20/30] incubator-lens git commit: LENS-319: Renamed Trainer to Algo (sharad)

LENS-319: Renamed Trainer to Algo (sharad)


Project: http://git-wip-us.apache.org/repos/asf/incubator-lens/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-lens/commit/49fef8e2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-lens/tree/49fef8e2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-lens/diff/49fef8e2

Branch: refs/heads/master
Commit: 49fef8e2bb3995ee7d5b3db74bf954c205c187fc
Parents: ff152d9
Author: Sharad Agarwal <sh...@flipkarts-MacBook-Pro.local>
Authored: Wed Feb 18 14:40:17 2015 +0530
Committer: Amareshwari Sriramadasu <am...@apache.org>
Committed: Thu Feb 19 10:23:12 2015 +0530

----------------------------------------------------------------------
 .../org/apache/lens/client/LensMLClient.java    |  14 +-
 .../apache/lens/client/LensMLJerseyClient.java  |  14 +-
 .../java/org/apache/lens/ml/AlgoArgParser.java  | 114 ++++++++
 .../main/java/org/apache/lens/ml/AlgoParam.java |  53 ++++
 .../java/org/apache/lens/ml/Algorithms.java     |  32 +-
 .../main/java/org/apache/lens/ml/LensML.java    |   6 +-
 .../java/org/apache/lens/ml/LensMLImpl.java     |  66 ++---
 .../main/java/org/apache/lens/ml/MLAlgo.java    |  53 ++++
 .../main/java/org/apache/lens/ml/MLDriver.java  |  18 +-
 .../main/java/org/apache/lens/ml/MLModel.java   |   4 +-
 .../main/java/org/apache/lens/ml/MLTrainer.java |  53 ----
 .../main/java/org/apache/lens/ml/MLUtils.java   |   6 +-
 .../org/apache/lens/ml/TrainerArgParser.java    | 114 --------
 .../java/org/apache/lens/ml/TrainerParam.java   |  53 ----
 .../org/apache/lens/ml/spark/SparkMLDriver.java |  42 +--
 .../lens/ml/spark/algos/BaseSparkAlgo.java      | 290 +++++++++++++++++++
 .../lens/ml/spark/algos/DecisionTreeAlgo.java   | 109 +++++++
 .../apache/lens/ml/spark/algos/KMeansAlgo.java  | 163 +++++++++++
 .../ml/spark/algos/LogisticRegressionAlgo.java  |  86 ++++++
 .../lens/ml/spark/algos/NaiveBayesAlgo.java     |  73 +++++
 .../org/apache/lens/ml/spark/algos/SVMAlgo.java |  90 ++++++
 .../ml/spark/trainers/BaseSparkTrainer.java     | 289 ------------------
 .../ml/spark/trainers/DecisionTreeTrainer.java  | 109 -------
 .../lens/ml/spark/trainers/KMeansTrainer.java   | 163 -----------
 .../trainers/LogisticRegressionTrainer.java     |  86 ------
 .../ml/spark/trainers/NaiveBayesTrainer.java    |  73 -----
 .../lens/ml/spark/trainers/SVMTrainer.java      |  90 ------
 .../java/org/apache/lens/ml/task/MLTask.java    |   8 +-
 .../apache/lens/server/ml/MLServiceImpl.java    |   6 +-
 .../lens/server/ml/MLServiceResource.java       |  38 +--
 .../java/org/apache/lens/ml/TestMLResource.java |  37 +--
 31 files changed, 1177 insertions(+), 1175 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
index 1bad7a0..9f7fa26 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
@@ -35,9 +35,9 @@ import org.apache.lens.api.LensSessionHandle;
 import org.apache.lens.api.ml.ModelMetadata;
 import org.apache.lens.api.ml.TestReport;
 import org.apache.lens.ml.LensML;
+import org.apache.lens.ml.MLAlgo;
 import org.apache.lens.ml.MLModel;
 import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.MLTrainer;
 
 import org.apache.commons.lang.StringUtils;
 import org.apache.commons.logging.Log;
@@ -81,7 +81,7 @@ public class LensMLClient implements LensML, Closeable {
    */
   @Override
   public List<String> getAlgorithms() {
-    return client.getTrainerNames();
+    return client.getAlgoNames();
   }
 
   /**
@@ -92,7 +92,7 @@ public class LensMLClient implements LensML, Closeable {
    */
   @Override
   public Map<String, String> getAlgoParamDescription(String algorithm) {
-    List<String> paramDesc = client.getParamDescriptionOfTrainer(algorithm);
+    List<String> paramDesc = client.getParamDescriptionOfAlgo(algorithm);
     // convert paramDesc to map
     Map<String, String> paramDescMap = new LinkedHashMap<String, String>();
     for (String str : paramDesc) {
@@ -103,15 +103,15 @@ public class LensMLClient implements LensML, Closeable {
   }
 
   /**
-   * Get a trainer object instance which could be used to generate a model of the given algorithm.
+   * Get a algo object instance which could be used to generate a model of the given algorithm.
    *
    * @param algorithm the algorithm
-   * @return the trainer for name
+   * @return the algo for name
    * @throws LensException the lens exception
    */
   @Override
-  public MLTrainer getTrainerForName(String algorithm) throws LensException {
-    throw new UnsupportedOperationException("MLTrainer cannot be accessed from client");
+  public MLAlgo getAlgoForName(String algorithm) throws LensException {
+    throw new UnsupportedOperationException("MLAlgo cannot be accessed from client");
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
index 2b1ece4..af47a41 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
@@ -134,9 +134,9 @@ public class LensMLJerseyClient {
     }
   }
 
-  public List<String> getTrainerNames() {
-    StringList trainerNames = getMLWebTarget().path("trainers").request().get(StringList.class);
-    return trainerNames == null ? null : trainerNames.getElements();
+  public List<String> getAlgoNames() {
+    StringList algoNames = getMLWebTarget().path("algos").request().get(StringList.class);
+    return algoNames == null ? null : algoNames.getElements();
   }
 
   /**
@@ -234,14 +234,14 @@ public class LensMLJerseyClient {
   }
 
   /**
-   * Gets the param description of trainer.
+   * Gets the param description of algo.
    *
    * @param algorithm the algorithm
-   * @return the param description of trainer
+   * @return the param description of algo
    */
-  public List<String> getParamDescriptionOfTrainer(String algorithm) {
+  public List<String> getParamDescriptionOfAlgo(String algorithm) {
     try {
-      StringList paramHelp = getMLWebTarget().path("trainers").path(algorithm).request(MediaType.APPLICATION_XML)
+      StringList paramHelp = getMLWebTarget().path("algos").path(algorithm).request(MediaType.APPLICATION_XML)
         .get(StringList.class);
       return paramHelp.getElements();
     } catch (NotFoundException exc) {

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
new file mode 100644
index 0000000..20da083
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+/**
+ * The Class AlgoArgParser.
+ */
+public final class AlgoArgParser {
+  private AlgoArgParser() {
+  }
+
+  /**
+   * The Class CustomArgParser.
+   *
+   * @param <E> the element type
+   */
+  public abstract static class CustomArgParser<E> {
+
+    /**
+     * Parses the.
+     *
+     * @param value the value
+     * @return the e
+     */
+    public abstract E parse(String value);
+  }
+
+  /** The Constant LOG. */
+  public static final Log LOG = LogFactory.getLog(AlgoArgParser.class);
+
+  /**
+   * Extracts feature names. If the algo has any parameters associated with @AlgoParam annotation, those are set
+   * as well.
+   *
+   * @param algo the algo
+   * @param args    the args
+   * @return List of feature column names.
+   */
+  public static List<String> parseArgs(MLAlgo algo, String[] args) {
+    List<String> featureColumns = new ArrayList<String>();
+    Class<? extends MLAlgo> algoClass = algo.getClass();
+    // Get param fields
+    Map<String, Field> fieldMap = new HashMap<String, Field>();
+
+    for (Field fld : algoClass.getDeclaredFields()) {
+      fld.setAccessible(true);
+      AlgoParam paramAnnotation = fld.getAnnotation(AlgoParam.class);
+      if (paramAnnotation != null) {
+        fieldMap.put(paramAnnotation.name(), fld);
+      }
+    }
+
+    for (int i = 0; i < args.length; i += 2) {
+      String key = args[i].trim();
+      String value = args[i + 1].trim();
+
+      try {
+        if ("feature".equalsIgnoreCase(key)) {
+          featureColumns.add(value);
+        } else if (fieldMap.containsKey(key)) {
+          Field f = fieldMap.get(key);
+          if (String.class.equals(f.getType())) {
+            f.set(algo, value);
+          } else if (Integer.TYPE.equals(f.getType())) {
+            f.setInt(algo, Integer.parseInt(value));
+          } else if (Double.TYPE.equals(f.getType())) {
+            f.setDouble(algo, Double.parseDouble(value));
+          } else if (Long.TYPE.equals(f.getType())) {
+            f.setLong(algo, Long.parseLong(value));
+          } else {
+            // check if the algo provides a deserializer for this param
+            String customParserClass = algo.getConf().getProperties().get("lens.ml.args." + key);
+            if (customParserClass != null) {
+              Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
+                .forName(customParserClass);
+              CustomArgParser<?> parser = clz.newInstance();
+              f.set(algo, parser.parse(value));
+            } else {
+              LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
+            }
+          }
+        }
+      } catch (Exception exc) {
+        LOG.error("Error while setting param " + key + " to " + value + " for algo " + algo);
+      }
+    }
+    return featureColumns;
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
new file mode 100644
index 0000000..5836f51
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The Interface AlgoParam.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.FIELD)
+public @interface AlgoParam {
+
+  /**
+   * Name.
+   *
+   * @return the string
+   */
+  String name();
+
+  /**
+   * Help.
+   *
+   * @return the string
+   */
+  String help();
+
+  /**
+   * Default value.
+   *
+   * @return the string
+   */
+  String defaultValue() default "None";
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
index 3b74a09..c1b7212 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
@@ -32,41 +32,41 @@ import org.apache.lens.api.LensException;
 public class Algorithms {
 
   /** The algorithm classes. */
-  private final Map<String, Class<? extends MLTrainer>> algorithmClasses
-    = new HashMap<String, Class<? extends MLTrainer>>();
+  private final Map<String, Class<? extends MLAlgo>> algorithmClasses
+    = new HashMap<String, Class<? extends MLAlgo>>();
 
   /**
    * Register.
    *
-   * @param trainerClass the trainer class
+   * @param algoClass the algo class
    */
-  public void register(Class<? extends MLTrainer> trainerClass) {
-    if (trainerClass != null && trainerClass.getAnnotation(Algorithm.class) != null) {
-      algorithmClasses.put(trainerClass.getAnnotation(Algorithm.class).name(), trainerClass);
+  public void register(Class<? extends MLAlgo> algoClass) {
+    if (algoClass != null && algoClass.getAnnotation(Algorithm.class) != null) {
+      algorithmClasses.put(algoClass.getAnnotation(Algorithm.class).name(), algoClass);
     } else {
-      throw new IllegalArgumentException("Not a valid algorithm class: " + trainerClass);
+      throw new IllegalArgumentException("Not a valid algorithm class: " + algoClass);
     }
   }
 
   /**
-   * Gets the trainer for name.
+   * Gets the algo for name.
    *
    * @param name the name
-   * @return the trainer for name
+   * @return the algo for name
    * @throws LensException the lens exception
    */
-  public MLTrainer getTrainerForName(String name) throws LensException {
-    Class<? extends MLTrainer> trainerClass = algorithmClasses.get(name);
-    if (trainerClass == null) {
+  public MLAlgo getAlgoForName(String name) throws LensException {
+    Class<? extends MLAlgo> algoClass = algorithmClasses.get(name);
+    if (algoClass == null) {
       return null;
     }
-    Algorithm algoAnnotation = trainerClass.getAnnotation(Algorithm.class);
+    Algorithm algoAnnotation = algoClass.getAnnotation(Algorithm.class);
     String description = algoAnnotation.description();
     try {
-      Constructor<? extends MLTrainer> trainerConstructor = trainerClass.getConstructor(String.class, String.class);
-      return trainerConstructor.newInstance(name, description);
+      Constructor<? extends MLAlgo> algoConstructor = algoClass.getConstructor(String.class, String.class);
+      return algoConstructor.newInstance(name, description);
     } catch (Exception exc) {
-      throw new LensException("Unable to get trainer: " + name, exc);
+      throw new LensException("Unable to get algo: " + name, exc);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
index 9a15ea0..fe65d2f 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
@@ -48,13 +48,13 @@ public interface LensML {
   Map<String, String> getAlgoParamDescription(String algorithm);
 
   /**
-   * Get a trainer object instance which could be used to generate a model of the given algorithm.
+   * Get a algo object instance which could be used to generate a model of the given algorithm.
    *
    * @param algorithm the algorithm
-   * @return the trainer for name
+   * @return the algo for name
    * @throws LensException the lens exception
    */
-  MLTrainer getTrainerForName(String algorithm) throws LensException;
+  MLAlgo getAlgoForName(String algorithm) throws LensException;
 
   /**
    * Create a model using the given HCatalog table as input. The arguments should contain information needeed to

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
index 7cd0580..2555ca0 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
@@ -35,7 +35,7 @@ import org.apache.lens.api.query.LensQuery;
 import org.apache.lens.api.query.QueryHandle;
 import org.apache.lens.api.query.QueryStatus;
 import org.apache.lens.ml.spark.SparkMLDriver;
-import org.apache.lens.ml.spark.trainers.BaseSparkTrainer;
+import org.apache.lens.ml.spark.algos.BaseSparkAlgo;
 import org.apache.lens.server.api.LensConfConstants;
 
 import org.apache.commons.io.IOUtils;
@@ -93,25 +93,25 @@ public class LensMLImpl implements LensML {
   }
 
   public List<String> getAlgorithms() {
-    List<String> trainers = new ArrayList<String>();
+    List<String> algos = new ArrayList<String>();
     for (MLDriver driver : drivers) {
-      trainers.addAll(driver.getTrainerNames());
+      algos.addAll(driver.getAlgoNames());
     }
-    return trainers;
+    return algos;
   }
 
   /*
    * (non-Javadoc)
    *
-   * @see org.apache.lens.ml.LensML#getTrainerForName(java.lang.String)
+   * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
    */
-  public MLTrainer getTrainerForName(String algorithm) throws LensException {
+  public MLAlgo getAlgoForName(String algorithm) throws LensException {
     for (MLDriver driver : drivers) {
-      if (driver.isTrainerSupported(algorithm)) {
-        return driver.getTrainerInstance(algorithm);
+      if (driver.isAlgoSupported(algorithm)) {
+        return driver.getAlgoInstance(algorithm);
       }
     }
-    throw new LensException("Trainer not supported " + algorithm);
+    throw new LensException("Algo not supported " + algorithm);
   }
 
   /*
@@ -120,11 +120,11 @@ public class LensMLImpl implements LensML {
    * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
    */
   public String train(String table, String algorithm, String[] args) throws LensException {
-    MLTrainer trainer = getTrainerForName(algorithm);
+    MLAlgo algo = getAlgoForName(algorithm);
 
     String modelId = UUID.randomUUID().toString();
 
-    LOG.info("Begin training model " + modelId + ", trainer=" + algorithm + ", table=" + table + ", params="
+    LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
       + Arrays.toString(args));
 
     String database = null;
@@ -134,33 +134,33 @@ public class LensMLImpl implements LensML {
       database = "default";
     }
 
-    MLModel model = trainer.train(toLensConf(conf), database, table, modelId, args);
+    MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
 
     LOG.info("Done training model: " + modelId);
 
     model.setCreatedAt(new Date());
-    model.setTrainerName(algorithm);
+    model.setAlgoName(algorithm);
 
     Path modelLocation = null;
     try {
       modelLocation = persistModel(model);
-      LOG.info("Model saved: " + modelId + ", trainer: " + algorithm + ", path: " + modelLocation);
+      LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
       return model.getId();
     } catch (IOException e) {
-      throw new LensException("Error saving model " + modelId + " for trainer " + algorithm, e);
+      throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
     }
   }
 
   /**
-   * Gets the trainer dir.
+   * Gets the algo dir.
    *
-   * @param trainerName the trainer name
-   * @return the trainer dir
+   * @param algoName the algo name
+   * @return the algo dir
    * @throws IOException Signals that an I/O exception has occurred.
    */
-  private Path getTrainerDir(String trainerName) throws IOException {
+  private Path getAlgoDir(String algoName) throws IOException {
     String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
-    return new Path(new Path(modelSaveBaseDir), trainerName);
+    return new Path(new Path(modelSaveBaseDir), algoName);
   }
 
   /**
@@ -172,14 +172,14 @@ public class LensMLImpl implements LensML {
    */
   private Path persistModel(MLModel model) throws IOException {
     // Get model save path
-    Path trainerDir = getTrainerDir(model.getTrainerName());
-    FileSystem fs = trainerDir.getFileSystem(conf);
+    Path algoDir = getAlgoDir(model.getAlgoName());
+    FileSystem fs = algoDir.getFileSystem(conf);
 
-    if (!fs.exists(trainerDir)) {
-      fs.mkdirs(trainerDir);
+    if (!fs.exists(algoDir)) {
+      fs.mkdirs(algoDir);
     }
 
-    Path modelSavePath = new Path(trainerDir, model.getId());
+    Path modelSavePath = new Path(algoDir, model.getId());
     ObjectOutputStream outputStream = null;
 
     try {
@@ -202,15 +202,15 @@ public class LensMLImpl implements LensML {
    */
   public List<String> getModels(String algorithm) throws LensException {
     try {
-      Path trainerDir = getTrainerDir(algorithm);
-      FileSystem fs = trainerDir.getFileSystem(conf);
-      if (!fs.exists(trainerDir)) {
+      Path algoDir = getAlgoDir(algorithm);
+      FileSystem fs = algoDir.getFileSystem(conf);
+      if (!fs.exists(algoDir)) {
         return null;
       }
 
       List<String> models = new ArrayList<String>();
 
-      for (FileStatus stat : fs.listStatus(trainerDir)) {
+      for (FileStatus stat : fs.listStatus(algoDir)) {
         models.add(stat.getPath().getName());
       }
 
@@ -563,15 +563,15 @@ public class LensMLImpl implements LensML {
    * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
    */
   public Map<String, String> getAlgoParamDescription(String algorithm) {
-    MLTrainer trainer = null;
+    MLAlgo algo = null;
     try {
-      trainer = getTrainerForName(algorithm);
+      algo = getAlgoForName(algorithm);
     } catch (LensException e) {
       LOG.error("Error getting algo description : " + algorithm, e);
       return null;
     }
-    if (trainer instanceof BaseSparkTrainer) {
-      return ((BaseSparkTrainer) trainer).getArgUsage();
+    if (algo instanceof BaseSparkAlgo) {
+      return ((BaseSparkAlgo) algo).getArgUsage();
     }
     return null;
   }

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
new file mode 100644
index 0000000..7dccf2c
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+
+/**
+ * The Interface MLAlgo.
+ */
+public interface MLAlgo {
+  String getName();
+
+  String getDescription();
+
+  /**
+   * Configure.
+   *
+   * @param configuration the configuration
+   */
+  void configure(LensConf configuration);
+
+  LensConf getConf();
+
+  /**
+   * Train.
+   *
+   * @param conf    the conf
+   * @param db      the db
+   * @param table   the table
+   * @param modelId the model id
+   * @param params  the params
+   * @return the ML model
+   * @throws LensException the lens exception
+   */
+  MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
index 562253a..567e717 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
@@ -29,21 +29,21 @@ import org.apache.lens.api.LensException;
 public interface MLDriver {
 
   /**
-   * Checks if is trainer supported.
+   * Checks if is algo supported.
    *
-   * @param trainer the trainer
-   * @return true, if is trainer supported
+   * @param algo the algo
+   * @return true, if is algo supported
    */
-  boolean isTrainerSupported(String trainer);
+  boolean isAlgoSupported(String algo);
 
   /**
-   * Gets the trainer instance.
+   * Gets the algo instance.
    *
-   * @param trainer the trainer
-   * @return the trainer instance
+   * @param algo the algo
+   * @return the algo instance
    * @throws LensException the lens exception
    */
-  MLTrainer getTrainerInstance(String trainer) throws LensException;
+  MLAlgo getAlgoInstance(String algo) throws LensException;
 
   /**
    * Inits the.
@@ -67,5 +67,5 @@ public interface MLDriver {
    */
   void stop() throws LensException;
 
-  List<String> getTrainerNames();
+  List<String> getAlgoNames();
 }

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
index 863cdfe..c177757 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
@@ -44,10 +44,10 @@ public abstract class MLModel<PREDICTION> implements Serializable {
   @Setter
   private Date createdAt;
 
-  /** The trainer name. */
+  /** The algo name. */
   @Getter
   @Setter
-  private String trainerName;
+  private String algoName;
 
   /** The table. */
   @Getter

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
deleted file mode 100644
index f1ae291..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTrainer.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-
-/**
- * The Interface MLTrainer.
- */
-public interface MLTrainer {
-  String getName();
-
-  String getDescription();
-
-  /**
-   * Configure.
-   *
-   * @param configuration the configuration
-   */
-  void configure(LensConf configuration);
-
-  LensConf getConf();
-
-  /**
-   * Train.
-   *
-   * @param conf    the conf
-   * @param db      the db
-   * @param table   the table
-   * @param modelId the model id
-   * @param params  the params
-   * @return the ML model
-   * @throws LensException the lens exception
-   */
-  MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
-}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
index 4ea8773..2e240af 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
@@ -39,12 +39,12 @@ public final class MLUtils {
     HIVE_CONF.addResource("lens-site.xml");
   }
 
-  public static String getTrainerName(Class<? extends MLTrainer> trainerClass) {
-    Algorithm annotation = trainerClass.getAnnotation(Algorithm.class);
+  public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
+    Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
     if (annotation != null) {
       return annotation.name();
     }
-    throw new IllegalArgumentException("Trainer should be decorated with annotation - " + Algorithm.class.getName());
+    throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
   }
 
   public static MLServiceImpl getMLService() throws Exception {

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
deleted file mode 100644
index 92c025d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerArgParser.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.reflect.Field;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-/**
- * The Class TrainerArgParser.
- */
-public final class TrainerArgParser {
-  private TrainerArgParser() {
-  }
-
-  /**
-   * The Class CustomArgParser.
-   *
-   * @param <E> the element type
-   */
-  public abstract static class CustomArgParser<E> {
-
-    /**
-     * Parses the.
-     *
-     * @param value the value
-     * @return the e
-     */
-    public abstract E parse(String value);
-  }
-
-  /** The Constant LOG. */
-  public static final Log LOG = LogFactory.getLog(TrainerArgParser.class);
-
-  /**
-   * Extracts feature names. If the trainer has any parameters associated with @TrainerParam annotation, those are set
-   * as well.
-   *
-   * @param trainer the trainer
-   * @param args    the args
-   * @return List of feature column names.
-   */
-  public static List<String> parseArgs(MLTrainer trainer, String[] args) {
-    List<String> featureColumns = new ArrayList<String>();
-    Class<? extends MLTrainer> trainerClass = trainer.getClass();
-    // Get param fields
-    Map<String, Field> fieldMap = new HashMap<String, Field>();
-
-    for (Field fld : trainerClass.getDeclaredFields()) {
-      fld.setAccessible(true);
-      TrainerParam paramAnnotation = fld.getAnnotation(TrainerParam.class);
-      if (paramAnnotation != null) {
-        fieldMap.put(paramAnnotation.name(), fld);
-      }
-    }
-
-    for (int i = 0; i < args.length; i += 2) {
-      String key = args[i].trim();
-      String value = args[i + 1].trim();
-
-      try {
-        if ("feature".equalsIgnoreCase(key)) {
-          featureColumns.add(value);
-        } else if (fieldMap.containsKey(key)) {
-          Field f = fieldMap.get(key);
-          if (String.class.equals(f.getType())) {
-            f.set(trainer, value);
-          } else if (Integer.TYPE.equals(f.getType())) {
-            f.setInt(trainer, Integer.parseInt(value));
-          } else if (Double.TYPE.equals(f.getType())) {
-            f.setDouble(trainer, Double.parseDouble(value));
-          } else if (Long.TYPE.equals(f.getType())) {
-            f.setLong(trainer, Long.parseLong(value));
-          } else {
-            // check if the trainer provides a deserializer for this param
-            String customParserClass = trainer.getConf().getProperties().get("lens.ml.args." + key);
-            if (customParserClass != null) {
-              Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
-                .forName(customParserClass);
-              CustomArgParser<?> parser = clz.newInstance();
-              f.set(trainer, parser.parse(value));
-            } else {
-              LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
-            }
-          }
-        }
-      } catch (Exception exc) {
-        LOG.error("Error while setting param " + key + " to " + value + " for trainer " + trainer);
-      }
-    }
-    return featureColumns;
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
deleted file mode 100644
index fe8af60..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/TrainerParam.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
-
-/**
- * The Interface TrainerParam.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.FIELD)
-public @interface TrainerParam {
-
-  /**
-   * Name.
-   *
-   * @return the string
-   */
-  String name();
-
-  /**
-   * Help.
-   *
-   * @return the string
-   */
-  String help();
-
-  /**
-   * Default value.
-   *
-   * @return the string
-   */
-  String defaultValue() default "None";
-}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
index b19e8ea..1e452f1 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
@@ -26,9 +26,9 @@ import java.util.List;
 import org.apache.lens.api.LensConf;
 import org.apache.lens.api.LensException;
 import org.apache.lens.ml.Algorithms;
+import org.apache.lens.ml.MLAlgo;
 import org.apache.lens.ml.MLDriver;
-import org.apache.lens.ml.MLTrainer;
-import org.apache.lens.ml.spark.trainers.*;
+import org.apache.lens.ml.spark.algos.*;
 
 import org.apache.commons.lang.StringUtils;
 import org.apache.commons.logging.Log;
@@ -90,46 +90,46 @@ public class SparkMLDriver implements MLDriver {
   /*
    * (non-Javadoc)
    *
-   * @see org.apache.lens.ml.MLDriver#isTrainerSupported(java.lang.String)
+   * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
    */
   @Override
-  public boolean isTrainerSupported(String name) {
+  public boolean isAlgoSupported(String name) {
     return algorithms.isAlgoSupported(name);
   }
 
   /*
    * (non-Javadoc)
    *
-   * @see org.apache.lens.ml.MLDriver#getTrainerInstance(java.lang.String)
+   * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
    */
   @Override
-  public MLTrainer getTrainerInstance(String name) throws LensException {
+  public MLAlgo getAlgoInstance(String name) throws LensException {
     checkStarted();
 
-    if (!isTrainerSupported(name)) {
+    if (!isAlgoSupported(name)) {
       return null;
     }
 
-    MLTrainer trainer = null;
+    MLAlgo algo = null;
     try {
-      trainer = algorithms.getTrainerForName(name);
-      if (trainer instanceof BaseSparkTrainer) {
-        ((BaseSparkTrainer) trainer).setSparkContext(sparkContext);
+      algo = algorithms.getAlgoForName(name);
+      if (algo instanceof BaseSparkAlgo) {
+        ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
       }
     } catch (LensException exc) {
-      LOG.error("Error creating trainer object", exc);
+      LOG.error("Error creating algo object", exc);
     }
-    return trainer;
+    return algo;
   }
 
   /**
-   * Register trainers.
+   * Register algos.
    */
-  private void registerTrainers() {
-    algorithms.register(NaiveBayesTrainer.class);
-    algorithms.register(SVMTrainer.class);
-    algorithms.register(LogisticRegressionTrainer.class);
-    algorithms.register(DecisionTreeTrainer.class);
+  private void registerAlgos() {
+    algorithms.register(NaiveBayesAlgo.class);
+    algorithms.register(SVMAlgo.class);
+    algorithms.register(LogisticRegressionAlgo.class);
+    algorithms.register(DecisionTreeAlgo.class);
   }
 
   /*
@@ -140,7 +140,7 @@ public class SparkMLDriver implements MLDriver {
   @Override
   public void init(LensConf conf) throws LensException {
     sparkConf = new SparkConf();
-    registerTrainers();
+    registerAlgos();
     for (String key : conf.getProperties().keySet()) {
       if (key.startsWith("lens.ml.sparkdriver.")) {
         sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
@@ -253,7 +253,7 @@ public class SparkMLDriver implements MLDriver {
   }
 
   @Override
-  public List<String> getTrainerNames() {
+  public List<String> getAlgoNames() {
     return algorithms.getAlgorithmNames();
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
new file mode 100644
index 0000000..22cda6d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
@@ -0,0 +1,290 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.lang.reflect.Field;
+import java.util.*;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.MLAlgo;
+import org.apache.lens.ml.MLModel;
+
+import org.apache.lens.ml.spark.TableTrainingSpec;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class BaseSparkAlgo.
+ */
+public abstract class BaseSparkAlgo implements MLAlgo {
+
+  /** The Constant LOG. */
+  public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
+
+  /** The name. */
+  private final String name;
+
+  /** The description. */
+  private final String description;
+
+  /** The spark context. */
+  protected JavaSparkContext sparkContext;
+
+  /** The params. */
+  protected Map<String, String> params;
+
+  /** The conf. */
+  protected transient LensConf conf;
+
+  /** The training fraction. */
+  @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
+  protected double trainingFraction;
+
+  /** The use training fraction. */
+  private boolean useTrainingFraction;
+
+  /** The label. */
+  @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
+  protected String label;
+
+  /** The partition filter. */
+  @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
+  protected String partitionFilter;
+
+  /** The features. */
+  @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
+  protected List<String> features;
+
+  /**
+   * Instantiates a new base spark algo.
+   *
+   * @param name        the name
+   * @param description the description
+   */
+  public BaseSparkAlgo(String name, String description) {
+    this.name = name;
+    this.description = description;
+  }
+
+  public void setSparkContext(JavaSparkContext sparkContext) {
+    this.sparkContext = sparkContext;
+  }
+
+  @Override
+  public LensConf getConf() {
+    return conf;
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+   */
+  @Override
+  public void configure(LensConf configuration) {
+    this.conf = configuration;
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+   * java.lang.String, java.lang.String[])
+   */
+  @Override
+  public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
+    throws LensException {
+    parseParams(params);
+
+    TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
+      .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
+
+    if (useTrainingFraction) {
+      builder.trainingFraction(trainingFraction);
+    }
+
+    TableTrainingSpec spec = builder.build();
+    LOG.info("Training " + " with " + features.size() + " features");
+
+    spec.createRDDs(sparkContext);
+
+    RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
+    BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
+    model.setTable(table);
+    model.setParams(Arrays.asList(params));
+    model.setLabelColumn(label);
+    model.setFeatureColumns(features);
+    return model;
+  }
+
+  /**
+   * To hive conf.
+   *
+   * @param conf the conf
+   * @return the hive conf
+   */
+  protected HiveConf toHiveConf(LensConf conf) {
+    HiveConf hiveConf = new HiveConf();
+    for (String key : conf.getProperties().keySet()) {
+      hiveConf.set(key, conf.getProperties().get(key));
+    }
+    return hiveConf;
+  }
+
+  /**
+   * Parses the params.
+   *
+   * @param args the args
+   */
+  public void parseParams(String[] args) {
+    if (args.length % 2 != 0) {
+      throw new IllegalArgumentException("Invalid number of params " + args.length);
+    }
+
+    params = new LinkedHashMap<String, String>();
+
+    for (int i = 0; i < args.length; i += 2) {
+      if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
+        if (features == null) {
+          features = new ArrayList<String>();
+        }
+        features.add(args[i + 1]);
+      } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
+        label = args[i + 1];
+      } else {
+        params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
+      }
+    }
+
+    if (params.containsKey("trainingFraction")) {
+      // Get training Fraction
+      String trainingFractionStr = params.get("trainingFraction");
+      try {
+        trainingFraction = Double.parseDouble(trainingFractionStr);
+        useTrainingFraction = true;
+      } catch (NumberFormatException nfe) {
+        throw new IllegalArgumentException("Invalid training fraction", nfe);
+      }
+    }
+
+    if (params.containsKey("partition") || params.containsKey("p")) {
+      partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
+    }
+
+    parseAlgoParams(params);
+  }
+
+  /**
+   * Gets the param value.
+   *
+   * @param param      the param
+   * @param defaultVal the default val
+   * @return the param value
+   */
+  public double getParamValue(String param, double defaultVal) {
+    if (params.containsKey(param)) {
+      try {
+        return Double.parseDouble(params.get(param));
+      } catch (NumberFormatException nfe) {
+        LOG.warn("Couldn't parse param value: " + param + " as double.");
+      }
+    }
+    return defaultVal;
+  }
+
+  /**
+   * Gets the param value.
+   *
+   * @param param      the param
+   * @param defaultVal the default val
+   * @return the param value
+   */
+  public int getParamValue(String param, int defaultVal) {
+    if (params.containsKey(param)) {
+      try {
+        return Integer.parseInt(params.get(param));
+      } catch (NumberFormatException nfe) {
+        LOG.warn("Couldn't parse param value: " + param + " as integer.");
+      }
+    }
+    return defaultVal;
+  }
+
+  public String getName() {
+    return name;
+  }
+
+  public String getDescription() {
+    return description;
+  }
+
+  public Map<String, String> getArgUsage() {
+    Map<String, String> usage = new LinkedHashMap<String, String>();
+    Class<?> clz = this.getClass();
+    // Put class name and description as well as part of the usage
+    Algorithm algorithm = clz.getAnnotation(Algorithm.class);
+    if (algorithm != null) {
+      usage.put("Algorithm Name", algorithm.name());
+      usage.put("Algorithm Description", algorithm.description());
+    }
+
+    // Get all algo params including base algo params
+    while (clz != null) {
+      for (Field field : clz.getDeclaredFields()) {
+        AlgoParam param = field.getAnnotation(AlgoParam.class);
+        if (param != null) {
+          usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
+        }
+      }
+
+      if (clz.equals(BaseSparkAlgo.class)) {
+        break;
+      }
+      clz = clz.getSuperclass();
+    }
+    return usage;
+  }
+
+  /**
+   * Parses the algo params.
+   *
+   * @param params the params
+   */
+  public abstract void parseAlgoParams(Map<String, String> params);
+
+  /**
+   * Train internal.
+   *
+   * @param modelId     the model id
+   * @param trainingRDD the training rdd
+   * @return the base spark classification model
+   * @throws LensException the lens exception
+   */
+  protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+    throws LensException;
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
new file mode 100644
index 0000000..a6d66c5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
+import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
+
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.DecisionTree$;
+import org.apache.spark.mllib.tree.configuration.Algo$;
+import org.apache.spark.mllib.tree.impurity.Entropy$;
+import org.apache.spark.mllib.tree.impurity.Gini$;
+import org.apache.spark.mllib.tree.impurity.Impurity;
+import org.apache.spark.mllib.tree.impurity.Variance$;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+import scala.Enumeration;
+
+/**
+ * The Class DecisionTreeAlgo.
+ */
+@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
+public class DecisionTreeAlgo extends BaseSparkAlgo {
+
+  /** The algo. */
+  @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
+  private Enumeration.Value algo;
+
+  /** The decision tree impurity. */
+  @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
+    + "Allowed values are 'gini', 'entropy' and 'variance'")
+  private Impurity decisionTreeImpurity;
+
+  /** The max depth. */
+  @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
+    defaultValue = "100")
+  private int maxDepth;
+
+  /**
+   * Instantiates a new decision tree algo.
+   *
+   * @param name        the name
+   * @param description the description
+   */
+  public DecisionTreeAlgo(String name, String description) {
+    super(name, description);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+   */
+  @Override
+  public void parseAlgoParams(Map<String, String> params) {
+    String dtreeAlgoName = params.get("algo");
+    if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
+      algo = Algo$.MODULE$.Classification();
+    } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
+      algo = Algo$.MODULE$.Regression();
+    }
+
+    String impurity = params.get("impurity");
+    if ("gini".equals(impurity)) {
+      decisionTreeImpurity = Gini$.MODULE$;
+    } else if ("entropy".equals(impurity)) {
+      decisionTreeImpurity = Entropy$.MODULE$;
+    } else if ("variance".equals(impurity)) {
+      decisionTreeImpurity = Variance$.MODULE$;
+    }
+
+    maxDepth = getParamValue("maxDepth", 100);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+   */
+  @Override
+  protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+    throws LensException {
+    DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
+    return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
new file mode 100644
index 0000000..7ca5a79
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.*;
+import org.apache.lens.ml.spark.HiveTableRDD;
+import org.apache.lens.ml.spark.models.KMeansClusteringModel;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.KMeans;
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+import scala.Tuple2;
+
+/**
+ * The Class KMeansAlgo.
+ */
+@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
+public class KMeansAlgo implements MLAlgo {
+
+  /** The conf. */
+  private transient LensConf conf;
+
+  /** The spark context. */
+  private JavaSparkContext sparkContext;
+
+  /** The part filter. */
+  @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
+  private String partFilter = null;
+
+  /** The k. */
+  @AlgoParam(name = "k", help = "Number of cluster")
+  private int k;
+
+  /** The max iterations. */
+  @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
+  private int maxIterations = 100;
+
+  /** The runs. */
+  @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
+  private int runs = 1;
+
+  /** The initialization mode. */
+  @AlgoParam(name = "initializationMode",
+    help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
+  private String initializationMode = "k-means||";
+
+  @Override
+  public String getName() {
+    return getClass().getAnnotation(Algorithm.class).name();
+  }
+
+  @Override
+  public String getDescription() {
+    return getClass().getAnnotation(Algorithm.class).description();
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+   */
+  @Override
+  public void configure(LensConf configuration) {
+    this.conf = configuration;
+  }
+
+  @Override
+  public LensConf getConf() {
+    return conf;
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+   * java.lang.String, java.lang.String[])
+   */
+  @Override
+  public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
+    List<String> features = AlgoArgParser.parseArgs(this, params);
+    final int[] featurePositions = new int[features.size()];
+    final int NUM_FEATURES = features.size();
+
+    JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
+    try {
+      // Map feature names to positions
+      Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
+      List<FieldSchema> allCols = tbl.getAllCols();
+      int f = 0;
+      for (int i = 0; i < tbl.getAllCols().size(); i++) {
+        String colName = allCols.get(i).getName();
+        if (features.contains(colName)) {
+          featurePositions[f++] = i;
+        }
+      }
+
+      rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
+      JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
+        @Override
+        public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
+          HCatRecord hCatRecord = v1._2();
+          double[] arr = new double[NUM_FEATURES];
+          for (int i = 0; i < NUM_FEATURES; i++) {
+            Object val = hCatRecord.get(featurePositions[i]);
+            arr[i] = val == null ? 0d : (Double) val;
+          }
+          return Vectors.dense(arr);
+        }
+      });
+
+      KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
+      return new KMeansClusteringModel(modelId, model);
+    } catch (Exception e) {
+      throw new LensException("KMeans algo failed for " + db + "." + table, e);
+    }
+  }
+
+  /**
+   * To hive conf.
+   *
+   * @param conf the conf
+   * @return the hive conf
+   */
+  private HiveConf toHiveConf(LensConf conf) {
+    HiveConf hiveConf = new HiveConf();
+    for (String key : conf.getProperties().keySet()) {
+      hiveConf.set(key, conf.getProperties().get(key));
+    }
+    return hiveConf;
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
new file mode 100644
index 0000000..106b3c5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class LogisticRegressionAlgo.
+ */
+@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
+public class LogisticRegressionAlgo extends BaseSparkAlgo {
+
+  /** The iterations. */
+  @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
+  private int iterations;
+
+  /** The step size. */
+  @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
+  private double stepSize;
+
+  /** The min batch fraction. */
+  @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+  private double minBatchFraction;
+
+  /**
+   * Instantiates a new logistic regression algo.
+   *
+   * @param name        the name
+   * @param description the description
+   */
+  public LogisticRegressionAlgo(String name, String description) {
+    super(name, description);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+   */
+  @Override
+  public void parseAlgoParams(Map<String, String> params) {
+    iterations = getParamValue("iterations", 100);
+    stepSize = getParamValue("stepSize", 1.0d);
+    minBatchFraction = getParamValue("minBatchFraction", 1.0d);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+   */
+  @Override
+  protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+    throws LensException {
+    LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
+      minBatchFraction);
+    return new LogitRegressionClassificationModel(modelId, lrModel);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
new file mode 100644
index 0000000..f7652d1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayes;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class NaiveBayesAlgo.
+ */
+@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
+public class NaiveBayesAlgo extends BaseSparkAlgo {
+
+  /** The lambda. */
+  @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
+  private double lambda = 1.0;
+
+  /**
+   * Instantiates a new naive bayes algo.
+   *
+   * @param name        the name
+   * @param description the description
+   */
+  public NaiveBayesAlgo(String name, String description) {
+    super(name, description);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+   */
+  @Override
+  public void parseAlgoParams(Map<String, String> params) {
+    lambda = getParamValue("lambda", 1.0d);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+   */
+  @Override
+  protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+    throws LensException {
+    return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
new file mode 100644
index 0000000..09251b7
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.spark.algos;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.AlgoParam;
+import org.apache.lens.ml.Algorithm;
+import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
+import org.apache.lens.ml.spark.models.SVMClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.classification.SVMWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class SVMAlgo.
+ */
+@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
+public class SVMAlgo extends BaseSparkAlgo {
+
+  /** The min batch fraction. */
+  @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+  private double minBatchFraction;
+
+  /** The reg param. */
+  @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
+  private double regParam;
+
+  /** The step size. */
+  @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
+  private double stepSize;
+
+  /** The iterations. */
+  @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
+  private int iterations;
+
+  /**
+   * Instantiates a new SVM algo.
+   *
+   * @param name        the name
+   * @param description the description
+   */
+  public SVMAlgo(String name, String description) {
+    super(name, description);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+   */
+  @Override
+  public void parseAlgoParams(Map<String, String> params) {
+    minBatchFraction = getParamValue("minBatchFraction", 1.0);
+    regParam = getParamValue("regParam", 1.0);
+    stepSize = getParamValue("stepSize", 1.0);
+    iterations = getParamValue("iterations", 100);
+  }
+
+  /*
+   * (non-Javadoc)
+   *
+   * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+   */
+  @Override
+  protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+    throws LensException {
+    SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
+    return new SVMClassificationModel(modelId, svmModel);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
deleted file mode 100644
index f75e41c..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/BaseSparkTrainer.java
+++ /dev/null
@@ -1,289 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.lang.reflect.Field;
-import java.util.*;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTrainer;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.TableTrainingSpec;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class BaseSparkTrainer.
- */
-public abstract class BaseSparkTrainer implements MLTrainer {
-
-  /** The Constant LOG. */
-  public static final Log LOG = LogFactory.getLog(BaseSparkTrainer.class);
-
-  /** The name. */
-  private final String name;
-
-  /** The description. */
-  private final String description;
-
-  /** The spark context. */
-  protected JavaSparkContext sparkContext;
-
-  /** The params. */
-  protected Map<String, String> params;
-
-  /** The conf. */
-  protected transient LensConf conf;
-
-  /** The training fraction. */
-  @TrainerParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
-  protected double trainingFraction;
-
-  /** The use training fraction. */
-  private boolean useTrainingFraction;
-
-  /** The label. */
-  @TrainerParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
-  protected String label;
-
-  /** The partition filter. */
-  @TrainerParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
-  protected String partitionFilter;
-
-  /** The features. */
-  @TrainerParam(name = "feature", help = "Column name(s) which are to be used as sample features")
-  protected List<String> features;
-
-  /**
-   * Instantiates a new base spark trainer.
-   *
-   * @param name        the name
-   * @param description the description
-   */
-  public BaseSparkTrainer(String name, String description) {
-    this.name = name;
-    this.description = description;
-  }
-
-  public void setSparkContext(JavaSparkContext sparkContext) {
-    this.sparkContext = sparkContext;
-  }
-
-  @Override
-  public LensConf getConf() {
-    return conf;
-  }
-
-  /*
-   * (non-Javadoc)
-   *
-   * @see org.apache.lens.ml.MLTrainer#configure(org.apache.lens.api.LensConf)
-   */
-  @Override
-  public void configure(LensConf configuration) {
-    this.conf = configuration;
-  }
-
-  /*
-   * (non-Javadoc)
-   *
-   * @see org.apache.lens.ml.MLTrainer#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
-   * java.lang.String, java.lang.String[])
-   */
-  @Override
-  public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
-    throws LensException {
-    parseParams(params);
-
-    TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
-      .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
-
-    if (useTrainingFraction) {
-      builder.trainingFraction(trainingFraction);
-    }
-
-    TableTrainingSpec spec = builder.build();
-    LOG.info("Training " + " with " + features.size() + " features");
-
-    spec.createRDDs(sparkContext);
-
-    RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
-    BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
-    model.setTable(table);
-    model.setParams(Arrays.asList(params));
-    model.setLabelColumn(label);
-    model.setFeatureColumns(features);
-    return model;
-  }
-
-  /**
-   * To hive conf.
-   *
-   * @param conf the conf
-   * @return the hive conf
-   */
-  protected HiveConf toHiveConf(LensConf conf) {
-    HiveConf hiveConf = new HiveConf();
-    for (String key : conf.getProperties().keySet()) {
-      hiveConf.set(key, conf.getProperties().get(key));
-    }
-    return hiveConf;
-  }
-
-  /**
-   * Parses the params.
-   *
-   * @param args the args
-   */
-  public void parseParams(String[] args) {
-    if (args.length % 2 != 0) {
-      throw new IllegalArgumentException("Invalid number of params " + args.length);
-    }
-
-    params = new LinkedHashMap<String, String>();
-
-    for (int i = 0; i < args.length; i += 2) {
-      if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
-        if (features == null) {
-          features = new ArrayList<String>();
-        }
-        features.add(args[i + 1]);
-      } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
-        label = args[i + 1];
-      } else {
-        params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
-      }
-    }
-
-    if (params.containsKey("trainingFraction")) {
-      // Get training Fraction
-      String trainingFractionStr = params.get("trainingFraction");
-      try {
-        trainingFraction = Double.parseDouble(trainingFractionStr);
-        useTrainingFraction = true;
-      } catch (NumberFormatException nfe) {
-        throw new IllegalArgumentException("Invalid training fraction", nfe);
-      }
-    }
-
-    if (params.containsKey("partition") || params.containsKey("p")) {
-      partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
-    }
-
-    parseTrainerParams(params);
-  }
-
-  /**
-   * Gets the param value.
-   *
-   * @param param      the param
-   * @param defaultVal the default val
-   * @return the param value
-   */
-  public double getParamValue(String param, double defaultVal) {
-    if (params.containsKey(param)) {
-      try {
-        return Double.parseDouble(params.get(param));
-      } catch (NumberFormatException nfe) {
-        LOG.warn("Couldn't parse param value: " + param + " as double.");
-      }
-    }
-    return defaultVal;
-  }
-
-  /**
-   * Gets the param value.
-   *
-   * @param param      the param
-   * @param defaultVal the default val
-   * @return the param value
-   */
-  public int getParamValue(String param, int defaultVal) {
-    if (params.containsKey(param)) {
-      try {
-        return Integer.parseInt(params.get(param));
-      } catch (NumberFormatException nfe) {
-        LOG.warn("Couldn't parse param value: " + param + " as integer.");
-      }
-    }
-    return defaultVal;
-  }
-
-  public String getName() {
-    return name;
-  }
-
-  public String getDescription() {
-    return description;
-  }
-
-  public Map<String, String> getArgUsage() {
-    Map<String, String> usage = new LinkedHashMap<String, String>();
-    Class<?> clz = this.getClass();
-    // Put class name and description as well as part of the usage
-    Algorithm algorithm = clz.getAnnotation(Algorithm.class);
-    if (algorithm != null) {
-      usage.put("Algorithm Name", algorithm.name());
-      usage.put("Algorithm Description", algorithm.description());
-    }
-
-    // Get all trainer params including base trainer params
-    while (clz != null) {
-      for (Field field : clz.getDeclaredFields()) {
-        TrainerParam param = field.getAnnotation(TrainerParam.class);
-        if (param != null) {
-          usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
-        }
-      }
-
-      if (clz.equals(BaseSparkTrainer.class)) {
-        break;
-      }
-      clz = clz.getSuperclass();
-    }
-    return usage;
-  }
-
-  /**
-   * Parses the trainer params.
-   *
-   * @param params the params
-   */
-  public abstract void parseTrainerParams(Map<String, String> params);
-
-  /**
-   * Train internal.
-   *
-   * @param modelId     the model id
-   * @param trainingRDD the training rdd
-   * @return the base spark classification model
-   * @throws LensException the lens exception
-   */
-  protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
-    throws LensException;
-}

http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/49fef8e2/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
deleted file mode 100644
index 96fbf28..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/trainers/DecisionTreeTrainer.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.trainers;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.TrainerParam;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
-import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
-
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.DecisionTree$;
-import org.apache.spark.mllib.tree.configuration.Algo$;
-import org.apache.spark.mllib.tree.impurity.Entropy$;
-import org.apache.spark.mllib.tree.impurity.Gini$;
-import org.apache.spark.mllib.tree.impurity.Impurity;
-import org.apache.spark.mllib.tree.impurity.Variance$;
-import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.rdd.RDD;
-
-import scala.Enumeration;
-
-/**
- * The Class DecisionTreeTrainer.
- */
-@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier trainer")
-public class DecisionTreeTrainer extends BaseSparkTrainer {
-
-  /** The algo. */
-  @TrainerParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
-  private Enumeration.Value algo;
-
-  /** The decision tree impurity. */
-  @TrainerParam(name = "impurity", help = "Impurity measure used by the decision tree. "
-    + "Allowed values are 'gini', 'entropy' and 'variance'")
-  private Impurity decisionTreeImpurity;
-
-  /** The max depth. */
-  @TrainerParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
-    defaultValue = "100")
-  private int maxDepth;
-
-  /**
-   * Instantiates a new decision tree trainer.
-   *
-   * @param name        the name
-   * @param description the description
-   */
-  public DecisionTreeTrainer(String name, String description) {
-    super(name, description);
-  }
-
-  /*
-   * (non-Javadoc)
-   *
-   * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#parseTrainerParams(java.util.Map)
-   */
-  @Override
-  public void parseTrainerParams(Map<String, String> params) {
-    String dtreeAlgoName = params.get("algo");
-    if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
-      algo = Algo$.MODULE$.Classification();
-    } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
-      algo = Algo$.MODULE$.Regression();
-    }
-
-    String impurity = params.get("impurity");
-    if ("gini".equals(impurity)) {
-      decisionTreeImpurity = Gini$.MODULE$;
-    } else if ("entropy".equals(impurity)) {
-      decisionTreeImpurity = Entropy$.MODULE$;
-    } else if ("variance".equals(impurity)) {
-      decisionTreeImpurity = Variance$.MODULE$;
-    }
-
-    maxDepth = getParamValue("maxDepth", 100);
-  }
-
-  /*
-   * (non-Javadoc)
-   *
-   * @see org.apache.lens.ml.spark.trainers.BaseSparkTrainer#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
-   */
-  @Override
-  protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
-    throws LensException {
-    DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
-    return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
-  }
-}