You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by sh...@apache.org on 2015/04/05 09:11:05 UTC
[4/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
new file mode 100644
index 0000000..c955268
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.File;
+import java.io.FilenameFilter;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLDriver;
+import org.apache.lens.ml.algo.lib.Algorithms;
+import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
+import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
+import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
+import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * The Class SparkMLDriver.
+ */
+public class SparkMLDriver implements MLDriver {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
+
+ /** The owns spark context. */
+ private boolean ownsSparkContext = true;
+
+ /**
+ * The Enum SparkMasterMode.
+ */
+ private enum SparkMasterMode {
+ // Embedded mode used in tests
+ /** The embedded. */
+ EMBEDDED,
+ // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
+ /** The yarn client. */
+ YARN_CLIENT,
+
+ /** The yarn cluster. */
+ YARN_CLUSTER
+ }
+
+ /** The algorithms. */
+ private final Algorithms algorithms = new Algorithms();
+
+ /** The client mode. */
+ private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
+
+ /** The is started. */
+ private boolean isStarted;
+
+ /** The spark conf. */
+ private SparkConf sparkConf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /**
+ * Use spark context.
+ *
+ * @param jsc the jsc
+ */
+ public void useSparkContext(JavaSparkContext jsc) {
+ ownsSparkContext = false;
+ this.sparkContext = jsc;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
+ */
+ @Override
+ public boolean isAlgoSupported(String name) {
+ return algorithms.isAlgoSupported(name);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
+ */
+ @Override
+ public MLAlgo getAlgoInstance(String name) throws LensException {
+ checkStarted();
+
+ if (!isAlgoSupported(name)) {
+ return null;
+ }
+
+ MLAlgo algo = null;
+ try {
+ algo = algorithms.getAlgoForName(name);
+ if (algo instanceof BaseSparkAlgo) {
+ ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
+ }
+ } catch (LensException exc) {
+ LOG.error("Error creating algo object", exc);
+ }
+ return algo;
+ }
+
+ /**
+ * Register algos.
+ */
+ private void registerAlgos() {
+ algorithms.register(NaiveBayesAlgo.class);
+ algorithms.register(SVMAlgo.class);
+ algorithms.register(LogisticRegressionAlgo.class);
+ algorithms.register(DecisionTreeAlgo.class);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void init(LensConf conf) throws LensException {
+ sparkConf = new SparkConf();
+ registerAlgos();
+ for (String key : conf.getProperties().keySet()) {
+ if (key.startsWith("lens.ml.sparkdriver.")) {
+ sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
+ }
+ }
+
+ String sparkAppMaster = sparkConf.get("spark.master");
+ if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) {
+ clientMode = SparkMasterMode.YARN_CLIENT;
+ } else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) {
+ clientMode = SparkMasterMode.YARN_CLUSTER;
+ } else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) {
+ clientMode = SparkMasterMode.EMBEDDED;
+ } else {
+ throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster);
+ }
+
+ if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) {
+ String sparkHome = System.getenv("SPARK_HOME");
+ if (StringUtils.isNotBlank(sparkHome)) {
+ sparkConf.setSparkHome(sparkHome);
+ }
+
+ // If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties.
+ if (StringUtils.isBlank(sparkConf.get("spark.home"))) {
+ throw new IllegalArgumentException("Spark home is not set");
+ }
+
+ LOG.info("Spark home is set to " + sparkConf.get("spark.home"));
+ }
+
+ sparkConf.setAppName("lens-ml");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#start()
+ */
+ @Override
+ public void start() throws LensException {
+ if (sparkContext == null) {
+ sparkContext = new JavaSparkContext(sparkConf);
+ }
+
+ // Adding jars to spark context is only required when running in yarn-client mode
+ if (clientMode != SparkMasterMode.EMBEDDED) {
+ // TODO Figure out only necessary set of JARs to be added for HCatalog
+ // Add hcatalog and hive jars
+ String hiveLocation = System.getenv("HIVE_HOME");
+
+ if (StringUtils.isBlank(hiveLocation)) {
+ throw new LensException("HIVE_HOME is not set");
+ }
+
+ LOG.info("HIVE_HOME at " + hiveLocation);
+
+ File hiveLibDir = new File(hiveLocation, "lib");
+ FilenameFilter jarFileFilter = new FilenameFilter() {
+ @Override
+ public boolean accept(File file, String s) {
+ return s.endsWith(".jar");
+ }
+ };
+
+ List<String> jarFiles = new ArrayList<String>();
+ // Add hive jars
+ for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) {
+ jarFiles.add(jarFile.getAbsolutePath());
+ LOG.info("Adding HIVE jar " + jarFile.getAbsolutePath());
+ sparkContext.addJar(jarFile.getAbsolutePath());
+ }
+
+ // Add hcatalog jars
+ File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog");
+ for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) {
+ jarFiles.add(jarFile.getAbsolutePath());
+ LOG.info("Adding HCATALOG jar " + jarFile.getAbsolutePath());
+ sparkContext.addJar(jarFile.getAbsolutePath());
+ }
+
+ // Add the current jar
+ String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class);
+ for (String lensSparkJar : lensSparkLibJars) {
+ LOG.info("Adding Lens JAR " + lensSparkJar);
+ sparkContext.addJar(lensSparkJar);
+ }
+ }
+
+ isStarted = true;
+ LOG.info("Created Spark context for app: '" + sparkContext.appName() + "', Spark master: " + sparkContext.master());
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#stop()
+ */
+ @Override
+ public void stop() throws LensException {
+ if (!isStarted) {
+ LOG.warn("Spark driver was not started");
+ return;
+ }
+ isStarted = false;
+ if (ownsSparkContext) {
+ sparkContext.stop();
+ }
+ LOG.info("Stopped spark context " + this);
+ }
+
+ @Override
+ public List<String> getAlgoNames() {
+ return algorithms.getAlgorithmNames();
+ }
+
+ /**
+ * Check started.
+ *
+ * @throws LensException the lens exception
+ */
+ public void checkStarted() throws LensException {
+ if (!isStarted) {
+ throw new LensException("Spark driver is not started yet");
+ }
+ }
+
+ public JavaSparkContext getSparkContext() {
+ return sparkContext;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
new file mode 100644
index 0000000..33fd801
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
@@ -0,0 +1,433 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.lens.api.LensException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.hive.hcatalog.data.schema.HCatFieldSchema;
+import org.apache.hive.hcatalog.data.schema.HCatSchema;
+import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+import com.google.common.base.Preconditions;
+import lombok.Getter;
+import lombok.ToString;
+
+/**
+ * The Class TableTrainingSpec.
+ */
+@ToString
+public class TableTrainingSpec implements Serializable {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);
+
+ /** The training rdd. */
+ @Getter
+ private transient RDD<LabeledPoint> trainingRDD;
+
+ /** The testing rdd. */
+ @Getter
+ private transient RDD<LabeledPoint> testingRDD;
+
+ /** The database. */
+ @Getter
+ private String database;
+
+ /** The table. */
+ @Getter
+ private String table;
+
+ /** The partition filter. */
+ @Getter
+ private String partitionFilter;
+
+ /** The feature columns. */
+ @Getter
+ private List<String> featureColumns;
+
+ /** The label column. */
+ @Getter
+ private String labelColumn;
+
+ /** The conf. */
+ @Getter
+ private transient HiveConf conf;
+
+ // By default all samples are considered for training
+ /** The split training. */
+ private boolean splitTraining;
+
+ /** The training fraction. */
+ private double trainingFraction = 1.0;
+
+ /** The label pos. */
+ int labelPos;
+
+ /** The feature positions. */
+ int[] featurePositions;
+
+ /** The num features. */
+ int numFeatures;
+
+ /** The labeled rdd. */
+ transient JavaRDD<LabeledPoint> labeledRDD;
+
+ /**
+ * New builder.
+ *
+ * @return the table training spec builder
+ */
+ public static TableTrainingSpecBuilder newBuilder() {
+ return new TableTrainingSpecBuilder();
+ }
+
+ /**
+ * The Class TableTrainingSpecBuilder.
+ */
+ public static class TableTrainingSpecBuilder {
+
+ /** The spec. */
+ final TableTrainingSpec spec;
+
+ /**
+ * Instantiates a new table training spec builder.
+ */
+ public TableTrainingSpecBuilder() {
+ spec = new TableTrainingSpec();
+ }
+
+ /**
+ * Hive conf.
+ *
+ * @param conf the conf
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder hiveConf(HiveConf conf) {
+ spec.conf = conf;
+ return this;
+ }
+
+ /**
+ * Database.
+ *
+ * @param db the db
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder database(String db) {
+ spec.database = db;
+ return this;
+ }
+
+ /**
+ * Table.
+ *
+ * @param table the table
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder table(String table) {
+ spec.table = table;
+ return this;
+ }
+
+ /**
+ * Partition filter.
+ *
+ * @param partFilter the part filter
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder partitionFilter(String partFilter) {
+ spec.partitionFilter = partFilter;
+ return this;
+ }
+
+ /**
+ * Label column.
+ *
+ * @param labelColumn the label column
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder labelColumn(String labelColumn) {
+ spec.labelColumn = labelColumn;
+ return this;
+ }
+
+ /**
+ * Feature columns.
+ *
+ * @param featureColumns the feature columns
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) {
+ spec.featureColumns = featureColumns;
+ return this;
+ }
+
+ /**
+ * Builds the.
+ *
+ * @return the table training spec
+ */
+ public TableTrainingSpec build() {
+ return spec;
+ }
+
+ /**
+ * Training fraction.
+ *
+ * @param trainingFraction the training fraction
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder trainingFraction(double trainingFraction) {
+ Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0,
+ "Training fraction shoule be between 0 and 1");
+ spec.trainingFraction = trainingFraction;
+ spec.splitTraining = true;
+ return this;
+ }
+ }
+
+ /**
+ * The Class DataSample.
+ */
+ public static class DataSample implements Serializable {
+
+ /** The labeled point. */
+ private final LabeledPoint labeledPoint;
+
+ /** The sample. */
+ private final double sample;
+
+ /**
+ * Instantiates a new data sample.
+ *
+ * @param labeledPoint the labeled point
+ */
+ public DataSample(LabeledPoint labeledPoint) {
+ sample = Math.random();
+ this.labeledPoint = labeledPoint;
+ }
+ }
+
+ /**
+ * The Class TrainingFilter.
+ */
+ public static class TrainingFilter implements Function<DataSample, Boolean> {
+
+ /** The training fraction. */
+ private double trainingFraction;
+
+ /**
+ * Instantiates a new training filter.
+ *
+ * @param fraction the fraction
+ */
+ public TrainingFilter(double fraction) {
+ trainingFraction = fraction;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public Boolean call(DataSample v1) throws Exception {
+ return v1.sample <= trainingFraction;
+ }
+ }
+
+ /**
+ * The Class TestingFilter.
+ */
+ public static class TestingFilter implements Function<DataSample, Boolean> {
+
+ /** The training fraction. */
+ private double trainingFraction;
+
+ /**
+ * Instantiates a new testing filter.
+ *
+ * @param fraction the fraction
+ */
+ public TestingFilter(double fraction) {
+ trainingFraction = fraction;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public Boolean call(DataSample v1) throws Exception {
+ return v1.sample > trainingFraction;
+ }
+ }
+
+ /**
+ * The Class GetLabeledPoint.
+ */
+ public static class GetLabeledPoint implements Function<DataSample, LabeledPoint> {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public LabeledPoint call(DataSample v1) throws Exception {
+ return v1.labeledPoint;
+ }
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ boolean validate() {
+ List<HCatFieldSchema> columns;
+ try {
+ HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
+ HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
+ columns = tableSchema.getFields();
+ } catch (IOException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ LOG.info(table + " columns " + columns.toString());
+
+ boolean valid = false;
+ if (columns != null && !columns.isEmpty()) {
+ // Check labeled column
+ List<String> columnNames = new ArrayList<String>();
+ for (HCatFieldSchema col : columns) {
+ columnNames.add(col.getName());
+ }
+
+ // Need at least one feature column and one label column
+ valid = columnNames.contains(labelColumn) && columnNames.size() > 1;
+
+ if (valid) {
+ labelPos = columnNames.indexOf(labelColumn);
+
+ // Check feature columns
+ if (featureColumns == null || featureColumns.isEmpty()) {
+ // feature columns are not provided, so all columns except label column are feature columns
+ featurePositions = new int[columnNames.size() - 1];
+ int p = 0;
+ for (int i = 0; i < columnNames.size(); i++) {
+ if (i == labelPos) {
+ continue;
+ }
+ featurePositions[p++] = i;
+ }
+
+ columnNames.remove(labelPos);
+ featureColumns = columnNames;
+ } else {
+ // Feature columns were provided, verify all feature columns are present in the table
+ valid = columnNames.containsAll(featureColumns);
+ if (valid) {
+ // Get feature positions
+ featurePositions = new int[featureColumns.size()];
+ for (int i = 0; i < featureColumns.size(); i++) {
+ featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
+ }
+ }
+ }
+ numFeatures = featureColumns.size();
+ }
+ }
+
+ return valid;
+ }
+
+ /**
+ * Creates the rd ds.
+ *
+ * @param sparkContext the spark context
+ * @throws LensException the lens exception
+ */
+ public void createRDDs(JavaSparkContext sparkContext) throws LensException {
+ // Validate the spec
+ if (!validate()) {
+ throw new LensException("Table spec not valid: " + toString());
+ }
+
+ LOG.info("Creating RDDs with spec " + toString());
+
+ // Get the RDD for table
+ JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
+ try {
+ tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+
+ // Map into trainable RDD
+ // TODO: Figure out a way to use custom value mappers
+ FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
+ final DoubleValueMapper doubleMapper = new DoubleValueMapper();
+ for (int i = 0; i < numFeatures; i++) {
+ valueMappers[i] = doubleMapper;
+ }
+
+ ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
+ numFeatures, 0);
+ labeledRDD = tableRDD.map(trainPrepFunction);
+
+ if (splitTraining) {
+ // We have to split the RDD between a training RDD and a testing RDD
+ LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
+ JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
+ @Override
+ public DataSample call(LabeledPoint v1) throws Exception {
+ return new DataSample(v1);
+ }
+ });
+
+ trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ } else {
+ LOG.info("Using same RDD for train and test");
+ trainingRDD = labeledRDD.rdd();
+ testingRDD = trainingRDD;
+ }
+ LOG.info("Generated RDDs");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
new file mode 100644
index 0000000..6c7619a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.DecisionTree$;
+import org.apache.spark.mllib.tree.configuration.Algo$;
+import org.apache.spark.mllib.tree.impurity.Entropy$;
+import org.apache.spark.mllib.tree.impurity.Gini$;
+import org.apache.spark.mllib.tree.impurity.Impurity;
+import org.apache.spark.mllib.tree.impurity.Variance$;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+import scala.Enumeration;
+
+/**
+ * The Class DecisionTreeAlgo.
+ */
+@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
+public class DecisionTreeAlgo extends BaseSparkAlgo {
+
+ /** The algo. */
+ @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
+ private Enumeration.Value algo;
+
+ /** The decision tree impurity. */
+ @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
+ + "Allowed values are 'gini', 'entropy' and 'variance'")
+ private Impurity decisionTreeImpurity;
+
+ /** The max depth. */
+ @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
+ defaultValue = "100")
+ private int maxDepth;
+
+ /**
+ * Instantiates a new decision tree algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public DecisionTreeAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ String dtreeAlgoName = params.get("algo");
+ if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Classification();
+ } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Regression();
+ }
+
+ String impurity = params.get("impurity");
+ if ("gini".equals(impurity)) {
+ decisionTreeImpurity = Gini$.MODULE$;
+ } else if ("entropy".equals(impurity)) {
+ decisionTreeImpurity = Entropy$.MODULE$;
+ } else if ("variance".equals(impurity)) {
+ decisionTreeImpurity = Variance$.MODULE$;
+ }
+
+ maxDepth = getParamValue("maxDepth", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
+ return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
new file mode 100644
index 0000000..27c32f4
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+/**
+ * The Class DecisionTreeClassificationModel.
+ */
+public class DecisionTreeClassificationModel extends BaseSparkClassificationModel<SparkDecisionTreeModel> {
+
+ /**
+ * Instantiates a new decision tree classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public DecisionTreeClassificationModel(String modelId, SparkDecisionTreeModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
new file mode 100644
index 0000000..e561a8d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import org.apache.lens.ml.algo.spark.DoubleValueMapper;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.classification.ClassificationModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * This class is created because the Spark decision tree model doesn't extend ClassificationModel.
+ */
+public class SparkDecisionTreeModel implements ClassificationModel {
+
+ /** The model. */
+ private final DecisionTreeModel model;
+
+ /**
+ * Instantiates a new spark decision tree model.
+ *
+ * @param model the model
+ */
+ public SparkDecisionTreeModel(DecisionTreeModel model) {
+ this.model = model;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.rdd.RDD)
+ */
+ @Override
+ public RDD<Object> predict(RDD<Vector> testData) {
+ return model.predict(testData);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.mllib.linalg.Vector)
+ */
+ @Override
+ public double predict(Vector testData) {
+ return model.predict(testData);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.api.java.JavaRDD)
+ */
+ @Override
+ public JavaRDD<Double> predict(JavaRDD<Vector> testData) {
+ return model.predict(testData.rdd()).toJavaRDD().map(new DoubleValueMapper());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
new file mode 100644
index 0000000..6450f70
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.kmeans;
+
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.*;
+import org.apache.lens.ml.algo.lib.AlgoArgParser;
+import org.apache.lens.ml.algo.spark.HiveTableRDD;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.KMeans;
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+import scala.Tuple2;
+
+/**
+ * The Class KMeansAlgo.
+ */
+@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
+public class KMeansAlgo implements MLAlgo {
+
+ /** The conf. */
+ private transient LensConf conf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /** The part filter. */
+ @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
+ private String partFilter = null;
+
+ /** The k. */
+ @AlgoParam(name = "k", help = "Number of cluster")
+ private int k;
+
+ /** The max iterations. */
+ @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
+ private int maxIterations = 100;
+
+ /** The runs. */
+ @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
+ private int runs = 1;
+
+ /** The initialization mode. */
+ @AlgoParam(name = "initializationMode",
+ help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
+ private String initializationMode = "k-means||";
+
+ @Override
+ public String getName() {
+ return getClass().getAnnotation(Algorithm.class).name();
+ }
+
+ @Override
+ public String getDescription() {
+ return getClass().getAnnotation(Algorithm.class).description();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void configure(LensConf configuration) {
+ this.conf = configuration;
+ }
+
+ @Override
+ public LensConf getConf() {
+ return conf;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+ * java.lang.String, java.lang.String[])
+ */
+ @Override
+ public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
+ List<String> features = AlgoArgParser.parseArgs(this, params);
+ final int[] featurePositions = new int[features.size()];
+ final int NUM_FEATURES = features.size();
+
+ JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
+ try {
+ // Map feature names to positions
+ Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
+ List<FieldSchema> allCols = tbl.getAllCols();
+ int f = 0;
+ for (int i = 0; i < tbl.getAllCols().size(); i++) {
+ String colName = allCols.get(i).getName();
+ if (features.contains(colName)) {
+ featurePositions[f++] = i;
+ }
+ }
+
+ rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
+ JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
+ @Override
+ public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
+ HCatRecord hCatRecord = v1._2();
+ double[] arr = new double[NUM_FEATURES];
+ for (int i = 0; i < NUM_FEATURES; i++) {
+ Object val = hCatRecord.get(featurePositions[i]);
+ arr[i] = val == null ? 0d : (Double) val;
+ }
+ return Vectors.dense(arr);
+ }
+ });
+
+ KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
+ return new KMeansClusteringModel(modelId, model);
+ } catch (Exception e) {
+ throw new LensException("KMeans algo failed for " + db + "." + table, e);
+ }
+ }
+
+ /**
+ * To hive conf.
+ *
+ * @param conf the conf
+ * @return the hive conf
+ */
+ private HiveConf toHiveConf(LensConf conf) {
+ HiveConf hiveConf = new HiveConf();
+ for (String key : conf.getProperties().keySet()) {
+ hiveConf.set(key, conf.getProperties().get(key));
+ }
+ return hiveConf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
new file mode 100644
index 0000000..62dc536
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.kmeans;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vectors;
+
+/**
+ * The Class KMeansClusteringModel.
+ */
+public class KMeansClusteringModel extends MLModel<Integer> {
+
+ /** The model. */
+ private final KMeansModel model;
+
+ /** The model id. */
+ private final String modelId;
+
+ /**
+ * Instantiates a new k means clustering model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public KMeansClusteringModel(String modelId, KMeansModel model) {
+ this.model = model;
+ this.modelId = modelId;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
+ */
+ @Override
+ public Integer predict(Object... args) {
+ // Convert the params to array of double
+ double[] arr = new double[args.length];
+ for (int i = 0; i < args.length; i++) {
+ if (args[i] != null) {
+ arr[i] = (Double) args[i];
+ } else {
+ arr[i] = 0d;
+ }
+ }
+
+ return model.predict(Vectors.dense(arr));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
new file mode 100644
index 0000000..55caf59
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.lr;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class LogisticRegressionAlgo.
+ */
+@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
+public class LogisticRegressionAlgo extends BaseSparkAlgo {
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /**
+ * Instantiates a new logistic regression algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public LogisticRegressionAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ iterations = getParamValue("iterations", 100);
+ stepSize = getParamValue("stepSize", 1.0d);
+ minBatchFraction = getParamValue("minBatchFraction", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
+ minBatchFraction);
+ return new LogitRegressionClassificationModel(modelId, lrModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
new file mode 100644
index 0000000..a4206e5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.lr;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+
+/**
+ * The Class LogitRegressionClassificationModel.
+ */
+public class LogitRegressionClassificationModel extends BaseSparkClassificationModel<LogisticRegressionModel> {
+
+ /**
+ * Instantiates a new logit regression classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public LogitRegressionClassificationModel(String modelId, LogisticRegressionModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
new file mode 100644
index 0000000..b4e1e78
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.nb;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayes;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class NaiveBayesAlgo.
+ */
+@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
+public class NaiveBayesAlgo extends BaseSparkAlgo {
+
+ /** The lambda. */
+ @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
+ private double lambda = 1.0;
+
+ /**
+ * Instantiates a new naive bayes algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public NaiveBayesAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ lambda = getParamValue("lambda", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
new file mode 100644
index 0000000..26d39df
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.nb;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayesModel;
+
+/**
+ * The Class NaiveBayesClassificationModel.
+ */
+public class NaiveBayesClassificationModel extends BaseSparkClassificationModel<NaiveBayesModel> {
+
+ /**
+ * Instantiates a new naive bayes classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public NaiveBayesClassificationModel(String modelId, NaiveBayesModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
new file mode 100644
index 0000000..21a036a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.svm;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.classification.SVMWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class SVMAlgo.
+ */
+@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
+public class SVMAlgo extends BaseSparkAlgo {
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /** The reg param. */
+ @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
+ private double regParam;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /**
+ * Instantiates a new SVM algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public SVMAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ minBatchFraction = getParamValue("minBatchFraction", 1.0);
+ regParam = getParamValue("regParam", 1.0);
+ stepSize = getParamValue("stepSize", 1.0);
+ iterations = getParamValue("iterations", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
+ return new SVMClassificationModel(modelId, svmModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
new file mode 100644
index 0000000..433c0f9
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.svm;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+
+/**
+ * The Class SVMClassificationModel.
+ */
+public class SVMClassificationModel extends BaseSparkClassificationModel<SVMModel> {
+
+ /**
+ * Instantiates a new SVM classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public SVMClassificationModel(String modelId, SVMModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
new file mode 100644
index 0000000..e124fb0
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLModel;
+
+/**
+ * Lens's machine learning interface used by client code as well as Lens ML service.
+ */
+public interface LensML {
+
+ /** Name of ML service */
+ String NAME = "ml";
+
+ /**
+ * Get list of available machine learning algorithms
+ *
+ * @return
+ */
+ List<String> getAlgorithms();
+
+ /**
+ * Get user friendly information about parameters accepted by the algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return map of param key to its help message
+ */
+ Map<String, String> getAlgoParamDescription(String algorithm);
+
+ /**
+ * Get a algo object instance which could be used to generate a model of the given algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the algo for name
+ * @throws LensException the lens exception
+ */
+ MLAlgo getAlgoForName(String algorithm) throws LensException;
+
+ /**
+ * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
+ * generate the model.
+ *
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param args the args
+ * @return Unique ID of the model created after training is complete
+ * @throws LensException the lens exception
+ */
+ String train(String table, String algorithm, String[] args) throws LensException;
+
+ /**
+ * Get model IDs for the given algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the models
+ * @throws LensException the lens exception
+ */
+ List<String> getModels(String algorithm) throws LensException;
+
+ /**
+ * Get a model instance given the algorithm name and model ID.
+ *
+ * @param algorithm the algorithm
+ * @param modelId the model id
+ * @return the model
+ * @throws LensException the lens exception
+ */
+ MLModel getModel(String algorithm, String modelId) throws LensException;
+
+ /**
+ * Get the FS location where model instance is saved.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the model path
+ */
+ String getModelPath(String algorithm, String modelID);
+
+ /**
+ * Evaluate model by running it against test data contained in the given table.
+ *
+ * @param session the session
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return Test report object containing test output table, and various evaluation metrics
+ * @throws LensException the lens exception
+ */
+ MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
+ String outputTable) throws LensException;
+
+ /**
+ * Get test reports for an algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the test reports
+ * @throws LensException the lens exception
+ */
+ List<String> getTestReports(String algorithm) throws LensException;
+
+ /**
+ * Get a test report by ID.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the test report
+ * @throws LensException the lens exception
+ */
+ MLTestReport getTestReport(String algorithm, String reportID) throws LensException;
+
+ /**
+ * Online predict call given a model ID, algorithm name and sample feature values.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param features the features
+ * @return prediction result
+ * @throws LensException the lens exception
+ */
+ Object predict(String algorithm, String modelID, Object[] features) throws LensException;
+
+ /**
+ * Permanently delete a model instance.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @throws LensException the lens exception
+ */
+ void deleteModel(String algorithm, String modelID) throws LensException;
+
+ /**
+ * Permanently delete a test report instance.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @throws LensException the lens exception
+ */
+ void deleteTestReport(String algorithm, String reportID) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
new file mode 100644
index 0000000..965161a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.io.Serializable;
+import java.util.List;
+
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+import lombok.ToString;
+
+/**
+ * Instantiates a new ML test report.
+ */
+@NoArgsConstructor
+@ToString
+public class MLTestReport implements Serializable {
+
+ /** The test table. */
+ @Getter
+ @Setter
+ private String testTable;
+
+ /** The output table. */
+ @Getter
+ @Setter
+ private String outputTable;
+
+ /** The output column. */
+ @Getter
+ @Setter
+ private String outputColumn;
+
+ /** The label column. */
+ @Getter
+ @Setter
+ private String labelColumn;
+
+ /** The feature columns. */
+ @Getter
+ @Setter
+ private List<String> featureColumns;
+
+ /** The algorithm. */
+ @Getter
+ @Setter
+ private String algorithm;
+
+ /** The model id. */
+ @Getter
+ @Setter
+ private String modelID;
+
+ /** The report id. */
+ @Getter
+ @Setter
+ private String reportID;
+
+ /** The query id. */
+ @Getter
+ @Setter
+ private String queryID;
+
+ /** The test output path. */
+ @Getter
+ @Setter
+ private String testOutputPath;
+
+ /** The prediction result column. */
+ @Getter
+ @Setter
+ private String predictionResultColumn;
+
+ /** The lens query id. */
+ @Getter
+ @Setter
+ private String lensQueryID;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
new file mode 100644
index 0000000..3f7dff1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+/**
+ * The Class ModelMetadata.
+ */
+@XmlRootElement
+/**
+ * Instantiates a new model metadata.
+ *
+ * @param modelID
+ * the model id
+ * @param table
+ * the table
+ * @param algorithm
+ * the algorithm
+ * @param params
+ * the params
+ * @param createdAt
+ * the created at
+ * @param modelPath
+ * the model path
+ * @param labelColumn
+ * the label column
+ * @param features
+ * the features
+ */
+@AllArgsConstructor
+/**
+ * Instantiates a new model metadata.
+ */
+@NoArgsConstructor
+public class ModelMetadata {
+
+ /** The model id. */
+ @XmlElement
+ @Getter
+ private String modelID;
+
+ /** The table. */
+ @XmlElement
+ @Getter
+ private String table;
+
+ /** The algorithm. */
+ @XmlElement
+ @Getter
+ private String algorithm;
+
+ /** The params. */
+ @XmlElement
+ @Getter
+ private String params;
+
+ /** The created at. */
+ @XmlElement
+ @Getter
+ private String createdAt;
+
+ /** The model path. */
+ @XmlElement
+ @Getter
+ private String modelPath;
+
+ /** The label column. */
+ @XmlElement
+ @Getter
+ private String labelColumn;
+
+ /** The features. */
+ @XmlElement
+ @Getter
+ private String features;
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.lang.Object#toString()
+ */
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder();
+
+ builder.append("Algorithm: ").append(algorithm).append('\n');
+ builder.append("Model ID: ").append(modelID).append('\n');
+ builder.append("Training table: ").append(table).append('\n');
+ builder.append("Features: ").append(features).append('\n');
+ builder.append("Labelled Column: ").append(labelColumn).append('\n');
+ builder.append("Training params: ").append(params).append('\n');
+ builder.append("Created on: ").append(createdAt).append('\n');
+ builder.append("Model saved at: ").append(modelPath).append('\n');
+ return builder.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
new file mode 100644
index 0000000..294fef3
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+/**
+ * The Class TestReport.
+ */
+@XmlRootElement
+/**
+ * Instantiates a new test report.
+ *
+ * @param testTable
+ * the test table
+ * @param outputTable
+ * the output table
+ * @param outputColumn
+ * the output column
+ * @param labelColumn
+ * the label column
+ * @param featureColumns
+ * the feature columns
+ * @param algorithm
+ * the algorithm
+ * @param modelID
+ * the model id
+ * @param reportID
+ * the report id
+ * @param queryID
+ * the query id
+ */
+@AllArgsConstructor
+/**
+ * Instantiates a new test report.
+ */
+@NoArgsConstructor
+public class TestReport {
+
+ /** The test table. */
+ @XmlElement
+ @Getter
+ private String testTable;
+
+ /** The output table. */
+ @XmlElement
+ @Getter
+ private String outputTable;
+
+ /** The output column. */
+ @XmlElement
+ @Getter
+ private String outputColumn;
+
+ /** The label column. */
+ @XmlElement
+ @Getter
+ private String labelColumn;
+
+ /** The feature columns. */
+ @XmlElement
+ @Getter
+ private String featureColumns;
+
+ /** The algorithm. */
+ @XmlElement
+ @Getter
+ private String algorithm;
+
+ /** The model id. */
+ @XmlElement
+ @Getter
+ private String modelID;
+
+ /** The report id. */
+ @XmlElement
+ @Getter
+ private String reportID;
+
+ /** The query id. */
+ @XmlElement
+ @Getter
+ private String queryID;
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.lang.Object#toString()
+ */
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder();
+ builder.append("Input test table: ").append(testTable).append('\n');
+ builder.append("Algorithm: ").append(algorithm).append('\n');
+ builder.append("Report id: ").append(reportID).append('\n');
+ builder.append("Model id: ").append(modelID).append('\n');
+ builder.append("Lens Query id: ").append(queryID).append('\n');
+ builder.append("Feature columns: ").append(featureColumns).append('\n');
+ builder.append("Labelled column: ").append(labelColumn).append('\n');
+ builder.append("Predicted column: ").append(outputColumn).append('\n');
+ builder.append("Test output table: ").append(outputTable).append('\n');
+ return builder.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
index 5e4d307..d444a32 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
@@ -18,9 +18,9 @@
*/
package org.apache.lens.ml.dao;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.task.MLTask;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.impl.MLTask;
public class MLDBUtils {
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
new file mode 100644
index 0000000..60a4008
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
@@ -0,0 +1,138 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.MapredContext;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.lazy.LazyDouble;
+import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyDoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.mapred.JobConf;
+
+/**
+ * Generic UDF to laod ML Models saved in HDFS and apply the model on list of columns passed as argument.
+ */
+@Description(name = "predict",
+ value = "_FUNC_(algorithm, modelID, features...) - Run prediction algorithm with given "
+ + "algorithm name, model ID and input feature columns")
+public final class HiveMLUDF extends GenericUDF {
+ private HiveMLUDF() {
+ }
+
+ /** The Constant UDF_NAME. */
+ public static final String UDF_NAME = "predict";
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(HiveMLUDF.class);
+
+ /** The conf. */
+ private JobConf conf;
+
+ /** The soi. */
+ private StringObjectInspector soi;
+
+ /** The doi. */
+ private LazyDoubleObjectInspector doi;
+
+ /** The model. */
+ private MLModel model;
+
+ /**
+ * Currently we only support double as the return value.
+ *
+ * @param objectInspectors the object inspectors
+ * @return the object inspector
+ * @throws UDFArgumentException the UDF argument exception
+ */
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] objectInspectors) throws UDFArgumentException {
+ // We require algo name, model id and at least one feature
+ if (objectInspectors.length < 3) {
+ throw new UDFArgumentLengthException("Algo name, model ID and at least one feature should be passed to "
+ + UDF_NAME);
+ }
+ LOG.info(UDF_NAME + " initialized");
+ return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#evaluate(org.apache.hadoop.hive.ql.udf.generic.GenericUDF.
+ * DeferredObject[])
+ */
+ @Override
+ public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
+ String algorithm = soi.getPrimitiveJavaObject(deferredObjects[0].get());
+ String modelId = soi.getPrimitiveJavaObject(deferredObjects[1].get());
+
+ Double[] features = new Double[deferredObjects.length - 2];
+ for (int i = 2; i < deferredObjects.length; i++) {
+ LazyDouble lazyDouble = (LazyDouble) deferredObjects[i].get();
+ features[i - 2] = (lazyDouble == null) ? 0d : doi.get(lazyDouble);
+ }
+
+ try {
+ if (model == null) {
+ model = ModelLoader.loadModel(conf, algorithm, modelId);
+ }
+ } catch (IOException e) {
+ throw new HiveException(e);
+ }
+
+ return model.predict(features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#getDisplayString(java.lang.String[])
+ */
+ @Override
+ public String getDisplayString(String[] strings) {
+ return UDF_NAME;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#configure(org.apache.hadoop.hive.ql.exec.MapredContext)
+ */
+ @Override
+ public void configure(MapredContext context) {
+ super.configure(context);
+ conf = context.getJobConf();
+ soi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ doi = LazyPrimitiveObjectInspectorFactory.LAZY_DOUBLE_OBJECT_INSPECTOR;
+ LOG.info(UDF_NAME + " configured. Model base dir path: " + conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
+ }
+}