You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by sh...@apache.org on 2015/04/05 09:11:02 UTC
[1/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Repository: incubator-lens
Updated Branches:
refs/heads/master 278e0e857 -> 0f5ea4c78
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/LogitRegressionClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/LogitRegressionClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/LogitRegressionClassificationModel.java
deleted file mode 100644
index 1c5152b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/LogitRegressionClassificationModel.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-
-/**
- * The Class LogitRegressionClassificationModel.
- */
-public class LogitRegressionClassificationModel extends BaseSparkClassificationModel<LogisticRegressionModel> {
-
- /**
- * Instantiates a new logit regression classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public LogitRegressionClassificationModel(String modelId, LogisticRegressionModel model) {
- super(modelId, model);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/NaiveBayesClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/NaiveBayesClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/NaiveBayesClassificationModel.java
deleted file mode 100644
index 8f4552c..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/NaiveBayesClassificationModel.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.spark.mllib.classification.NaiveBayesModel;
-
-/**
- * The Class NaiveBayesClassificationModel.
- */
-public class NaiveBayesClassificationModel extends BaseSparkClassificationModel<NaiveBayesModel> {
-
- /**
- * Instantiates a new naive bayes classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public NaiveBayesClassificationModel(String modelId, NaiveBayesModel model) {
- super(modelId, model);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SVMClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SVMClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SVMClassificationModel.java
deleted file mode 100644
index 4e504fb..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SVMClassificationModel.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.spark.mllib.classification.SVMModel;
-
-/**
- * The Class SVMClassificationModel.
- */
-public class SVMClassificationModel extends BaseSparkClassificationModel<SVMModel> {
-
- /**
- * Instantiates a new SVM classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public SVMClassificationModel(String modelId, SVMModel model) {
- super(modelId, model);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SparkDecisionTreeModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SparkDecisionTreeModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SparkDecisionTreeModel.java
deleted file mode 100644
index 657070b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/SparkDecisionTreeModel.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.lens.ml.spark.DoubleValueMapper;
-
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.mllib.classification.ClassificationModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.rdd.RDD;
-
-/**
- * This class is created because the Spark decision tree model doesn't extend ClassificationModel.
- */
-public class SparkDecisionTreeModel implements ClassificationModel {
-
- /** The model. */
- private final DecisionTreeModel model;
-
- /**
- * Instantiates a new spark decision tree model.
- *
- * @param model the model
- */
- public SparkDecisionTreeModel(DecisionTreeModel model) {
- this.model = model;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.rdd.RDD)
- */
- @Override
- public RDD<Object> predict(RDD<Vector> testData) {
- return model.predict(testData);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.mllib.linalg.Vector)
- */
- @Override
- public double predict(Vector testData) {
- return model.predict(testData);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.api.java.JavaRDD)
- */
- @Override
- public JavaRDD<Double> predict(JavaRDD<Vector> testData) {
- return model.predict(testData.rdd()).toJavaRDD().map(new DoubleValueMapper());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
deleted file mode 100644
index e4bb329..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/task/MLTask.java
+++ /dev/null
@@ -1,286 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.task;
-
-import java.util.*;
-
-import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.LensML;
-import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.MLUtils;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-
-import lombok.Getter;
-import lombok.ToString;
-
-/**
- * Run a complete cycle of train and test (evaluation) for an ML algorithm
- */
-@ToString
-public class MLTask implements Runnable {
- private static final Log LOG = LogFactory.getLog(MLTask.class);
-
- public enum State {
- RUNNING, SUCCESSFUL, FAILED
- }
-
- @Getter
- private State taskState;
-
- /**
- * Name of the algo/algorithm.
- */
- @Getter
- private String algorithm;
-
- /**
- * Name of the table containing training data.
- */
- @Getter
- private String trainingTable;
-
- /**
- * Name of the table containing test data. Optional, if not provided trainingTable itself is
- * used for testing
- */
- @Getter
- private String testTable;
-
- /**
- * Training table partition spec
- */
- @Getter
- private String partitionSpec;
-
- /**
- * Name of the column which is a label for supervised algorithms.
- */
- @Getter
- private String labelColumn;
-
- /**
- * Names of columns which are features in the training data.
- */
- @Getter
- private List<String> featureColumns;
-
- /**
- * Configuration for the example.
- */
- @Getter
- private HiveConf configuration;
-
- private LensML ml;
- private String taskID;
-
- /**
- * ml client
- */
- @Getter
- private LensMLClient mlClient;
-
- /**
- * Output table name
- */
- @Getter
- private String outputTable;
-
- /**
- * Extra params passed to the training algorithm
- */
- @Getter
- private Map<String, String> extraParams;
-
- @Getter
- private String modelID;
-
- @Getter
- private String reportID;
-
- /**
- * Use ExampleTask.Builder to create an instance
- */
- private MLTask() {
- // Use builder to construct the example
- extraParams = new HashMap<String, String>();
- taskID = UUID.randomUUID().toString();
- }
-
- /**
- * Builder to create an example task
- */
- public static class Builder {
- private MLTask task;
-
- public Builder() {
- task = new MLTask();
- }
-
- public Builder trainingTable(String trainingTable) {
- task.trainingTable = trainingTable;
- return this;
- }
-
- public Builder testTable(String testTable) {
- task.testTable = testTable;
- return this;
- }
-
- public Builder algorithm(String algorithm) {
- task.algorithm = algorithm;
- return this;
- }
-
- public Builder labelColumn(String labelColumn) {
- task.labelColumn = labelColumn;
- return this;
- }
-
- public Builder client(LensMLClient client) {
- task.mlClient = client;
- return this;
- }
-
- public Builder addFeatureColumn(String featureColumn) {
- if (task.featureColumns == null) {
- task.featureColumns = new ArrayList<String>();
- }
- task.featureColumns.add(featureColumn);
- return this;
- }
-
- public Builder hiveConf(HiveConf hiveConf) {
- task.configuration = hiveConf;
- return this;
- }
-
-
-
- public Builder extraParam(String param, String value) {
- task.extraParams.put(param, value);
- return this;
- }
-
- public Builder partitionSpec(String partitionSpec) {
- task.partitionSpec = partitionSpec;
- return this;
- }
-
- public Builder outputTable(String outputTable) {
- task.outputTable = outputTable;
- return this;
- }
-
- public MLTask build() {
- MLTask builtTask = task;
- task = null;
- return builtTask;
- }
-
- }
-
- @Override
- public void run() {
- taskState = State.RUNNING;
- LOG.info("Starting " + taskID);
- try {
- runTask();
- taskState = State.SUCCESSFUL;
- LOG.info("Complete " + taskID);
- } catch (Exception e) {
- taskState = State.FAILED;
- LOG.info("Error running task " + taskID, e);
- }
- }
-
- /**
- * Train an ML model, with specified algorithm and input data. Do model evaluation using the evaluation data and print
- * evaluation result
- *
- * @throws Exception
- */
- private void runTask() throws Exception {
- if (mlClient != null) {
- // Connect to a remote Lens server
- ml = mlClient;
- LOG.info("Working in client mode. Lens session handle " + mlClient.getSessionHandle().getPublicId());
- } else {
- // In server mode session handle has to be passed by the user as a request parameter
- ml = MLUtils.getMLService();
- LOG.info("Working in Lens server");
- }
-
- String[] algoArgs = buildTrainingArgs();
- LOG.info("Starting task " + taskID + " algo args: " + Arrays.toString(algoArgs));
-
- modelID = ml.train(trainingTable, algorithm, algoArgs);
- printModelMetadata(taskID, modelID);
-
- LOG.info("Starting test " + taskID);
- testTable = (testTable != null) ? testTable : trainingTable;
- MLTestReport testReport = ml.testModel(mlClient.getSessionHandle(), testTable, algorithm, modelID, outputTable);
- reportID = testReport.getReportID();
- printTestReport(taskID, testReport);
- saveTask();
- }
-
- // Save task metadata to DB
- private void saveTask() {
- LOG.info("Saving task details to DB");
- }
-
- private void printTestReport(String exampleID, MLTestReport testReport) {
- StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
- builder.append("\n\t");
- builder.append("EvaluationReport: ").append(testReport.toString());
- System.out.println(builder.toString());
- }
-
- private String[] buildTrainingArgs() {
- List<String> argList = new ArrayList<String>();
- argList.add("label");
- argList.add(labelColumn);
-
- // Add all the features
- for (String featureCol : featureColumns) {
- argList.add("feature");
- argList.add(featureCol);
- }
-
- // Add extra params
- for (String param : extraParams.keySet()) {
- argList.add(param);
- argList.add(extraParams.get(param));
- }
-
- return argList.toArray(new String[argList.size()]);
- }
-
- // Get the model instance and print its metadat to stdout
- private void printModelMetadata(String exampleID, String modelID) throws Exception {
- StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
- builder.append("\n\t");
- builder.append("Model: ");
- builder.append(ml.getModel(algorithm, modelID).toString());
- System.out.println(builder.toString());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java b/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
index 2c2d28b..ac89eee 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
@@ -29,7 +29,7 @@ import org.apache.lens.api.LensException;
import org.apache.lens.api.query.*;
import org.apache.lens.client.LensClient;
import org.apache.lens.client.LensClientResultSet;
-import org.apache.lens.ml.spark.HiveTableRDD;
+import org.apache.lens.ml.algo.spark.HiveTableRDD;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLApp.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLApp.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLApp.java
deleted file mode 100644
index 75d4f03..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLApp.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.server.ml;
-
-import java.util.HashSet;
-import java.util.Set;
-
-import javax.ws.rs.ApplicationPath;
-import javax.ws.rs.core.Application;
-
-import org.glassfish.jersey.filter.LoggingFilter;
-import org.glassfish.jersey.media.multipart.MultiPartFeature;
-
-@ApplicationPath("/ml")
-public class MLApp extends Application {
-
- private final Set<Class<?>> classes;
-
- /**
- * Pass additional classes when running in test mode
- *
- * @param additionalClasses
- */
- public MLApp(Class<?>... additionalClasses) {
- classes = new HashSet<Class<?>>();
-
- // register root resource
- classes.add(MLServiceResource.class);
- classes.add(MultiPartFeature.class);
- classes.add(LoggingFilter.class);
- for (Class<?> cls : additionalClasses) {
- classes.add(cls);
- }
-
- }
-
- /**
- * Get classes for this resource
- */
- @Override
- public Set<Class<?>> getClasses() {
- return classes;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLService.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLService.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLService.java
deleted file mode 100644
index 0dac605..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLService.java
+++ /dev/null
@@ -1,27 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.server.ml;
-
-import org.apache.lens.ml.LensML;
-
-/**
- * The Interface MLService.
- */
-public interface MLService extends LensML {
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
deleted file mode 100644
index 0e8e9aa..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceImpl.java
+++ /dev/null
@@ -1,324 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.server.ml;
-
-import java.util.List;
-import java.util.Map;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.query.LensQuery;
-import org.apache.lens.api.query.QueryHandle;
-import org.apache.lens.api.query.QueryStatus;
-import org.apache.lens.ml.*;
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.api.ServiceProvider;
-import org.apache.lens.server.api.ServiceProviderFactory;
-import org.apache.lens.server.api.query.QueryExecutionService;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
-import org.apache.hive.service.CompositeService;
-
-/**
- * The Class MLServiceImpl.
- */
-public class MLServiceImpl extends CompositeService implements MLService {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
-
- /** The ml. */
- private LensMLImpl ml;
-
- /** The service provider. */
- private ServiceProvider serviceProvider;
-
- /** The service provider factory. */
- private ServiceProviderFactory serviceProviderFactory;
-
- /**
- * Instantiates a new ML service impl.
- */
- public MLServiceImpl() {
- this(NAME);
- }
-
- /**
- * Instantiates a new ML service impl.
- *
- * @param name the name
- */
- public MLServiceImpl(String name) {
- super(name);
- }
-
- @Override
- public List<String> getAlgorithms() {
- return ml.getAlgorithms();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
- */
- @Override
- public MLAlgo getAlgoForName(String algorithm) throws LensException {
- return ml.getAlgoForName(algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
- */
- @Override
- public String train(String table, String algorithm, String[] args) throws LensException {
- return ml.train(table, algorithm, args);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
- */
- @Override
- public List<String> getModels(String algorithm) throws LensException {
- return ml.getModels(algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
- */
- @Override
- public MLModel getModel(String algorithm, String modelId) throws LensException {
- return ml.getModel(algorithm, modelId);
- }
-
- private ServiceProvider getServiceProvider() {
- if (serviceProvider == null) {
- serviceProvider = serviceProviderFactory.getServiceProvider();
- }
- return serviceProvider;
- }
-
- /**
- * Gets the service provider factory.
- *
- * @param conf the conf
- * @return the service provider factory
- */
- private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
- Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
- try {
- return (ServiceProviderFactory) spfClass.newInstance();
- } catch (InstantiationException e) {
- throw new RuntimeException(e);
- } catch (IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hive.service.CompositeService#init(org.apache.hadoop.hive.conf.HiveConf)
- */
- @Override
- public synchronized void init(HiveConf hiveConf) {
- ml = new LensMLImpl(hiveConf);
- ml.init(hiveConf);
- super.init(hiveConf);
- serviceProviderFactory = getServiceProviderFactory(hiveConf);
- LOG.info("Inited ML service");
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hive.service.CompositeService#start()
- */
- @Override
- public synchronized void start() {
- ml.start();
- super.start();
- LOG.info("Started ML service");
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hive.service.CompositeService#stop()
- */
- @Override
- public synchronized void stop() {
- ml.stop();
- super.stop();
- LOG.info("Stopped ML service");
- }
-
- /**
- * Clear models.
- */
- public void clearModels() {
- ModelLoader.clearCache();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
- */
- @Override
- public String getModelPath(String algorithm, String modelID) {
- return ml.getModelPath(algorithm, modelID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
- * java.lang.String)
- */
- @Override
- public MLTestReport testModel(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- String outputTable) throws LensException {
- return ml.testModel(sessionHandle, table, algorithm, modelID, new DirectQueryRunner(sessionHandle), outputTable);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
- */
- @Override
- public List<String> getTestReports(String algorithm) throws LensException {
- return ml.getTestReports(algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
- */
- @Override
- public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
- return ml.getTestReport(algorithm, reportID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
- */
- @Override
- public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
- return ml.predict(algorithm, modelID, features);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
- */
- @Override
- public void deleteModel(String algorithm, String modelID) throws LensException {
- ml.deleteModel(algorithm, modelID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
- */
- @Override
- public void deleteTestReport(String algorithm, String reportID) throws LensException {
- ml.deleteTestReport(algorithm, reportID);
- }
-
- /**
- * Run the test model query directly in the current lens server process.
- */
- private class DirectQueryRunner extends QueryRunner {
-
- /**
- * Instantiates a new direct query runner.
- *
- * @param sessionHandle the session handle
- */
- public DirectQueryRunner(LensSessionHandle sessionHandle) {
- super(sessionHandle);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
- */
- @Override
- public QueryHandle runQuery(String testQuery) throws LensException {
- FunctionRegistry.registerTemporaryFunction("predict", HiveMLUDF.class);
- LOG.info("Registered predict UDF");
- // Run the query in query executions service
- QueryExecutionService queryService = (QueryExecutionService) getServiceProvider().getService("query");
-
- LensConf queryConf = new LensConf();
- queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
- queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
-
- QueryHandle testQueryHandle = queryService.executeAsync(sessionHandle, testQuery, queryConf, queryName);
-
- // Wait for test query to complete
- LensQuery query = queryService.getQuery(sessionHandle, testQueryHandle);
- LOG.info("Submitted query " + testQueryHandle.getHandleId());
- while (!query.getStatus().isFinished()) {
- try {
- Thread.sleep(500);
- } catch (InterruptedException e) {
- throw new LensException(e);
- }
-
- query = queryService.getQuery(sessionHandle, testQueryHandle);
- }
-
- if (query.getStatus().getStatus() != QueryStatus.Status.SUCCESSFUL) {
- throw new LensException("Failed to run test query: " + testQueryHandle.getHandleId() + " reason= "
- + query.getStatus().getErrorMessage());
- }
-
- return testQueryHandle;
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
- */
- @Override
- public Map<String, String> getAlgoParamDescription(String algorithm) {
- return ml.getAlgoParamDescription(algorithm);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
deleted file mode 100644
index c0b32d3..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/server/ml/MLServiceResource.java
+++ /dev/null
@@ -1,415 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.server.ml;
-
-import static org.apache.commons.lang.StringUtils.isBlank;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import javax.ws.rs.*;
-import javax.ws.rs.core.*;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.StringList;
-import org.apache.lens.api.ml.ModelMetadata;
-import org.apache.lens.api.ml.TestReport;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.ModelLoader;
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.api.ServiceProvider;
-import org.apache.lens.server.api.ServiceProviderFactory;
-
-import org.apache.commons.lang.StringUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-
-import org.glassfish.jersey.media.multipart.FormDataParam;
-
-/**
- * Machine Learning service.
- */
-@Path("/ml")
-@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
-public class MLServiceResource {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(MLServiceResource.class);
-
- /** The ml service. */
- MLService mlService;
-
- /** The service provider. */
- ServiceProvider serviceProvider;
-
- /** The service provider factory. */
- ServiceProviderFactory serviceProviderFactory;
-
- private static final HiveConf HIVE_CONF;
-
- /**
- * Message indicating if ML service is up
- */
- public static final String ML_UP_MESSAGE = "ML service is up";
-
- static {
- HIVE_CONF = new HiveConf();
- // Add default config so that we know the service provider implementation
- HIVE_CONF.addResource("lensserver-default.xml");
- HIVE_CONF.addResource("lens-site.xml");
- }
-
- /**
- * Instantiates a new ML service resource.
- */
- public MLServiceResource() {
- serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
- }
-
- private ServiceProvider getServiceProvider() {
- if (serviceProvider == null) {
- serviceProvider = serviceProviderFactory.getServiceProvider();
- }
- return serviceProvider;
- }
-
- /**
- * Gets the service provider factory.
- *
- * @param conf the conf
- * @return the service provider factory
- */
- private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
- Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
- try {
- return (ServiceProviderFactory) spfClass.newInstance();
- } catch (InstantiationException e) {
- throw new RuntimeException(e);
- } catch (IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- }
-
- private MLService getMlService() {
- if (mlService == null) {
- mlService = (MLService) getServiceProvider().getService(MLService.NAME);
- }
- return mlService;
- }
-
- /**
- * Indicates if ML resource is up
- *
- * @return
- */
- @GET
- public String mlResourceUp() {
- return ML_UP_MESSAGE;
- }
-
- /**
- * Get a list of algos available
- *
- * @return
- */
- @GET
- @Path("algos")
- public StringList getAlgoNames() {
- List<String> algos = getMlService().getAlgorithms();
- StringList result = new StringList(algos);
- return result;
- }
-
- /**
- * Gets the human readable param description of an algorithm
- *
- * @param algorithm the algorithm
- * @return the param description
- */
- @GET
- @Path("algos/{algorithm}")
- public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
- Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm);
- if (paramDesc == null) {
- throw new NotFoundException("Param description not found for " + algorithm);
- }
-
- List<String> descriptions = new ArrayList<String>();
- for (String key : paramDesc.keySet()) {
- descriptions.add(key + " : " + paramDesc.get(key));
- }
- return new StringList(descriptions);
- }
-
- /**
- * Get model ID list for a given algorithm.
- *
- * @param algorithm algorithm name
- * @return the models for algo
- * @throws LensException the lens exception
- */
- @GET
- @Path("models/{algorithm}")
- public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
- List<String> models = getMlService().getModels(algorithm);
- if (models == null || models.isEmpty()) {
- throw new NotFoundException("No models found for algorithm " + algorithm);
- }
- return new StringList(models);
- }
-
- /**
- * Get metadata of the model given algorithm and model ID.
- *
- * @param algorithm algorithm name
- * @param modelID model ID
- * @return model metadata
- * @throws LensException the lens exception
- */
- @GET
- @Path("models/{algorithm}/{modelID}")
- public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
- throws LensException {
- MLModel model = getMlService().getModel(algorithm, modelID);
- if (model == null) {
- throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
- }
-
- ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
- model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
- model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
- return meta;
- }
-
- /**
- * Delete a model given model ID and algorithm name.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return confirmation text
- * @throws LensException the lens exception
- */
- @DELETE
- @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
- @Path("models/{algorithm}/{modelID}")
- public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
- throws LensException {
- getMlService().deleteModel(algorithm, modelID);
- return "DELETED model=" + modelID + " algorithm=" + algorithm;
- }
-
- /**
- * Train a model given an algorithm name and algorithm parameters
- * <p>
- * Following parameters are mandatory and must be passed as part of the form
- * <p/>
- * <ol>
- * <li>table - input Hive table to load training data from</li>
- * <li>label - name of the labelled column</li>
- * <li>feature - one entry per feature column. At least one feature column is required</li>
- * </ol>
- * <p/>
- * </p>
- *
- * @param algorithm algorithm name
- * @param form form data
- * @return if model is successfully trained, the model ID will be returned
- * @throws LensException the lens exception
- */
- @POST
- @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
- @Path("{algorithm}/train")
- public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form)
- throws LensException {
-
- // Check if algo is valid
- if (getMlService().getAlgoForName(algorithm) == null) {
- throw new NotFoundException("Algo for algo: " + algorithm + " not found");
- }
-
- if (isBlank(form.getFirst("table"))) {
- throw new BadRequestException("table parameter is rquired");
- }
-
- String table = form.getFirst("table");
-
- if (isBlank(form.getFirst("label"))) {
- throw new BadRequestException("label parameter is required");
- }
-
- // Check features
- List<String> featureNames = form.get("feature");
- if (featureNames.size() < 1) {
- throw new BadRequestException("At least one feature is required");
- }
-
- List<String> algoArgs = new ArrayList<String>();
- Set<Map.Entry<String, List<String>>> paramSet = form.entrySet();
-
- for (Map.Entry<String, List<String>> e : paramSet) {
- String p = e.getKey();
- List<String> values = e.getValue();
- if ("algorithm".equals(p) || "table".equals(p)) {
- continue;
- } else if ("feature".equals(p)) {
- for (String feature : values) {
- algoArgs.add("feature");
- algoArgs.add(feature);
- }
- } else if ("label".equals(p)) {
- algoArgs.add("label");
- algoArgs.add(values.get(0));
- } else {
- algoArgs.add(p);
- algoArgs.add(values.get(0));
- }
- }
- LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString());
- String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
- LOG.info("Done training " + table + " modelid = " + modelId);
- return modelId;
- }
-
- /**
- * Clear model cache (for admin use).
- *
- * @return OK if the cache was cleared
- */
- @DELETE
- @Path("clearModelCache")
- @Produces(MediaType.TEXT_PLAIN)
- public Response clearModelCache() {
- ModelLoader.clearCache();
- LOG.info("Cleared model cache");
- return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
- }
-
- /**
- * Run a test on a model for an algorithm.
- *
- * @param algorithm algorithm name
- * @param modelID model ID
- * @param table Hive table to run test on
- * @param session Lens session ID. This session ID will be used to run the test query
- * @return Test report ID
- * @throws LensException the lens exception
- */
- @POST
- @Path("test/{table}/{algorithm}/{modelID}")
- @Consumes(MediaType.MULTIPART_FORM_DATA)
- public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
- @PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session,
- @FormDataParam("outputTable") String outputTable) throws LensException {
- MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable);
- return testReport.getReportID();
- }
-
- /**
- * Get list of reports for a given algorithm.
- *
- * @param algoritm the algoritm
- * @return the reports for algorithm
- * @throws LensException the lens exception
- */
- @GET
- @Path("reports/{algorithm}")
- public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException {
- List<String> reports = getMlService().getTestReports(algoritm);
- if (reports == null || reports.isEmpty()) {
- throw new NotFoundException("No test reports found for " + algoritm);
- }
- return new StringList(reports);
- }
-
- /**
- * Get a single test report given the algorithm name and report id.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- * @throws LensException the lens exception
- */
- @GET
- @Path("reports/{algorithm}/{reportID}")
- public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
- throws LensException {
- MLTestReport report = getMlService().getTestReport(algorithm, reportID);
-
- if (report == null) {
- throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm);
- }
-
- TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(),
- report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(),
- report.getModelID(), report.getReportID(), report.getLensQueryID());
- return result;
- }
-
- /**
- * DELETE a report given the algorithm name and report ID.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the string
- * @throws LensException the lens exception
- */
- @DELETE
- @Path("reports/{algorithm}/{reportID}")
- @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
- public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
- throws LensException {
- getMlService().deleteTestReport(algorithm, reportID);
- return "DELETED report=" + reportID + " algorithm=" + algorithm;
- }
-
- /**
- * Predict.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param uriInfo the uri info
- * @return the string
- * @throws LensException the lens exception
- */
- @GET
- @Path("/predict/{algorithm}/{modelID}")
- @Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON})
- public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
- @Context UriInfo uriInfo) throws LensException {
- // Load the model instance
- MLModel<?> model = getMlService().getModel(algorithm, modelID);
-
- // Get input feature names
- MultivaluedMap<String, String> params = uriInfo.getQueryParameters();
- String[] features = new String[model.getFeatureColumns().size()];
- // Assuming that feature name parameters are same
- int i = 0;
- for (String feature : model.getFeatureColumns()) {
- features[i++] = params.getFirst(feature);
- }
-
- // TODO needs a 'prediction formatter'
- return getMlService().predict(algorithm, modelID, features).toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
new file mode 100644
index 0000000..9fe1ea0
--- /dev/null
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.TableType;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
+import org.apache.hadoop.hive.serde.serdeConstants;
+import org.apache.hadoop.mapred.TextInputFormat;
+
+/**
+ * The Class ExampleUtils.
+ */
+public final class ExampleUtils {
+ private ExampleUtils() {
+ }
+
+ private static final Log LOG = LogFactory.getLog(ExampleUtils.class);
+
+ /**
+ * Creates the example table.
+ *
+ * @param conf the conf
+ * @param database the database
+ * @param tableName the table name
+ * @param sampleDataFile the sample data file
+ * @param labelColumn the label column
+ * @param features the features
+ * @throws HiveException the hive exception
+ */
+ public static void createTable(HiveConf conf, String database, String tableName, String sampleDataFile,
+ String labelColumn, Map<String, String> tableParams, String... features) throws HiveException {
+
+ Path dataFilePath = new Path(sampleDataFile);
+ Path partDir = dataFilePath.getParent();
+
+ // Create table
+ List<FieldSchema> columns = new ArrayList<FieldSchema>();
+
+ // Label is optional. Not used for unsupervised models.
+ // If present, label will be the first column, followed by features
+ if (labelColumn != null) {
+ columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
+ }
+
+ for (String feature : features) {
+ columns.add(new FieldSchema(feature, "double", "Feature " + feature));
+ }
+
+ Table tbl = Hive.get(conf).newTable(database + "." + tableName);
+ tbl.setTableType(TableType.MANAGED_TABLE);
+ tbl.getTTable().getSd().setCols(columns);
+ tbl.getTTable().getParameters().putAll(tableParams);
+ tbl.setInputFormatClass(TextInputFormat.class);
+ tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
+ tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
+
+ List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
+ partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
+ tbl.setPartCols(partCols);
+
+ Hive.get(conf).createTable(tbl, false);
+ LOG.info("Created table " + tableName);
+
+ // Add partition for the data file
+ AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName, false);
+ Map<String, String> partSpec = new HashMap<String, String>();
+ partSpec.put("dummy_partition_col", "dummy_val");
+ partitionDesc.addPartition(partSpec, partDir.toUri().toString());
+ Hive.get(conf).createPartitions(partitionDesc);
+ LOG.info(tableName + ": Added partition " + partDir.toUri().toString());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
index f712481..8b7e3f3 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
@@ -33,15 +33,16 @@ import javax.ws.rs.core.UriBuilder;
import org.apache.lens.client.LensClient;
import org.apache.lens.client.LensClientConfig;
import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.spark.algos.DecisionTreeAlgo;
-import org.apache.lens.ml.spark.algos.LogisticRegressionAlgo;
-import org.apache.lens.ml.spark.algos.NaiveBayesAlgo;
-import org.apache.lens.ml.spark.algos.SVMAlgo;
-import org.apache.lens.ml.task.MLTask;
+import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
+import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
+import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
+import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
+import org.apache.lens.ml.impl.MLTask;
+import org.apache.lens.ml.impl.MLUtils;
+import org.apache.lens.ml.server.MLApp;
+import org.apache.lens.ml.server.MLServiceResource;
import org.apache.lens.server.LensJerseyTest;
import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.ml.MLApp;
-import org.apache.lens.server.ml.MLServiceResource;
import org.apache.lens.server.query.QueryServiceResource;
import org.apache.lens.server.session.SessionResource;
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
index d7f2f8f..655b55e 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
@@ -26,23 +26,24 @@ import javax.ws.rs.core.UriBuilder;
import org.apache.lens.client.LensClient;
import org.apache.lens.client.LensClientConfig;
import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.task.MLTask;
+import org.apache.lens.ml.impl.MLRunner;
+import org.apache.lens.ml.impl.MLTask;
+import org.apache.lens.ml.server.MLApp;
import org.apache.lens.server.LensJerseyTest;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.metastore.MetastoreResource;
-import org.apache.lens.server.ml.MLApp;
import org.apache.lens.server.query.QueryServiceResource;
import org.apache.lens.server.session.SessionResource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.Database;
import org.apache.hadoop.hive.ql.metadata.Hive;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.media.multipart.MultiPartFeature;
+
import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/test/resources/lens-site.xml
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/test/resources/lens-site.xml b/lens-ml-lib/src/test/resources/lens-site.xml
index 9ce4703..2e1ddab 100644
--- a/lens-ml-lib/src/test/resources/lens-site.xml
+++ b/lens-ml-lib/src/test/resources/lens-site.xml
@@ -103,7 +103,7 @@
<property>
<name>lens.server.ml.ws.resource.impl</name>
- <value>org.apache.lens.server.ml.MLServiceResource</value>
+ <value>org.apache.lens.ml.server.MLServiceResource</value>
<description>Implementation class for ML Service Resource</description>
</property>
@@ -138,13 +138,13 @@
<property>
<name>lens.server.ml.service.impl</name>
- <value>org.apache.lens.server.ml.MLServiceImpl</value>
+ <value>org.apache.lens.ml.server.MLServiceImpl</value>
<description>Implementation class for ML service</description>
</property>
<property>
<name>lens.ml.drivers</name>
- <value>org.apache.lens.ml.spark.SparkMLDriver</value>
+ <value>org.apache.lens.ml.algo.spark.SparkMLDriver</value>
</property>
<property>
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/tools/conf-pseudo-distr/server/lens-site.xml
----------------------------------------------------------------------
diff --git a/tools/conf-pseudo-distr/server/lens-site.xml b/tools/conf-pseudo-distr/server/lens-site.xml
index f11c6d6..ce3e753 100644
--- a/tools/conf-pseudo-distr/server/lens-site.xml
+++ b/tools/conf-pseudo-distr/server/lens-site.xml
@@ -59,19 +59,19 @@
<property>
<name>lens.server.ml.ws.resource.impl</name>
- <value>org.apache.lens.server.ml.MLServiceResource</value>
+ <value>org.apache.lens.ml.server.MLServiceResource</value>
<description>Implementation class for ML Service Resource</description>
</property>
<property>
<name>lens.server.ml.service.impl</name>
- <value>org.apache.lens.server.ml.MLServiceImpl</value>
+ <value>org.apache.lens.ml.server.MLServiceImpl</value>
<description>Implementation class for ML service</description>
</property>
<property>
<name>lens.ml.drivers</name>
- <value>org.apache.lens.ml.spark.SparkMLDriver</value>
+ <value>org.apache.lens.ml.algo.spark.SparkMLDriver</value>
</property>
<property>
[2/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Posted by sh...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
new file mode 100644
index 0000000..f9c954e
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
@@ -0,0 +1,427 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import static org.apache.commons.lang.StringUtils.isBlank;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.ws.rs.BadRequestException;
+import javax.ws.rs.Consumes;
+import javax.ws.rs.DELETE;
+import javax.ws.rs.GET;
+import javax.ws.rs.NotFoundException;
+import javax.ws.rs.POST;
+import javax.ws.rs.Path;
+import javax.ws.rs.PathParam;
+import javax.ws.rs.Produces;
+import javax.ws.rs.core.Context;
+import javax.ws.rs.core.MediaType;
+import javax.ws.rs.core.MultivaluedMap;
+import javax.ws.rs.core.Response;
+import javax.ws.rs.core.UriInfo;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.StringList;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.api.ModelMetadata;
+import org.apache.lens.ml.api.TestReport;
+import org.apache.lens.ml.impl.ModelLoader;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import org.glassfish.jersey.media.multipart.FormDataParam;
+
+/**
+ * Machine Learning service.
+ */
+@Path("/ml")
+@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
+public class MLServiceResource {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(MLServiceResource.class);
+
+ /** The ml service. */
+ MLService mlService;
+
+ /** The service provider. */
+ ServiceProvider serviceProvider;
+
+ /** The service provider factory. */
+ ServiceProviderFactory serviceProviderFactory;
+
+ private static final HiveConf HIVE_CONF;
+
+ /**
+ * Message indicating if ML service is up
+ */
+ public static final String ML_UP_MESSAGE = "ML service is up";
+
+ static {
+ HIVE_CONF = new HiveConf();
+ // Add default config so that we know the service provider implementation
+ HIVE_CONF.addResource("lensserver-default.xml");
+ HIVE_CONF.addResource("lens-site.xml");
+ }
+
+ /**
+ * Instantiates a new ML service resource.
+ */
+ public MLServiceResource() {
+ serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
+ }
+
+ private ServiceProvider getServiceProvider() {
+ if (serviceProvider == null) {
+ serviceProvider = serviceProviderFactory.getServiceProvider();
+ }
+ return serviceProvider;
+ }
+
+ /**
+ * Gets the service provider factory.
+ *
+ * @param conf the conf
+ * @return the service provider factory
+ */
+ private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
+ Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
+ try {
+ return (ServiceProviderFactory) spfClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private MLService getMlService() {
+ if (mlService == null) {
+ mlService = (MLService) getServiceProvider().getService(MLService.NAME);
+ }
+ return mlService;
+ }
+
+ /**
+ * Indicates if ML resource is up
+ *
+ * @return
+ */
+ @GET
+ public String mlResourceUp() {
+ return ML_UP_MESSAGE;
+ }
+
+ /**
+ * Get a list of algos available
+ *
+ * @return
+ */
+ @GET
+ @Path("algos")
+ public StringList getAlgoNames() {
+ List<String> algos = getMlService().getAlgorithms();
+ StringList result = new StringList(algos);
+ return result;
+ }
+
+ /**
+ * Gets the human readable param description of an algorithm
+ *
+ * @param algorithm the algorithm
+ * @return the param description
+ */
+ @GET
+ @Path("algos/{algorithm}")
+ public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
+ Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm);
+ if (paramDesc == null) {
+ throw new NotFoundException("Param description not found for " + algorithm);
+ }
+
+ List<String> descriptions = new ArrayList<String>();
+ for (String key : paramDesc.keySet()) {
+ descriptions.add(key + " : " + paramDesc.get(key));
+ }
+ return new StringList(descriptions);
+ }
+
+ /**
+ * Get model ID list for a given algorithm.
+ *
+ * @param algorithm algorithm name
+ * @return the models for algo
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("models/{algorithm}")
+ public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
+ List<String> models = getMlService().getModels(algorithm);
+ if (models == null || models.isEmpty()) {
+ throw new NotFoundException("No models found for algorithm " + algorithm);
+ }
+ return new StringList(models);
+ }
+
+ /**
+ * Get metadata of the model given algorithm and model ID.
+ *
+ * @param algorithm algorithm name
+ * @param modelID model ID
+ * @return model metadata
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("models/{algorithm}/{modelID}")
+ public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
+ throws LensException {
+ MLModel model = getMlService().getModel(algorithm, modelID);
+ if (model == null) {
+ throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
+ }
+
+ ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
+ model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
+ model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
+ return meta;
+ }
+
+ /**
+ * Delete a model given model ID and algorithm name.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return confirmation text
+ * @throws LensException the lens exception
+ */
+ @DELETE
+ @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
+ @Path("models/{algorithm}/{modelID}")
+ public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
+ throws LensException {
+ getMlService().deleteModel(algorithm, modelID);
+ return "DELETED model=" + modelID + " algorithm=" + algorithm;
+ }
+
+ /**
+ * Train a model given an algorithm name and algorithm parameters
+ * <p>
+ * Following parameters are mandatory and must be passed as part of the form
+ * <p/>
+ * <ol>
+ * <li>table - input Hive table to load training data from</li>
+ * <li>label - name of the labelled column</li>
+ * <li>feature - one entry per feature column. At least one feature column is required</li>
+ * </ol>
+ * <p/>
+ * </p>
+ *
+ * @param algorithm algorithm name
+ * @param form form data
+ * @return if model is successfully trained, the model ID will be returned
+ * @throws LensException the lens exception
+ */
+ @POST
+ @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
+ @Path("{algorithm}/train")
+ public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form)
+ throws LensException {
+
+ // Check if algo is valid
+ if (getMlService().getAlgoForName(algorithm) == null) {
+ throw new NotFoundException("Algo for algo: " + algorithm + " not found");
+ }
+
+ if (isBlank(form.getFirst("table"))) {
+ throw new BadRequestException("table parameter is rquired");
+ }
+
+ String table = form.getFirst("table");
+
+ if (isBlank(form.getFirst("label"))) {
+ throw new BadRequestException("label parameter is required");
+ }
+
+ // Check features
+ List<String> featureNames = form.get("feature");
+ if (featureNames.size() < 1) {
+ throw new BadRequestException("At least one feature is required");
+ }
+
+ List<String> algoArgs = new ArrayList<String>();
+ Set<Map.Entry<String, List<String>>> paramSet = form.entrySet();
+
+ for (Map.Entry<String, List<String>> e : paramSet) {
+ String p = e.getKey();
+ List<String> values = e.getValue();
+ if ("algorithm".equals(p) || "table".equals(p)) {
+ continue;
+ } else if ("feature".equals(p)) {
+ for (String feature : values) {
+ algoArgs.add("feature");
+ algoArgs.add(feature);
+ }
+ } else if ("label".equals(p)) {
+ algoArgs.add("label");
+ algoArgs.add(values.get(0));
+ } else {
+ algoArgs.add(p);
+ algoArgs.add(values.get(0));
+ }
+ }
+ LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString());
+ String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
+ LOG.info("Done training " + table + " modelid = " + modelId);
+ return modelId;
+ }
+
+ /**
+ * Clear model cache (for admin use).
+ *
+ * @return OK if the cache was cleared
+ */
+ @DELETE
+ @Path("clearModelCache")
+ @Produces(MediaType.TEXT_PLAIN)
+ public Response clearModelCache() {
+ ModelLoader.clearCache();
+ LOG.info("Cleared model cache");
+ return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
+ }
+
+ /**
+ * Run a test on a model for an algorithm.
+ *
+ * @param algorithm algorithm name
+ * @param modelID model ID
+ * @param table Hive table to run test on
+ * @param session Lens session ID. This session ID will be used to run the test query
+ * @return Test report ID
+ * @throws LensException the lens exception
+ */
+ @POST
+ @Path("test/{table}/{algorithm}/{modelID}")
+ @Consumes(MediaType.MULTIPART_FORM_DATA)
+ public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
+ @PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session,
+ @FormDataParam("outputTable") String outputTable) throws LensException {
+ MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable);
+ return testReport.getReportID();
+ }
+
+ /**
+ * Get list of reports for a given algorithm.
+ *
+ * @param algoritm the algoritm
+ * @return the reports for algorithm
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("reports/{algorithm}")
+ public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException {
+ List<String> reports = getMlService().getTestReports(algoritm);
+ if (reports == null || reports.isEmpty()) {
+ throw new NotFoundException("No test reports found for " + algoritm);
+ }
+ return new StringList(reports);
+ }
+
+ /**
+ * Get a single test report given the algorithm name and report id.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the test report
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("reports/{algorithm}/{reportID}")
+ public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
+ throws LensException {
+ MLTestReport report = getMlService().getTestReport(algorithm, reportID);
+
+ if (report == null) {
+ throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm);
+ }
+
+ TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(),
+ report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(),
+ report.getModelID(), report.getReportID(), report.getLensQueryID());
+ return result;
+ }
+
+ /**
+ * DELETE a report given the algorithm name and report ID.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the string
+ * @throws LensException the lens exception
+ */
+ @DELETE
+ @Path("reports/{algorithm}/{reportID}")
+ @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
+ public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
+ throws LensException {
+ getMlService().deleteTestReport(algorithm, reportID);
+ return "DELETED report=" + reportID + " algorithm=" + algorithm;
+ }
+
+ /**
+ * Predict.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param uriInfo the uri info
+ * @return the string
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("/predict/{algorithm}/{modelID}")
+ @Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON})
+ public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
+ @Context UriInfo uriInfo) throws LensException {
+ // Load the model instance
+ MLModel<?> model = getMlService().getModel(algorithm, modelID);
+
+ // Get input feature names
+ MultivaluedMap<String, String> params = uriInfo.getQueryParameters();
+ String[] features = new String[model.getFeatureColumns().size()];
+ // Assuming that feature name parameters are same
+ int i = 0;
+ for (String feature : model.getFeatureColumns()) {
+ features[i++] = params.getFirst(feature);
+ }
+
+ // TODO needs a 'prediction formatter'
+ return getMlService().predict(algorithm, modelID, features).toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
deleted file mode 100644
index abdad68..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.log4j.Logger;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-import com.google.common.base.Preconditions;
-import scala.Tuple2;
-
-/**
- * A feature function that directly maps an HCatRecord to a feature vector. Each column becomes a feature in the vector,
- * with the value of the feature obtained using the value mapper for that column
- */
-public class ColumnFeatureFunction extends FeatureFunction {
-
- /** The Constant LOG. */
- public static final Logger LOG = Logger.getLogger(ColumnFeatureFunction.class);
-
- /** The feature value mappers. */
- private final FeatureValueMapper[] featureValueMappers;
-
- /** The feature positions. */
- private final int[] featurePositions;
-
- /** The label column pos. */
- private final int labelColumnPos;
-
- /** The num features. */
- private final int numFeatures;
-
- /** The default labeled point. */
- private final LabeledPoint defaultLabeledPoint;
-
- /**
- * Feature positions and value mappers are parallel arrays. featurePositions[i] gives the position of ith feature in
- * the HCatRecord, and valueMappers[i] gives the value mapper used to map that feature to a Double value
- *
- * @param featurePositions position number of feature column in the HCatRecord
- * @param valueMappers mapper for each column position
- * @param labelColumnPos position of the label column
- * @param numFeatures number of features in the feature vector
- * @param defaultLabel default lable to be used for null records
- */
- public ColumnFeatureFunction(int[] featurePositions, FeatureValueMapper[] valueMappers, int labelColumnPos,
- int numFeatures, double defaultLabel) {
- Preconditions.checkNotNull(valueMappers, "Value mappers argument is required");
- Preconditions.checkNotNull(featurePositions, "Feature positions are required");
- Preconditions.checkArgument(valueMappers.length == featurePositions.length,
- "Mismatch between number of value mappers and feature positions");
-
- this.featurePositions = featurePositions;
- this.featureValueMappers = valueMappers;
- this.labelColumnPos = labelColumnPos;
- this.numFeatures = numFeatures;
- defaultLabeledPoint = new LabeledPoint(defaultLabel, Vectors.dense(new double[numFeatures]));
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.FeatureFunction#call(scala.Tuple2)
- */
- @Override
- public LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception {
- HCatRecord record = tuple._2();
-
- if (record == null) {
- LOG.info("@@@ Null record");
- return defaultLabeledPoint;
- }
-
- double[] features = new double[numFeatures];
-
- for (int i = 0; i < numFeatures; i++) {
- int featurePos = featurePositions[i];
- features[i] = featureValueMappers[i].call(record.get(featurePos));
- }
-
- double label = featureValueMappers[labelColumnPos].call(record.get(labelColumnPos));
- return new LabeledPoint(label, Vectors.dense(features));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
deleted file mode 100644
index 781ccd1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-/**
- * Directly return input when it is known to be double.
- */
-public class DoubleValueMapper extends FeatureValueMapper {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.FeatureValueMapper#call(java.lang.Object)
- */
- @Override
- public final Double call(Object input) {
- if (input instanceof Double || input == null) {
- return input == null ? Double.valueOf(0d) : (Double) input;
- }
-
- throw new IllegalArgumentException("Invalid input expecting only doubles, but got " + input);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
deleted file mode 100644
index affed7b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-import scala.Tuple2;
-
-/**
- * Function to map an HCatRecord to a feature vector usable by MLLib.
- */
-public abstract class FeatureFunction implements Function<Tuple2<WritableComparable, HCatRecord>, LabeledPoint> {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public abstract LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
deleted file mode 100644
index b692379..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.Serializable;
-
-import org.apache.spark.api.java.function.Function;
-
-/**
- * Map a feature value to a Double value usable by MLLib.
- */
-public abstract class FeatureValueMapper implements Function<Object, Double>, Serializable {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- public abstract Double call(Object input);
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
deleted file mode 100644
index 44a8e1d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.IOException;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-
-/**
- * Create a JavaRDD based on a Hive table using HCatInputFormat.
- */
-public final class HiveTableRDD {
- private HiveTableRDD() {
- }
-
- public static final Log LOG = LogFactory.getLog(HiveTableRDD.class);
-
- /**
- * Creates the hive table rdd.
- *
- * @param javaSparkContext the java spark context
- * @param conf the conf
- * @param db the db
- * @param table the table
- * @param partitionFilter the partition filter
- * @return the java pair rdd
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static JavaPairRDD<WritableComparable, HCatRecord> createHiveTableRDD(JavaSparkContext javaSparkContext,
- Configuration conf, String db, String table, String partitionFilter) throws IOException {
-
- HCatInputFormat.setInput(conf, db, table, partitionFilter);
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = javaSparkContext.newAPIHadoopRDD(conf,
- HCatInputFormat.class, // Input
- WritableComparable.class, // input key class
- HCatRecord.class); // input value class
- return rdd;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
deleted file mode 100644
index 1e452f1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
+++ /dev/null
@@ -1,275 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.File;
-import java.io.FilenameFilter;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithms;
-import org.apache.lens.ml.MLAlgo;
-import org.apache.lens.ml.MLDriver;
-import org.apache.lens.ml.spark.algos.*;
-
-import org.apache.commons.lang.StringUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-
-/**
- * The Class SparkMLDriver.
- */
-public class SparkMLDriver implements MLDriver {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
-
- /** The owns spark context. */
- private boolean ownsSparkContext = true;
-
- /**
- * The Enum SparkMasterMode.
- */
- private enum SparkMasterMode {
- // Embedded mode used in tests
- /** The embedded. */
- EMBEDDED,
- // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
- /** The yarn client. */
- YARN_CLIENT,
-
- /** The yarn cluster. */
- YARN_CLUSTER
- }
-
- /** The algorithms. */
- private final Algorithms algorithms = new Algorithms();
-
- /** The client mode. */
- private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
-
- /** The is started. */
- private boolean isStarted;
-
- /** The spark conf. */
- private SparkConf sparkConf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /**
- * Use spark context.
- *
- * @param jsc the jsc
- */
- public void useSparkContext(JavaSparkContext jsc) {
- ownsSparkContext = false;
- this.sparkContext = jsc;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
- */
- @Override
- public boolean isAlgoSupported(String name) {
- return algorithms.isAlgoSupported(name);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
- */
- @Override
- public MLAlgo getAlgoInstance(String name) throws LensException {
- checkStarted();
-
- if (!isAlgoSupported(name)) {
- return null;
- }
-
- MLAlgo algo = null;
- try {
- algo = algorithms.getAlgoForName(name);
- if (algo instanceof BaseSparkAlgo) {
- ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
- }
- } catch (LensException exc) {
- LOG.error("Error creating algo object", exc);
- }
- return algo;
- }
-
- /**
- * Register algos.
- */
- private void registerAlgos() {
- algorithms.register(NaiveBayesAlgo.class);
- algorithms.register(SVMAlgo.class);
- algorithms.register(LogisticRegressionAlgo.class);
- algorithms.register(DecisionTreeAlgo.class);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
- */
- @Override
- public void init(LensConf conf) throws LensException {
- sparkConf = new SparkConf();
- registerAlgos();
- for (String key : conf.getProperties().keySet()) {
- if (key.startsWith("lens.ml.sparkdriver.")) {
- sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
- }
- }
-
- String sparkAppMaster = sparkConf.get("spark.master");
- if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) {
- clientMode = SparkMasterMode.YARN_CLIENT;
- } else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) {
- clientMode = SparkMasterMode.YARN_CLUSTER;
- } else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) {
- clientMode = SparkMasterMode.EMBEDDED;
- } else {
- throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster);
- }
-
- if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) {
- String sparkHome = System.getenv("SPARK_HOME");
- if (StringUtils.isNotBlank(sparkHome)) {
- sparkConf.setSparkHome(sparkHome);
- }
-
- // If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties.
- if (StringUtils.isBlank(sparkConf.get("spark.home"))) {
- throw new IllegalArgumentException("Spark home is not set");
- }
-
- LOG.info("Spark home is set to " + sparkConf.get("spark.home"));
- }
-
- sparkConf.setAppName("lens-ml");
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#start()
- */
- @Override
- public void start() throws LensException {
- if (sparkContext == null) {
- sparkContext = new JavaSparkContext(sparkConf);
- }
-
- // Adding jars to spark context is only required when running in yarn-client mode
- if (clientMode != SparkMasterMode.EMBEDDED) {
- // TODO Figure out only necessary set of JARs to be added for HCatalog
- // Add hcatalog and hive jars
- String hiveLocation = System.getenv("HIVE_HOME");
-
- if (StringUtils.isBlank(hiveLocation)) {
- throw new LensException("HIVE_HOME is not set");
- }
-
- LOG.info("HIVE_HOME at " + hiveLocation);
-
- File hiveLibDir = new File(hiveLocation, "lib");
- FilenameFilter jarFileFilter = new FilenameFilter() {
- @Override
- public boolean accept(File file, String s) {
- return s.endsWith(".jar");
- }
- };
-
- List<String> jarFiles = new ArrayList<String>();
- // Add hive jars
- for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) {
- jarFiles.add(jarFile.getAbsolutePath());
- LOG.info("Adding HIVE jar " + jarFile.getAbsolutePath());
- sparkContext.addJar(jarFile.getAbsolutePath());
- }
-
- // Add hcatalog jars
- File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog");
- for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) {
- jarFiles.add(jarFile.getAbsolutePath());
- LOG.info("Adding HCATALOG jar " + jarFile.getAbsolutePath());
- sparkContext.addJar(jarFile.getAbsolutePath());
- }
-
- // Add the current jar
- String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class);
- for (String lensSparkJar : lensSparkLibJars) {
- LOG.info("Adding Lens JAR " + lensSparkJar);
- sparkContext.addJar(lensSparkJar);
- }
- }
-
- isStarted = true;
- LOG.info("Created Spark context for app: '" + sparkContext.appName() + "', Spark master: " + sparkContext.master());
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#stop()
- */
- @Override
- public void stop() throws LensException {
- if (!isStarted) {
- LOG.warn("Spark driver was not started");
- return;
- }
- isStarted = false;
- if (ownsSparkContext) {
- sparkContext.stop();
- }
- LOG.info("Stopped spark context " + this);
- }
-
- @Override
- public List<String> getAlgoNames() {
- return algorithms.getAlgorithmNames();
- }
-
- /**
- * Check started.
- *
- * @throws LensException the lens exception
- */
- public void checkStarted() throws LensException {
- if (!isStarted) {
- throw new LensException("Spark driver is not started yet");
- }
- }
-
- public JavaSparkContext getSparkContext() {
- return sparkContext;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
deleted file mode 100644
index e569b1e..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
+++ /dev/null
@@ -1,433 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.lens.api.LensException;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.hive.hcatalog.data.schema.HCatFieldSchema;
-import org.apache.hive.hcatalog.data.schema.HCatSchema;
-import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-import com.google.common.base.Preconditions;
-import lombok.Getter;
-import lombok.ToString;
-
-/**
- * The Class TableTrainingSpec.
- */
-@ToString
-public class TableTrainingSpec implements Serializable {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);
-
- /** The training rdd. */
- @Getter
- private transient RDD<LabeledPoint> trainingRDD;
-
- /** The testing rdd. */
- @Getter
- private transient RDD<LabeledPoint> testingRDD;
-
- /** The database. */
- @Getter
- private String database;
-
- /** The table. */
- @Getter
- private String table;
-
- /** The partition filter. */
- @Getter
- private String partitionFilter;
-
- /** The feature columns. */
- @Getter
- private List<String> featureColumns;
-
- /** The label column. */
- @Getter
- private String labelColumn;
-
- /** The conf. */
- @Getter
- private transient HiveConf conf;
-
- // By default all samples are considered for training
- /** The split training. */
- private boolean splitTraining;
-
- /** The training fraction. */
- private double trainingFraction = 1.0;
-
- /** The label pos. */
- int labelPos;
-
- /** The feature positions. */
- int[] featurePositions;
-
- /** The num features. */
- int numFeatures;
-
- /** The labeled rdd. */
- transient JavaRDD<LabeledPoint> labeledRDD;
-
- /**
- * New builder.
- *
- * @return the table training spec builder
- */
- public static TableTrainingSpecBuilder newBuilder() {
- return new TableTrainingSpecBuilder();
- }
-
- /**
- * The Class TableTrainingSpecBuilder.
- */
- public static class TableTrainingSpecBuilder {
-
- /** The spec. */
- final TableTrainingSpec spec;
-
- /**
- * Instantiates a new table training spec builder.
- */
- public TableTrainingSpecBuilder() {
- spec = new TableTrainingSpec();
- }
-
- /**
- * Hive conf.
- *
- * @param conf the conf
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder hiveConf(HiveConf conf) {
- spec.conf = conf;
- return this;
- }
-
- /**
- * Database.
- *
- * @param db the db
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder database(String db) {
- spec.database = db;
- return this;
- }
-
- /**
- * Table.
- *
- * @param table the table
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder table(String table) {
- spec.table = table;
- return this;
- }
-
- /**
- * Partition filter.
- *
- * @param partFilter the part filter
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder partitionFilter(String partFilter) {
- spec.partitionFilter = partFilter;
- return this;
- }
-
- /**
- * Label column.
- *
- * @param labelColumn the label column
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder labelColumn(String labelColumn) {
- spec.labelColumn = labelColumn;
- return this;
- }
-
- /**
- * Feature columns.
- *
- * @param featureColumns the feature columns
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) {
- spec.featureColumns = featureColumns;
- return this;
- }
-
- /**
- * Builds the.
- *
- * @return the table training spec
- */
- public TableTrainingSpec build() {
- return spec;
- }
-
- /**
- * Training fraction.
- *
- * @param trainingFraction the training fraction
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder trainingFraction(double trainingFraction) {
- Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0,
- "Training fraction shoule be between 0 and 1");
- spec.trainingFraction = trainingFraction;
- spec.splitTraining = true;
- return this;
- }
- }
-
- /**
- * The Class DataSample.
- */
- public static class DataSample implements Serializable {
-
- /** The labeled point. */
- private final LabeledPoint labeledPoint;
-
- /** The sample. */
- private final double sample;
-
- /**
- * Instantiates a new data sample.
- *
- * @param labeledPoint the labeled point
- */
- public DataSample(LabeledPoint labeledPoint) {
- sample = Math.random();
- this.labeledPoint = labeledPoint;
- }
- }
-
- /**
- * The Class TrainingFilter.
- */
- public static class TrainingFilter implements Function<DataSample, Boolean> {
-
- /** The training fraction. */
- private double trainingFraction;
-
- /**
- * Instantiates a new training filter.
- *
- * @param fraction the fraction
- */
- public TrainingFilter(double fraction) {
- trainingFraction = fraction;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public Boolean call(DataSample v1) throws Exception {
- return v1.sample <= trainingFraction;
- }
- }
-
- /**
- * The Class TestingFilter.
- */
- public static class TestingFilter implements Function<DataSample, Boolean> {
-
- /** The training fraction. */
- private double trainingFraction;
-
- /**
- * Instantiates a new testing filter.
- *
- * @param fraction the fraction
- */
- public TestingFilter(double fraction) {
- trainingFraction = fraction;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public Boolean call(DataSample v1) throws Exception {
- return v1.sample > trainingFraction;
- }
- }
-
- /**
- * The Class GetLabeledPoint.
- */
- public static class GetLabeledPoint implements Function<DataSample, LabeledPoint> {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public LabeledPoint call(DataSample v1) throws Exception {
- return v1.labeledPoint;
- }
- }
-
- /**
- * Validate.
- *
- * @return true, if successful
- */
- boolean validate() {
- List<HCatFieldSchema> columns;
- try {
- HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
- HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
- columns = tableSchema.getFields();
- } catch (IOException exc) {
- LOG.error("Error getting table info " + toString(), exc);
- return false;
- }
-
- LOG.info(table + " columns " + columns.toString());
-
- boolean valid = false;
- if (columns != null && !columns.isEmpty()) {
- // Check labeled column
- List<String> columnNames = new ArrayList<String>();
- for (HCatFieldSchema col : columns) {
- columnNames.add(col.getName());
- }
-
- // Need at least one feature column and one label column
- valid = columnNames.contains(labelColumn) && columnNames.size() > 1;
-
- if (valid) {
- labelPos = columnNames.indexOf(labelColumn);
-
- // Check feature columns
- if (featureColumns == null || featureColumns.isEmpty()) {
- // feature columns are not provided, so all columns except label column are feature columns
- featurePositions = new int[columnNames.size() - 1];
- int p = 0;
- for (int i = 0; i < columnNames.size(); i++) {
- if (i == labelPos) {
- continue;
- }
- featurePositions[p++] = i;
- }
-
- columnNames.remove(labelPos);
- featureColumns = columnNames;
- } else {
- // Feature columns were provided, verify all feature columns are present in the table
- valid = columnNames.containsAll(featureColumns);
- if (valid) {
- // Get feature positions
- featurePositions = new int[featureColumns.size()];
- for (int i = 0; i < featureColumns.size(); i++) {
- featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
- }
- }
- }
- numFeatures = featureColumns.size();
- }
- }
-
- return valid;
- }
-
- /**
- * Creates the rd ds.
- *
- * @param sparkContext the spark context
- * @throws LensException the lens exception
- */
- public void createRDDs(JavaSparkContext sparkContext) throws LensException {
- // Validate the spec
- if (!validate()) {
- throw new LensException("Table spec not valid: " + toString());
- }
-
- LOG.info("Creating RDDs with spec " + toString());
-
- // Get the RDD for table
- JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
- try {
- tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
- } catch (IOException e) {
- throw new LensException(e);
- }
-
- // Map into trainable RDD
- // TODO: Figure out a way to use custom value mappers
- FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
- final DoubleValueMapper doubleMapper = new DoubleValueMapper();
- for (int i = 0; i < numFeatures; i++) {
- valueMappers[i] = doubleMapper;
- }
-
- ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
- numFeatures, 0);
- labeledRDD = tableRDD.map(trainPrepFunction);
-
- if (splitTraining) {
- // We have to split the RDD between a training RDD and a testing RDD
- LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
- JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
- @Override
- public DataSample call(LabeledPoint v1) throws Exception {
- return new DataSample(v1);
- }
- });
-
- trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- } else {
- LOG.info("Using same RDD for train and test");
- trainingRDD = labeledRDD.rdd();
- testingRDD = trainingRDD;
- }
- LOG.info("Generated RDDs");
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
deleted file mode 100644
index 22cda6d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
+++ /dev/null
@@ -1,290 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.lang.reflect.Field;
-import java.util.*;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.MLAlgo;
-import org.apache.lens.ml.MLModel;
-
-import org.apache.lens.ml.spark.TableTrainingSpec;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class BaseSparkAlgo.
- */
-public abstract class BaseSparkAlgo implements MLAlgo {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
-
- /** The name. */
- private final String name;
-
- /** The description. */
- private final String description;
-
- /** The spark context. */
- protected JavaSparkContext sparkContext;
-
- /** The params. */
- protected Map<String, String> params;
-
- /** The conf. */
- protected transient LensConf conf;
-
- /** The training fraction. */
- @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
- protected double trainingFraction;
-
- /** The use training fraction. */
- private boolean useTrainingFraction;
-
- /** The label. */
- @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
- protected String label;
-
- /** The partition filter. */
- @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
- protected String partitionFilter;
-
- /** The features. */
- @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
- protected List<String> features;
-
- /**
- * Instantiates a new base spark algo.
- *
- * @param name the name
- * @param description the description
- */
- public BaseSparkAlgo(String name, String description) {
- this.name = name;
- this.description = description;
- }
-
- public void setSparkContext(JavaSparkContext sparkContext) {
- this.sparkContext = sparkContext;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
- throws LensException {
- parseParams(params);
-
- TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
- .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
-
- if (useTrainingFraction) {
- builder.trainingFraction(trainingFraction);
- }
-
- TableTrainingSpec spec = builder.build();
- LOG.info("Training " + " with " + features.size() + " features");
-
- spec.createRDDs(sparkContext);
-
- RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
- BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
- model.setTable(table);
- model.setParams(Arrays.asList(params));
- model.setLabelColumn(label);
- model.setFeatureColumns(features);
- return model;
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- protected HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-
- /**
- * Parses the params.
- *
- * @param args the args
- */
- public void parseParams(String[] args) {
- if (args.length % 2 != 0) {
- throw new IllegalArgumentException("Invalid number of params " + args.length);
- }
-
- params = new LinkedHashMap<String, String>();
-
- for (int i = 0; i < args.length; i += 2) {
- if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
- if (features == null) {
- features = new ArrayList<String>();
- }
- features.add(args[i + 1]);
- } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
- label = args[i + 1];
- } else {
- params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
- }
- }
-
- if (params.containsKey("trainingFraction")) {
- // Get training Fraction
- String trainingFractionStr = params.get("trainingFraction");
- try {
- trainingFraction = Double.parseDouble(trainingFractionStr);
- useTrainingFraction = true;
- } catch (NumberFormatException nfe) {
- throw new IllegalArgumentException("Invalid training fraction", nfe);
- }
- }
-
- if (params.containsKey("partition") || params.containsKey("p")) {
- partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
- }
-
- parseAlgoParams(params);
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public double getParamValue(String param, double defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Double.parseDouble(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as double.");
- }
- }
- return defaultVal;
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public int getParamValue(String param, int defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Integer.parseInt(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as integer.");
- }
- }
- return defaultVal;
- }
-
- public String getName() {
- return name;
- }
-
- public String getDescription() {
- return description;
- }
-
- public Map<String, String> getArgUsage() {
- Map<String, String> usage = new LinkedHashMap<String, String>();
- Class<?> clz = this.getClass();
- // Put class name and description as well as part of the usage
- Algorithm algorithm = clz.getAnnotation(Algorithm.class);
- if (algorithm != null) {
- usage.put("Algorithm Name", algorithm.name());
- usage.put("Algorithm Description", algorithm.description());
- }
-
- // Get all algo params including base algo params
- while (clz != null) {
- for (Field field : clz.getDeclaredFields()) {
- AlgoParam param = field.getAnnotation(AlgoParam.class);
- if (param != null) {
- usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
- }
- }
-
- if (clz.equals(BaseSparkAlgo.class)) {
- break;
- }
- clz = clz.getSuperclass();
- }
- return usage;
- }
-
- /**
- * Parses the algo params.
- *
- * @param params the params
- */
- public abstract void parseAlgoParams(Map<String, String> params);
-
- /**
- * Train internal.
- *
- * @param modelId the model id
- * @param trainingRDD the training rdd
- * @return the base spark classification model
- * @throws LensException the lens exception
- */
- protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
deleted file mode 100644
index a6d66c5..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
-import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
-
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.DecisionTree$;
-import org.apache.spark.mllib.tree.configuration.Algo$;
-import org.apache.spark.mllib.tree.impurity.Entropy$;
-import org.apache.spark.mllib.tree.impurity.Gini$;
-import org.apache.spark.mllib.tree.impurity.Impurity;
-import org.apache.spark.mllib.tree.impurity.Variance$;
-import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.rdd.RDD;
-
-import scala.Enumeration;
-
-/**
- * The Class DecisionTreeAlgo.
- */
-@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
-public class DecisionTreeAlgo extends BaseSparkAlgo {
-
- /** The algo. */
- @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
- private Enumeration.Value algo;
-
- /** The decision tree impurity. */
- @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
- + "Allowed values are 'gini', 'entropy' and 'variance'")
- private Impurity decisionTreeImpurity;
-
- /** The max depth. */
- @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
- defaultValue = "100")
- private int maxDepth;
-
- /**
- * Instantiates a new decision tree algo.
- *
- * @param name the name
- * @param description the description
- */
- public DecisionTreeAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- String dtreeAlgoName = params.get("algo");
- if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Classification();
- } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Regression();
- }
-
- String impurity = params.get("impurity");
- if ("gini".equals(impurity)) {
- decisionTreeImpurity = Gini$.MODULE$;
- } else if ("entropy".equals(impurity)) {
- decisionTreeImpurity = Entropy$.MODULE$;
- } else if ("variance".equals(impurity)) {
- decisionTreeImpurity = Variance$.MODULE$;
- }
-
- maxDepth = getParamValue("maxDepth", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
- return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
deleted file mode 100644
index 7ca5a79..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
+++ /dev/null
@@ -1,163 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.*;
-import org.apache.lens.ml.spark.HiveTableRDD;
-import org.apache.lens.ml.spark.models.KMeansClusteringModel;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.clustering.KMeans;
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-
-import scala.Tuple2;
-
-/**
- * The Class KMeansAlgo.
- */
-@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
-public class KMeansAlgo implements MLAlgo {
-
- /** The conf. */
- private transient LensConf conf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /** The part filter. */
- @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
- private String partFilter = null;
-
- /** The k. */
- @AlgoParam(name = "k", help = "Number of cluster")
- private int k;
-
- /** The max iterations. */
- @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
- private int maxIterations = 100;
-
- /** The runs. */
- @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
- private int runs = 1;
-
- /** The initialization mode. */
- @AlgoParam(name = "initializationMode",
- help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
- private String initializationMode = "k-means||";
-
- @Override
- public String getName() {
- return getClass().getAnnotation(Algorithm.class).name();
- }
-
- @Override
- public String getDescription() {
- return getClass().getAnnotation(Algorithm.class).description();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
- List<String> features = AlgoArgParser.parseArgs(this, params);
- final int[] featurePositions = new int[features.size()];
- final int NUM_FEATURES = features.size();
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
- try {
- // Map feature names to positions
- Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
- List<FieldSchema> allCols = tbl.getAllCols();
- int f = 0;
- for (int i = 0; i < tbl.getAllCols().size(); i++) {
- String colName = allCols.get(i).getName();
- if (features.contains(colName)) {
- featurePositions[f++] = i;
- }
- }
-
- rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
- JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
- @Override
- public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
- HCatRecord hCatRecord = v1._2();
- double[] arr = new double[NUM_FEATURES];
- for (int i = 0; i < NUM_FEATURES; i++) {
- Object val = hCatRecord.get(featurePositions[i]);
- arr[i] = val == null ? 0d : (Double) val;
- }
- return Vectors.dense(arr);
- }
- });
-
- KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
- return new KMeansClusteringModel(modelId, model);
- } catch (Exception e) {
- throw new LensException("KMeans algo failed for " + db + "." + table, e);
- }
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- private HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
deleted file mode 100644
index 106b3c5..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel;
-
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class LogisticRegressionAlgo.
- */
-@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
-public class LogisticRegressionAlgo extends BaseSparkAlgo {
-
- /** The iterations. */
- @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
- private int iterations;
-
- /** The step size. */
- @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The min batch fraction. */
- @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /**
- * Instantiates a new logistic regression algo.
- *
- * @param name the name
- * @param description the description
- */
- public LogisticRegressionAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- iterations = getParamValue("iterations", 100);
- stepSize = getParamValue("stepSize", 1.0d);
- minBatchFraction = getParamValue("minBatchFraction", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
- minBatchFraction);
- return new LogitRegressionClassificationModel(modelId, lrModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
deleted file mode 100644
index f7652d1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
-
-import org.apache.spark.mllib.classification.NaiveBayes;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class NaiveBayesAlgo.
- */
-@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
-public class NaiveBayesAlgo extends BaseSparkAlgo {
-
- /** The lambda. */
- @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
- private double lambda = 1.0;
-
- /**
- * Instantiates a new naive bayes algo.
- *
- * @param name the name
- * @param description the description
- */
- public NaiveBayesAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- lambda = getParamValue("lambda", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
deleted file mode 100644
index 09251b7..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.SVMClassificationModel;
-
-import org.apache.spark.mllib.classification.SVMModel;
-import org.apache.spark.mllib.classification.SVMWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class SVMAlgo.
- */
-@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
-public class SVMAlgo extends BaseSparkAlgo {
-
- /** The min batch fraction. */
- @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /** The reg param. */
- @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
- private double regParam;
-
- /** The step size. */
- @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The iterations. */
- @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
- private int iterations;
-
- /**
- * Instantiates a new SVM algo.
- *
- * @param name the name
- * @param description the description
- */
- public SVMAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- minBatchFraction = getParamValue("minBatchFraction", 1.0);
- regParam = getParamValue("regParam", 1.0);
- stepSize = getParamValue("stepSize", 1.0);
- iterations = getParamValue("iterations", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
- return new SVMClassificationModel(modelId, svmModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
deleted file mode 100644
index deee1b7..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.lens.ml.ClassifierBaseModel;
-
-import org.apache.spark.mllib.classification.ClassificationModel;
-import org.apache.spark.mllib.linalg.Vectors;
-
-/**
- * The Class BaseSparkClassificationModel.
- *
- * @param <MODEL> the generic type
- */
-public class BaseSparkClassificationModel<MODEL extends ClassificationModel> extends ClassifierBaseModel {
-
- /** The model id. */
- private final String modelId;
-
- /** The spark model. */
- private final MODEL sparkModel;
-
- /**
- * Instantiates a new base spark classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public BaseSparkClassificationModel(String modelId, MODEL model) {
- this.modelId = modelId;
- this.sparkModel = model;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public Double predict(Object... args) {
- return sparkModel.predict(Vectors.dense(getFeatureVector(args)));
- }
-
- @Override
- public String getId() {
- return modelId;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
deleted file mode 100644
index 0460024..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-/**
- * The Class DecisionTreeClassificationModel.
- */
-public class DecisionTreeClassificationModel extends BaseSparkClassificationModel<SparkDecisionTreeModel> {
-
- /**
- * Instantiates a new decision tree classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public DecisionTreeClassificationModel(String modelId, SparkDecisionTreeModel model) {
- super(modelId, model);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
deleted file mode 100644
index 959d9f4..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.lens.ml.MLModel;
-
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vectors;
-
-/**
- * The Class KMeansClusteringModel.
- */
-public class KMeansClusteringModel extends MLModel<Integer> {
-
- /** The model. */
- private final KMeansModel model;
-
- /** The model id. */
- private final String modelId;
-
- /**
- * Instantiates a new k means clustering model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public KMeansClusteringModel(String modelId, KMeansModel model) {
- this.model = model;
- this.modelId = modelId;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public Integer predict(Object... args) {
- // Convert the params to array of double
- double[] arr = new double[args.length];
- for (int i = 0; i < args.length; i++) {
- if (args[i] != null) {
- arr[i] = (Double) args[i];
- } else {
- arr[i] = 0d;
- }
- }
-
- return model.predict(Vectors.dense(arr));
- }
-}
[3/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Posted by sh...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
new file mode 100644
index 0000000..f0c6e04
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
@@ -0,0 +1,744 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import javax.ws.rs.client.Client;
+import javax.ws.rs.client.ClientBuilder;
+import javax.ws.rs.client.Entity;
+import javax.ws.rs.client.WebTarget;
+import javax.ws.rs.core.MediaType;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.LensQuery;
+import org.apache.lens.api.query.QueryHandle;
+import org.apache.lens.api.query.QueryStatus;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLDriver;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.SparkMLDriver;
+import org.apache.lens.ml.api.LensML;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.session.SessionService;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.session.SessionState;
+import org.apache.spark.api.java.JavaSparkContext;
+
+import org.glassfish.jersey.media.multipart.FormDataBodyPart;
+import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
+import org.glassfish.jersey.media.multipart.FormDataMultiPart;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+
+/**
+ * The Class LensMLImpl.
+ */
+public class LensMLImpl implements LensML {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /** The drivers. */
+ protected List<MLDriver> drivers;
+
+ /** The conf. */
+ private HiveConf conf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /** Check if the predict UDF has been registered for a user */
+ private final Map<LensSessionHandle, Boolean> predictUdfStatus;
+ /** Background thread to periodically check if we need to clear expire status for a session */
+ private ScheduledExecutorService udfStatusExpirySvc;
+
+ /**
+ * Instantiates a new lens ml impl.
+ *
+ * @param conf the conf
+ */
+ public LensMLImpl(HiveConf conf) {
+ this.conf = conf;
+ this.predictUdfStatus = new ConcurrentHashMap<LensSessionHandle, Boolean>();
+ }
+
+ public HiveConf getConf() {
+ return conf;
+ }
+
+ /**
+ * Use an existing Spark context. Useful in case of
+ *
+ * @param jsc JavaSparkContext instance
+ */
+ public void setSparkContext(JavaSparkContext jsc) {
+ this.sparkContext = jsc;
+ }
+
+ public List<String> getAlgorithms() {
+ List<String> algos = new ArrayList<String>();
+ for (MLDriver driver : drivers) {
+ algos.addAll(driver.getAlgoNames());
+ }
+ return algos;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
+ */
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ for (MLDriver driver : drivers) {
+ if (driver.isAlgoSupported(algorithm)) {
+ return driver.getAlgoInstance(algorithm);
+ }
+ }
+ throw new LensException("Algo not supported " + algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
+ */
+ public String train(String table, String algorithm, String[] args) throws LensException {
+ MLAlgo algo = getAlgoForName(algorithm);
+
+ String modelId = UUID.randomUUID().toString();
+
+ LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
+ + Arrays.toString(args));
+
+ String database = null;
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ } else {
+ database = "default";
+ }
+
+ MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
+
+ LOG.info("Done training model: " + modelId);
+
+ model.setCreatedAt(new Date());
+ model.setAlgoName(algorithm);
+
+ Path modelLocation = null;
+ try {
+ modelLocation = persistModel(model);
+ LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
+ return model.getId();
+ } catch (IOException e) {
+ throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
+ }
+ }
+
+ /**
+ * Gets the algo dir.
+ *
+ * @param algoName the algo name
+ * @return the algo dir
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ private Path getAlgoDir(String algoName) throws IOException {
+ String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
+ return new Path(new Path(modelSaveBaseDir), algoName);
+ }
+
+ /**
+ * Persist model.
+ *
+ * @param model the model
+ * @return the path
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ private Path persistModel(MLModel model) throws IOException {
+ // Get model save path
+ Path algoDir = getAlgoDir(model.getAlgoName());
+ FileSystem fs = algoDir.getFileSystem(conf);
+
+ if (!fs.exists(algoDir)) {
+ fs.mkdirs(algoDir);
+ }
+
+ Path modelSavePath = new Path(algoDir, model.getId());
+ ObjectOutputStream outputStream = null;
+
+ try {
+ outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
+ outputStream.writeObject(model);
+ outputStream.flush();
+ } catch (IOException io) {
+ LOG.error("Error saving model " + model.getId() + " reason: " + io.getMessage());
+ throw io;
+ } finally {
+ IOUtils.closeQuietly(outputStream);
+ }
+ return modelSavePath;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
+ */
+ public List<String> getModels(String algorithm) throws LensException {
+ try {
+ Path algoDir = getAlgoDir(algorithm);
+ FileSystem fs = algoDir.getFileSystem(conf);
+ if (!fs.exists(algoDir)) {
+ return null;
+ }
+
+ List<String> models = new ArrayList<String>();
+
+ for (FileStatus stat : fs.listStatus(algoDir)) {
+ models.add(stat.getPath().getName());
+ }
+
+ if (models.isEmpty()) {
+ return null;
+ }
+
+ return models;
+ } catch (IOException ioex) {
+ throw new LensException(ioex);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
+ */
+ public MLModel getModel(String algorithm, String modelId) throws LensException {
+ try {
+ return ModelLoader.loadModel(conf, algorithm, modelId);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+ }
+
+ /**
+ * Inits the.
+ *
+ * @param hiveConf the hive conf
+ */
+ public synchronized void init(HiveConf hiveConf) {
+ this.conf = hiveConf;
+
+ // Get all the drivers
+ String[] driverClasses = hiveConf.getStrings("lens.ml.drivers");
+
+ if (driverClasses == null || driverClasses.length == 0) {
+ throw new RuntimeException("No ML Drivers specified in conf");
+ }
+
+ LOG.info("Loading drivers " + Arrays.toString(driverClasses));
+ drivers = new ArrayList<MLDriver>(driverClasses.length);
+
+ for (String driverClass : driverClasses) {
+ Class<?> cls;
+ try {
+ cls = Class.forName(driverClass);
+ } catch (ClassNotFoundException e) {
+ LOG.error("Driver class not found " + driverClass);
+ continue;
+ }
+
+ if (!MLDriver.class.isAssignableFrom(cls)) {
+ LOG.warn("Not a driver class " + driverClass);
+ continue;
+ }
+
+ try {
+ Class<? extends MLDriver> mlDriverClass = (Class<? extends MLDriver>) cls;
+ MLDriver driver = mlDriverClass.newInstance();
+ driver.init(toLensConf(conf));
+ drivers.add(driver);
+ LOG.info("Added driver " + driverClass);
+ } catch (Exception e) {
+ LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
+ }
+ }
+ if (drivers.isEmpty()) {
+ throw new RuntimeException("No ML drivers loaded");
+ }
+
+ LOG.info("Inited ML service");
+ }
+
+ /**
+ * Start.
+ */
+ public synchronized void start() {
+ for (MLDriver driver : drivers) {
+ try {
+ if (driver instanceof SparkMLDriver && sparkContext != null) {
+ ((SparkMLDriver) driver).useSparkContext(sparkContext);
+ }
+ driver.start();
+ } catch (LensException e) {
+ LOG.error("Failed to start driver " + driver, e);
+ }
+ }
+
+ udfStatusExpirySvc = Executors.newSingleThreadScheduledExecutor();
+ udfStatusExpirySvc.scheduleAtFixedRate(new UDFStatusExpiryRunnable(), 60, 60, TimeUnit.SECONDS);
+
+ LOG.info("Started ML service");
+ }
+
+ /**
+ * Stop.
+ */
+ public synchronized void stop() {
+ for (MLDriver driver : drivers) {
+ try {
+ driver.stop();
+ } catch (LensException e) {
+ LOG.error("Failed to stop driver " + driver, e);
+ }
+ }
+ drivers.clear();
+ udfStatusExpirySvc.shutdownNow();
+ LOG.info("Stopped ML service");
+ }
+
+ public synchronized HiveConf getHiveConf() {
+ return conf;
+ }
+
+ /**
+ * Clear models.
+ */
+ public void clearModels() {
+ ModelLoader.clearCache();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
+ */
+ public String getModelPath(String algorithm, String modelID) {
+ return ModelLoader.getModelLocation(conf, algorithm, modelID).toString();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
+ * java.lang.String)
+ */
+ @Override
+ public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
+ String outputTable) throws LensException {
+ return null;
+ }
+
+ /**
+ * Test a model in embedded mode.
+ *
+ * @param sessionHandle the session handle
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param queryApiUrl the query api url
+ * @return the ML test report
+ * @throws LensException the lens exception
+ */
+ public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ String queryApiUrl, String outputTable) throws LensException {
+ return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl),
+ outputTable);
+ }
+
+ /**
+ * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its
+ * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is
+ * associated with this unique ID.
+ * <p/>
+ * <p>
+ * This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens
+ * server.
+ * </p>
+ *
+ * @param sessionHandle the session handle
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param queryRunner the query runner
+ * @param outputTable table where test output will be written
+ * @return the ML test report
+ * @throws LensException the lens exception
+ */
+ public MLTestReport testModel(final LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ QueryRunner queryRunner, String outputTable) throws LensException {
+ if (sessionHandle == null) {
+ throw new NullPointerException("Null session not allowed");
+ }
+ // check if algorithm exists
+ if (!getAlgorithms().contains(algorithm)) {
+ throw new LensException("No such algorithm " + algorithm);
+ }
+
+ MLModel<?> model;
+ try {
+ model = ModelLoader.loadModel(conf, algorithm, modelID);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+
+ if (model == null) {
+ throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm);
+ }
+
+ String database = null;
+
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ }
+
+ String testID = UUID.randomUUID().toString().replace("-", "_");
+ final String testTable = outputTable;
+ final String testResultColumn = "prediction_result";
+
+ // TODO support error metric UDAFs
+ TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf)
+ .database(database == null ? "default" : database).inputTable(table).featureColumns(model.getFeatureColumns())
+ .outputColumn(testResultColumn).lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID)
+ .outputTable(testTable).testID(testID).build();
+
+ String testQuery = spec.getTestQuery();
+ if (testQuery == null) {
+ throw new LensException("Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID="
+ + modelID);
+ }
+
+ if (!spec.isOutputTableExists()) {
+ LOG.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid="
+ + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery());
+ // create the output table
+ String createOutputTableQuery = spec.getCreateOutputTableQuery();
+ queryRunner.runQuery(createOutputTableQuery);
+ LOG.info("Table created " + testTable);
+ }
+
+ // Check if ML UDF is registered in this session
+ registerPredictUdf(sessionHandle, queryRunner);
+
+ LOG.info("Running evaluation query " + testQuery);
+ queryRunner.setQueryName("model_test_" + modelID);
+ QueryHandle testQueryHandle = queryRunner.runQuery(testQuery);
+
+ MLTestReport testReport = new MLTestReport();
+ testReport.setReportID(testID);
+ testReport.setAlgorithm(algorithm);
+ testReport.setFeatureColumns(model.getFeatureColumns());
+ testReport.setLabelColumn(model.getLabelColumn());
+ testReport.setModelID(model.getId());
+ testReport.setOutputColumn(testResultColumn);
+ testReport.setOutputTable(testTable);
+ testReport.setTestTable(table);
+ testReport.setQueryID(testQueryHandle.toString());
+
+ // Save test report
+ persistTestReport(testReport);
+ LOG.info("Saved test report " + testReport.getReportID());
+ return testReport;
+ }
+
+ /**
+ * Persist test report.
+ *
+ * @param testReport the test report
+ * @throws LensException the lens exception
+ */
+ private void persistTestReport(MLTestReport testReport) throws LensException {
+ LOG.info("saving test report " + testReport.getReportID());
+ try {
+ ModelLoader.saveTestReport(conf, testReport);
+ LOG.info("Saved report " + testReport.getReportID());
+ } catch (IOException e) {
+ LOG.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage());
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
+ */
+ public List<String> getTestReports(String algorithm) throws LensException {
+ Path reportBaseDir = new Path(conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT));
+ FileSystem fs = null;
+
+ try {
+ fs = reportBaseDir.getFileSystem(conf);
+ if (!fs.exists(reportBaseDir)) {
+ return null;
+ }
+
+ Path algoDir = new Path(reportBaseDir, algorithm);
+ if (!fs.exists(algoDir)) {
+ return null;
+ }
+
+ List<String> reports = new ArrayList<String>();
+ for (FileStatus stat : fs.listStatus(algoDir)) {
+ reports.add(stat.getPath().getName());
+ }
+ return reports;
+ } catch (IOException e) {
+ LOG.error("Error reading report list for " + algorithm, e);
+ return null;
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
+ */
+ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
+ try {
+ return ModelLoader.loadReport(conf, algorithm, reportID);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
+ */
+ public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
+ // Load the model instance
+ MLModel<?> model = getModel(algorithm, modelID);
+ return model.predict(features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
+ */
+ public void deleteModel(String algorithm, String modelID) throws LensException {
+ try {
+ ModelLoader.deleteModel(conf, algorithm, modelID);
+ LOG.info("DELETED model " + modelID + " algorithm=" + algorithm);
+ } catch (IOException e) {
+ LOG.error(
+ "Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e);
+ throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
+ */
+ public void deleteTestReport(String algorithm, String reportID) throws LensException {
+ try {
+ ModelLoader.deleteTestReport(conf, algorithm, reportID);
+ LOG.info("DELETED report=" + reportID + " algorithm=" + algorithm);
+ } catch (IOException e) {
+ LOG.error("Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e);
+ throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
+ */
+ public Map<String, String> getAlgoParamDescription(String algorithm) {
+ MLAlgo algo = null;
+ try {
+ algo = getAlgoForName(algorithm);
+ } catch (LensException e) {
+ LOG.error("Error getting algo description : " + algorithm, e);
+ return null;
+ }
+ if (algo instanceof BaseSparkAlgo) {
+ return ((BaseSparkAlgo) algo).getArgUsage();
+ }
+ return null;
+ }
+
+ /**
+ * Submit model test query to a remote Lens server.
+ */
+ class RemoteQueryRunner extends QueryRunner {
+
+ /** The query api url. */
+ final String queryApiUrl;
+
+ /**
+ * Instantiates a new remote query runner.
+ *
+ * @param sessionHandle the session handle
+ * @param queryApiUrl the query api url
+ */
+ public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) {
+ super(sessionHandle);
+ this.queryApiUrl = queryApiUrl;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
+ */
+ @Override
+ public QueryHandle runQuery(String query) throws LensException {
+ // Create jersey client for query endpoint
+ Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build();
+ WebTarget target = client.target(queryApiUrl);
+ final FormDataMultiPart mp = new FormDataMultiPart();
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
+ MediaType.APPLICATION_XML_TYPE));
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query));
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute"));
+
+ LensConf lensConf = new LensConf();
+ lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
+ lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf,
+ MediaType.APPLICATION_XML_TYPE));
+
+ final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE),
+ QueryHandle.class);
+
+ LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request()
+ .get(LensQuery.class);
+
+ QueryStatus stat = ctx.getStatus();
+ while (!stat.isFinished()) {
+ ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request().get(LensQuery.class);
+ stat = ctx.getStatus();
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new LensException(e);
+ }
+ }
+
+ if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) {
+ throw new LensException("Query failed " + ctx.getQueryHandle().getHandleId() + " reason:"
+ + stat.getErrorMessage());
+ }
+
+ return ctx.getQueryHandle();
+ }
+ }
+
+ /**
+ * To lens conf.
+ *
+ * @param conf the conf
+ * @return the lens conf
+ */
+ private LensConf toLensConf(HiveConf conf) {
+ LensConf lensConf = new LensConf();
+ lensConf.getProperties().putAll(conf.getValByRegex(".*"));
+ return lensConf;
+ }
+
+ protected void registerPredictUdf(LensSessionHandle sessionHandle, QueryRunner queryRunner) throws LensException {
+ if (isUdfRegisterd(sessionHandle)) {
+ // Already registered, nothing to do
+ return;
+ }
+
+ LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
+ // We have to add UDF jars to the session
+ try {
+ SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
+ String[] udfJars = conf.getStrings("lens.server.ml.predict.udf.jars");
+ if (udfJars != null) {
+ for (String jar : udfJars) {
+ sessionService.addResource(sessionHandle, "jar", jar);
+ LOG.info(jar + " added UDF session " + sessionHandle.getPublicId().toString());
+ }
+ }
+ } catch (Exception e) {
+ throw new LensException(e);
+ }
+
+ String regUdfQuery = "CREATE TEMPORARY FUNCTION " + HiveMLUDF.UDF_NAME + " AS '" + HiveMLUDF.class
+ .getCanonicalName() + "'";
+ queryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
+ QueryHandle udfQuery = queryRunner.runQuery(regUdfQuery);
+ predictUdfStatus.put(sessionHandle, true);
+ LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
+ }
+
+ protected boolean isUdfRegisterd(LensSessionHandle sessionHandle) {
+ return predictUdfStatus.containsKey(sessionHandle);
+ }
+
+ /**
+ * Periodically check if sessions have been closed, and clear UDF registered status.
+ */
+ private class UDFStatusExpiryRunnable implements Runnable {
+ public void run() {
+ try {
+ SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
+ // Clear status of sessions which are closed.
+ List<LensSessionHandle> sessions = new ArrayList<LensSessionHandle>(predictUdfStatus.keySet());
+ for (LensSessionHandle sessionHandle : sessions) {
+ if (!sessionService.isOpen(sessionHandle)) {
+ LOG.info("Session closed, removing UDF status: " + sessionHandle);
+ predictUdfStatus.remove(sessionHandle);
+ }
+ }
+ } catch (Exception exc) {
+ LOG.warn("Error clearing UDF statuses", exc);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
new file mode 100644
index 0000000..625d020
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
@@ -0,0 +1,172 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+
+import org.apache.lens.client.LensClient;
+import org.apache.lens.client.LensClientConfig;
+import org.apache.lens.client.LensMLClient;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.TableType;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
+import org.apache.hadoop.hive.serde.serdeConstants;
+import org.apache.hadoop.mapred.TextInputFormat;
+
+public class MLRunner {
+
+ private static final Log LOG = LogFactory.getLog(MLRunner.class);
+
+ private LensMLClient mlClient;
+ private String algoName;
+ private String database;
+ private String trainTable;
+ private String trainFile;
+ private String testTable;
+ private String testFile;
+ private String outputTable;
+ private String[] features;
+ private String labelColumn;
+ private HiveConf conf;
+
+ public void init(LensMLClient mlClient, String confDir) throws Exception {
+ File dir = new File(confDir);
+ File propFile = new File(dir, "ml.properties");
+ Properties props = new Properties();
+ props.load(new FileInputStream(propFile));
+ String feat = props.getProperty("features");
+ String trainFile = confDir + File.separator + "train.data";
+ String testFile = confDir + File.separator + "test.data";
+ init(mlClient, props.getProperty("algo"), props.getProperty("database"),
+ props.getProperty("traintable"), trainFile,
+ props.getProperty("testtable"), testFile,
+ props.getProperty("outputtable"), feat.split(","),
+ props.getProperty("labelcolumn"));
+ }
+
+ public void init(LensMLClient mlClient, String algoName,
+ String database, String trainTable, String trainFile,
+ String testTable, String testFile, String outputTable, String[] features,
+ String labelColumn) {
+ this.mlClient = mlClient;
+ this.algoName = algoName;
+ this.database = database;
+ this.trainTable = trainTable;
+ this.trainFile = trainFile;
+ this.testTable = testTable;
+ this.testFile = testFile;
+ this.outputTable = outputTable;
+ this.features = features;
+ this.labelColumn = labelColumn;
+ //hive metastore settings are loaded via lens-site.xml, so loading LensClientConfig
+ //is required
+ this.conf = new HiveConf(new LensClientConfig(), MLRunner.class);
+ }
+
+ public MLTask train() throws Exception {
+ LOG.info("Starting train & eval");
+
+ createTable(trainTable, trainFile);
+ createTable(testTable, testFile);
+ MLTask.Builder taskBuilder = new MLTask.Builder();
+ taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn).outputTable(outputTable)
+ .client(mlClient).trainingTable(trainTable).testTable(testTable);
+
+ // Add features
+ for (String feature : features) {
+ taskBuilder.addFeatureColumn(feature);
+ }
+ MLTask task = taskBuilder.build();
+ LOG.info("Created task " + task.toString());
+ task.run();
+ return task;
+ }
+
+ public void createTable(String tableName, String dataFile) throws HiveException {
+
+ File filedataFile = new File(dataFile);
+ Path dataFilePath = new Path(filedataFile.toURI());
+ Path partDir = dataFilePath.getParent();
+
+ // Create table
+ List<FieldSchema> columns = new ArrayList<FieldSchema>();
+
+ // Label is optional. Not used for unsupervised models.
+ // If present, label will be the first column, followed by features
+ if (labelColumn != null) {
+ columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
+ }
+
+ for (String feature : features) {
+ columns.add(new FieldSchema(feature, "double", "Feature " + feature));
+ }
+
+ Table tbl = Hive.get(conf).newTable(database + "." + tableName);
+ tbl.setTableType(TableType.MANAGED_TABLE);
+ tbl.getTTable().getSd().setCols(columns);
+ // tbl.getTTable().getParameters().putAll(new HashMap<String, String>());
+ tbl.setInputFormatClass(TextInputFormat.class);
+ tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
+ tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
+
+ List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
+ partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
+ tbl.setPartCols(partCols);
+
+ Hive.get(conf).dropTable(database, tableName, false, true);
+ Hive.get(conf).createTable(tbl, true);
+ LOG.info("Created table " + tableName);
+
+ // Add partition for the data file
+ AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName,
+ false);
+ Map<String, String> partSpec = new HashMap<String, String>();
+ partSpec.put("dummy_partition_col", "dummy_val");
+ partitionDesc.addPartition(partSpec, partDir.toUri().toString());
+ Hive.get(conf).createPartitions(partitionDesc);
+ LOG.info(tableName + ": Added partition " + partDir.toUri().toString());
+ }
+
+ public static void main(String[] args) throws Exception {
+ if (args.length < 1) {
+ System.out.println("Usage: " + MLRunner.class.getName() + " <ml-conf-dir>");
+ System.exit(-1);
+ }
+ String confDir = args[0];
+ LensMLClient client = new LensMLClient(new LensClient());
+ MLRunner runner = new MLRunner();
+ runner.init(client, confDir);
+ runner.train();
+ System.out.println("Created the Model successfully. Output Table: " + runner.outputTable);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
new file mode 100644
index 0000000..2867b90
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
@@ -0,0 +1,285 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.util.*;
+
+import org.apache.lens.client.LensMLClient;
+import org.apache.lens.ml.api.LensML;
+import org.apache.lens.ml.api.MLTestReport;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import lombok.Getter;
+import lombok.ToString;
+
+/**
+ * Run a complete cycle of train and test (evaluation) for an ML algorithm
+ */
+@ToString
+public class MLTask implements Runnable {
+ private static final Log LOG = LogFactory.getLog(MLTask.class);
+
+ public enum State {
+ RUNNING, SUCCESSFUL, FAILED
+ }
+
+ @Getter
+ private State taskState;
+
+ /**
+ * Name of the algo/algorithm.
+ */
+ @Getter
+ private String algorithm;
+
+ /**
+ * Name of the table containing training data.
+ */
+ @Getter
+ private String trainingTable;
+
+ /**
+ * Name of the table containing test data. Optional, if not provided trainingTable itself is
+ * used for testing
+ */
+ @Getter
+ private String testTable;
+
+ /**
+ * Training table partition spec
+ */
+ @Getter
+ private String partitionSpec;
+
+ /**
+ * Name of the column which is a label for supervised algorithms.
+ */
+ @Getter
+ private String labelColumn;
+
+ /**
+ * Names of columns which are features in the training data.
+ */
+ @Getter
+ private List<String> featureColumns;
+
+ /**
+ * Configuration for the example.
+ */
+ @Getter
+ private HiveConf configuration;
+
+ private LensML ml;
+ private String taskID;
+
+ /**
+ * ml client
+ */
+ @Getter
+ private LensMLClient mlClient;
+
+ /**
+ * Output table name
+ */
+ @Getter
+ private String outputTable;
+
+ /**
+ * Extra params passed to the training algorithm
+ */
+ @Getter
+ private Map<String, String> extraParams;
+
+ @Getter
+ private String modelID;
+
+ @Getter
+ private String reportID;
+
+ /**
+ * Use ExampleTask.Builder to create an instance
+ */
+ private MLTask() {
+ // Use builder to construct the example
+ extraParams = new HashMap<String, String>();
+ taskID = UUID.randomUUID().toString();
+ }
+
+ /**
+ * Builder to create an example task
+ */
+ public static class Builder {
+ private MLTask task;
+
+ public Builder() {
+ task = new MLTask();
+ }
+
+ public Builder trainingTable(String trainingTable) {
+ task.trainingTable = trainingTable;
+ return this;
+ }
+
+ public Builder testTable(String testTable) {
+ task.testTable = testTable;
+ return this;
+ }
+
+ public Builder algorithm(String algorithm) {
+ task.algorithm = algorithm;
+ return this;
+ }
+
+ public Builder labelColumn(String labelColumn) {
+ task.labelColumn = labelColumn;
+ return this;
+ }
+
+ public Builder client(LensMLClient client) {
+ task.mlClient = client;
+ return this;
+ }
+
+ public Builder addFeatureColumn(String featureColumn) {
+ if (task.featureColumns == null) {
+ task.featureColumns = new ArrayList<String>();
+ }
+ task.featureColumns.add(featureColumn);
+ return this;
+ }
+
+ public Builder hiveConf(HiveConf hiveConf) {
+ task.configuration = hiveConf;
+ return this;
+ }
+
+
+
+ public Builder extraParam(String param, String value) {
+ task.extraParams.put(param, value);
+ return this;
+ }
+
+ public Builder partitionSpec(String partitionSpec) {
+ task.partitionSpec = partitionSpec;
+ return this;
+ }
+
+ public Builder outputTable(String outputTable) {
+ task.outputTable = outputTable;
+ return this;
+ }
+
+ public MLTask build() {
+ MLTask builtTask = task;
+ task = null;
+ return builtTask;
+ }
+
+ }
+
+ @Override
+ public void run() {
+ taskState = State.RUNNING;
+ LOG.info("Starting " + taskID);
+ try {
+ runTask();
+ taskState = State.SUCCESSFUL;
+ LOG.info("Complete " + taskID);
+ } catch (Exception e) {
+ taskState = State.FAILED;
+ LOG.info("Error running task " + taskID, e);
+ }
+ }
+
+ /**
+ * Train an ML model, with specified algorithm and input data. Do model evaluation using the evaluation data and print
+ * evaluation result
+ *
+ * @throws Exception
+ */
+ private void runTask() throws Exception {
+ if (mlClient != null) {
+ // Connect to a remote Lens server
+ ml = mlClient;
+ LOG.info("Working in client mode. Lens session handle " + mlClient.getSessionHandle().getPublicId());
+ } else {
+ // In server mode session handle has to be passed by the user as a request parameter
+ ml = MLUtils.getMLService();
+ LOG.info("Working in Lens server");
+ }
+
+ String[] algoArgs = buildTrainingArgs();
+ LOG.info("Starting task " + taskID + " algo args: " + Arrays.toString(algoArgs));
+
+ modelID = ml.train(trainingTable, algorithm, algoArgs);
+ printModelMetadata(taskID, modelID);
+
+ LOG.info("Starting test " + taskID);
+ testTable = (testTable != null) ? testTable : trainingTable;
+ MLTestReport testReport = ml.testModel(mlClient.getSessionHandle(), testTable, algorithm, modelID, outputTable);
+ reportID = testReport.getReportID();
+ printTestReport(taskID, testReport);
+ saveTask();
+ }
+
+ // Save task metadata to DB
+ private void saveTask() {
+ LOG.info("Saving task details to DB");
+ }
+
+ private void printTestReport(String exampleID, MLTestReport testReport) {
+ StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
+ builder.append("\n\t");
+ builder.append("EvaluationReport: ").append(testReport.toString());
+ System.out.println(builder.toString());
+ }
+
+ private String[] buildTrainingArgs() {
+ List<String> argList = new ArrayList<String>();
+ argList.add("label");
+ argList.add(labelColumn);
+
+ // Add all the features
+ for (String featureCol : featureColumns) {
+ argList.add("feature");
+ argList.add(featureCol);
+ }
+
+ // Add extra params
+ for (String param : extraParams.keySet()) {
+ argList.add(param);
+ argList.add(extraParams.get(param));
+ }
+
+ return argList.toArray(new String[argList.size()]);
+ }
+
+ // Get the model instance and print its metadat to stdout
+ private void printModelMetadata(String exampleID, String modelID) throws Exception {
+ StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
+ builder.append("\n\t");
+ builder.append("Model: ");
+ builder.append(ml.getModel(algorithm, modelID).toString());
+ System.out.println(builder.toString());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
new file mode 100644
index 0000000..9c96d9b
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.server.MLService;
+import org.apache.lens.ml.server.MLServiceImpl;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+
+public final class MLUtils {
+ private MLUtils() {
+ }
+
+ private static final HiveConf HIVE_CONF;
+
+ static {
+ HIVE_CONF = new HiveConf();
+ // Add default config so that we know the service provider implementation
+ HIVE_CONF.addResource("lensserver-default.xml");
+ HIVE_CONF.addResource("lens-site.xml");
+ }
+
+ public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
+ Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
+ if (annotation != null) {
+ return annotation.name();
+ }
+ throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
+ }
+
+ public static MLServiceImpl getMLService() throws Exception {
+ return getServiceProvider().getService(MLService.NAME);
+ }
+
+ public static ServiceProvider getServiceProvider() throws Exception {
+ Class<? extends ServiceProviderFactory> spfClass = HIVE_CONF.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY,
+ null, ServiceProviderFactory.class);
+ ServiceProviderFactory spf = spfClass.newInstance();
+ return spf.getServiceProvider();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
new file mode 100644
index 0000000..c0e7953
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
@@ -0,0 +1,242 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+
+/**
+ * Load ML models from a FS location.
+ */
+public final class ModelLoader {
+ private ModelLoader() {
+ }
+
+ /** The Constant MODEL_PATH_BASE_DIR. */
+ public static final String MODEL_PATH_BASE_DIR = "lens.ml.model.basedir";
+
+ /** The Constant MODEL_PATH_BASE_DIR_DEFAULT. */
+ public static final String MODEL_PATH_BASE_DIR_DEFAULT = "file:///tmp";
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(ModelLoader.class);
+
+ /** The Constant TEST_REPORT_BASE_DIR. */
+ public static final String TEST_REPORT_BASE_DIR = "lens.ml.test.basedir";
+
+ /** The Constant TEST_REPORT_BASE_DIR_DEFAULT. */
+ public static final String TEST_REPORT_BASE_DIR_DEFAULT = "file:///tmp/ml_reports";
+
+ // Model cache settings
+ /** The Constant MODEL_CACHE_SIZE. */
+ public static final long MODEL_CACHE_SIZE = 10;
+
+ /** The Constant MODEL_CACHE_TIMEOUT. */
+ public static final long MODEL_CACHE_TIMEOUT = 3600000L; // one hour
+
+ /** The model cache. */
+ private static Cache<Path, MLModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
+ .expireAfterAccess(MODEL_CACHE_TIMEOUT, TimeUnit.MILLISECONDS).build();
+
+ /**
+ * Gets the model location.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the model location
+ */
+ public static Path getModelLocation(Configuration conf, String algorithm, String modelID) {
+ String modelDataBaseDir = conf.get(MODEL_PATH_BASE_DIR, MODEL_PATH_BASE_DIR_DEFAULT);
+ // Model location format - <modelDataBaseDir>/<algorithm>/modelID
+ return new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID);
+ }
+
+ /**
+ * Load model.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the ML model
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static MLModel loadModel(Configuration conf, String algorithm, String modelID) throws IOException {
+ final Path modelPath = getModelLocation(conf, algorithm, modelID);
+ LOG.info("Loading model for algorithm: " + algorithm + " modelID: " + modelID + " At path: "
+ + modelPath.toUri().toString());
+ try {
+ return modelCache.get(modelPath, new Callable<MLModel>() {
+ @Override
+ public MLModel call() throws Exception {
+ FileSystem fs = modelPath.getFileSystem(new HiveConf());
+ if (!fs.exists(modelPath)) {
+ throw new IOException("Model path not found " + modelPath.toString());
+ }
+
+ ObjectInputStream ois = null;
+ try {
+ ois = new ObjectInputStream(fs.open(modelPath));
+ MLModel model = (MLModel) ois.readObject();
+ LOG.info("Loaded model " + model.getId() + " from location " + modelPath);
+ return model;
+ } catch (ClassNotFoundException e) {
+ throw new IOException(e);
+ } finally {
+ IOUtils.closeQuietly(ois);
+ }
+ }
+ });
+ } catch (ExecutionException exc) {
+ throw new IOException(exc);
+ }
+ }
+
+ /**
+ * Clear cache.
+ */
+ public static void clearCache() {
+ modelCache.cleanUp();
+ }
+
+ /**
+ * Gets the test report path.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param report the report
+ * @return the test report path
+ */
+ public static Path getTestReportPath(Configuration conf, String algorithm, String report) {
+ String testReportDir = conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT);
+ return new Path(new Path(testReportDir, algorithm), report);
+ }
+
+ /**
+ * Save test report.
+ *
+ * @param conf the conf
+ * @param report the report
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void saveTestReport(Configuration conf, MLTestReport report) throws IOException {
+ Path reportDir = new Path(conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT));
+ FileSystem fs = reportDir.getFileSystem(conf);
+
+ if (!fs.exists(reportDir)) {
+ LOG.info("Creating test report dir " + reportDir.toUri().toString());
+ fs.mkdirs(reportDir);
+ }
+
+ Path algoDir = new Path(reportDir, report.getAlgorithm());
+
+ if (!fs.exists(algoDir)) {
+ LOG.info("Creating algorithm report dir " + algoDir.toUri().toString());
+ fs.mkdirs(algoDir);
+ }
+
+ ObjectOutputStream reportOutputStream = null;
+ Path reportSaveLocation;
+ try {
+ reportSaveLocation = new Path(algoDir, report.getReportID());
+ reportOutputStream = new ObjectOutputStream(fs.create(reportSaveLocation));
+ reportOutputStream.writeObject(report);
+ reportOutputStream.flush();
+ } catch (IOException ioexc) {
+ LOG.error("Error saving test report " + report.getReportID(), ioexc);
+ throw ioexc;
+ } finally {
+ IOUtils.closeQuietly(reportOutputStream);
+ }
+ LOG.info("Saved report " + report.getReportID() + " at location " + reportSaveLocation.toUri());
+ }
+
+ /**
+ * Load report.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the ML test report
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static MLTestReport loadReport(Configuration conf, String algorithm, String reportID) throws IOException {
+ Path reportLocation = getTestReportPath(conf, algorithm, reportID);
+ FileSystem fs = reportLocation.getFileSystem(conf);
+ ObjectInputStream reportStream = null;
+ MLTestReport report = null;
+
+ try {
+ reportStream = new ObjectInputStream(fs.open(reportLocation));
+ report = (MLTestReport) reportStream.readObject();
+ } catch (IOException ioex) {
+ LOG.error("Error reading report " + reportLocation, ioex);
+ } catch (ClassNotFoundException e) {
+ throw new IOException(e);
+ } finally {
+ IOUtils.closeQuietly(reportStream);
+ }
+ return report;
+ }
+
+ /**
+ * Delete model.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void deleteModel(HiveConf conf, String algorithm, String modelID) throws IOException {
+ Path modelLocation = getModelLocation(conf, algorithm, modelID);
+ FileSystem fs = modelLocation.getFileSystem(conf);
+ fs.delete(modelLocation, false);
+ }
+
+ /**
+ * Delete test report.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void deleteTestReport(HiveConf conf, String algorithm, String reportID) throws IOException {
+ Path reportPath = getTestReportPath(conf, algorithm, reportID);
+ reportPath.getFileSystem(conf).delete(reportPath, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
new file mode 100644
index 0000000..2f2e017
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.QueryHandle;
+
+import lombok.Getter;
+import lombok.Setter;
+
+/**
+ * Run a query against a Lens server.
+ */
+public abstract class QueryRunner {
+
+ /** The session handle. */
+ protected final LensSessionHandle sessionHandle;
+
+ @Getter @Setter
+ protected String queryName;
+
+ /**
+ * Instantiates a new query runner.
+ *
+ * @param sessionHandle the session handle
+ */
+ public QueryRunner(LensSessionHandle sessionHandle) {
+ this.sessionHandle = sessionHandle;
+ }
+
+ /**
+ * Run query.
+ *
+ * @param query the query
+ * @return the query handle
+ * @throws LensException the lens exception
+ */
+ public abstract QueryHandle runQuery(String query) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
new file mode 100644
index 0000000..34b2a3f
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
@@ -0,0 +1,325 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+
+import lombok.Getter;
+
+/**
+ * Table specification for running test on a table.
+ */
+public class TableTestingSpec {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);
+
+ /** The db. */
+ private String db;
+
+ /** The table containing input data. */
+ private String inputTable;
+
+ // TODO use partition condition
+ /** The partition filter. */
+ private String partitionFilter;
+
+ /** The feature columns. */
+ private List<String> featureColumns;
+
+ /** The label column. */
+ private String labelColumn;
+
+ /** The output column. */
+ private String outputColumn;
+
+ /** The output table. */
+ private String outputTable;
+
+ /** The conf. */
+ private transient HiveConf conf;
+
+ /** The algorithm. */
+ private String algorithm;
+
+ /** The model id. */
+ private String modelID;
+
+ @Getter
+ private boolean outputTableExists;
+
+ @Getter
+ private String testID;
+
+ private HashMap<String, FieldSchema> columnNameToFieldSchema;
+
+ /**
+ * The Class TableTestingSpecBuilder.
+ */
+ public static class TableTestingSpecBuilder {
+
+ /** The spec. */
+ private final TableTestingSpec spec;
+
+ /**
+ * Instantiates a new table testing spec builder.
+ */
+ public TableTestingSpecBuilder() {
+ spec = new TableTestingSpec();
+ }
+
+ /**
+ * Database.
+ *
+ * @param database the database
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder database(String database) {
+ spec.db = database;
+ return this;
+ }
+
+ /**
+ * Set the input table
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder inputTable(String table) {
+ spec.inputTable = table;
+ return this;
+ }
+
+ /**
+ * Partition filter for input table
+ *
+ * @param partFilter the part filter
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder partitionFilter(String partFilter) {
+ spec.partitionFilter = partFilter;
+ return this;
+ }
+
+ /**
+ * Feature columns.
+ *
+ * @param featureColumns the feature columns
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder featureColumns(List<String> featureColumns) {
+ spec.featureColumns = featureColumns;
+ return this;
+ }
+
+ /**
+ * Labe column.
+ *
+ * @param labelColumn the label column
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder lableColumn(String labelColumn) {
+ spec.labelColumn = labelColumn;
+ return this;
+ }
+
+ /**
+ * Output column.
+ *
+ * @param outputColumn the output column
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputColumn(String outputColumn) {
+ spec.outputColumn = outputColumn;
+ return this;
+ }
+
+ /**
+ * Output table.
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputTable(String table) {
+ spec.outputTable = table;
+ return this;
+ }
+
+ /**
+ * Hive conf.
+ *
+ * @param conf the conf
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder hiveConf(HiveConf conf) {
+ spec.conf = conf;
+ return this;
+ }
+
+ /**
+ * Algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder algorithm(String algorithm) {
+ spec.algorithm = algorithm;
+ return this;
+ }
+
+ /**
+ * Model id.
+ *
+ * @param modelID the model id
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder modelID(String modelID) {
+ spec.modelID = modelID;
+ return this;
+ }
+
+ /**
+ * Builds the.
+ *
+ * @return the table testing spec
+ */
+ public TableTestingSpec build() {
+ return spec;
+ }
+
+ /**
+ * Set the unique test id
+ *
+ * @param testID
+ * @return
+ */
+ public TableTestingSpecBuilder testID(String testID) {
+ spec.testID = testID;
+ return this;
+ }
+ }
+
+ /**
+ * New builder.
+ *
+ * @return the table testing spec builder
+ */
+ public static TableTestingSpecBuilder newBuilder() {
+ return new TableTestingSpecBuilder();
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ public boolean validate() {
+ List<FieldSchema> columns;
+ try {
+ Hive metastoreClient = Hive.get(conf);
+ Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
+ columns = tbl.getAllCols();
+ columnNameToFieldSchema = new HashMap<String, FieldSchema>();
+
+ for (FieldSchema fieldSchema : columns) {
+ columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
+ }
+
+ // Check if output table exists
+ Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
+ outputTableExists = (outTbl != null);
+ } catch (HiveException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ // Check if labeled column and feature columns are contained in the table
+ List<String> testTableColumns = new ArrayList<String>(columns.size());
+ for (FieldSchema column : columns) {
+ testTableColumns.add(column.getName());
+ }
+
+ if (!testTableColumns.containsAll(featureColumns)) {
+ LOG.info("Invalid feature columns: " + featureColumns + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (!testTableColumns.contains(labelColumn)) {
+ LOG.info("Invalid label column: " + labelColumn + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputColumn)) {
+ LOG.info("Output column is required");
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputTable)) {
+ LOG.info("Output table is required");
+ return false;
+ }
+ return true;
+ }
+
+ public String getTestQuery() {
+ if (!validate()) {
+ return null;
+ }
+
+ // We always insert a dynamic partition
+ StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
+ + "') SELECT ");
+ String featureCols = StringUtils.join(featureColumns, ",");
+ q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
+ .append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
+ .append(" FROM ").append(inputTable);
+
+ return q.toString();
+ }
+
+ public String getCreateOutputTableQuery() {
+ StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
+ // Output table contains feature columns, label column, output column
+ List<String> outputTableColumns = new ArrayList<String>();
+ for (String featureCol : featureColumns) {
+ outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
+ }
+
+ outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
+ outputTableColumns.add(outputColumn + " string");
+
+ createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
+
+ // Append partition column
+ createTableQuery.append(") PARTITIONED BY (part_testid string)");
+
+ return createTableQuery.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
new file mode 100644
index 0000000..e6e3c02
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import javax.ws.rs.ApplicationPath;
+import javax.ws.rs.core.Application;
+
+import org.glassfish.jersey.filter.LoggingFilter;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+
+@ApplicationPath("/ml")
+public class MLApp extends Application {
+
+ private final Set<Class<?>> classes;
+
+ /**
+ * Pass additional classes when running in test mode
+ *
+ * @param additionalClasses
+ */
+ public MLApp(Class<?>... additionalClasses) {
+ classes = new HashSet<Class<?>>();
+
+ // register root resource
+ classes.add(MLServiceResource.class);
+ classes.add(MultiPartFeature.class);
+ classes.add(LoggingFilter.class);
+ for (Class<?> cls : additionalClasses) {
+ classes.add(cls);
+ }
+
+ }
+
+ /**
+ * Get classes for this resource
+ */
+ @Override
+ public Set<Class<?>> getClasses() {
+ return classes;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
new file mode 100644
index 0000000..f8b7cd1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import org.apache.lens.ml.api.LensML;
+
+/**
+ * The Interface MLService.
+ */
+public interface MLService extends LensML {
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
new file mode 100644
index 0000000..f3e8ec1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
@@ -0,0 +1,329 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.LensQuery;
+import org.apache.lens.api.query.QueryHandle;
+import org.apache.lens.api.query.QueryStatus;
+import org.apache.lens.ml.algo.api.*;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.impl.HiveMLUDF;
+import org.apache.lens.ml.impl.LensMLImpl;
+import org.apache.lens.ml.impl.ModelLoader;
+import org.apache.lens.ml.impl.QueryRunner;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+import org.apache.lens.server.api.query.QueryExecutionService;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
+import org.apache.hive.service.CompositeService;
+
+/**
+ * The Class MLServiceImpl.
+ */
+public class MLServiceImpl extends CompositeService implements MLService {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /** The ml. */
+ private LensMLImpl ml;
+
+ /** The service provider. */
+ private ServiceProvider serviceProvider;
+
+ /** The service provider factory. */
+ private ServiceProviderFactory serviceProviderFactory;
+
+ /**
+ * Instantiates a new ML service impl.
+ */
+ public MLServiceImpl() {
+ this(NAME);
+ }
+
+ /**
+ * Instantiates a new ML service impl.
+ *
+ * @param name the name
+ */
+ public MLServiceImpl(String name) {
+ super(name);
+ }
+
+ @Override
+ public List<String> getAlgorithms() {
+ return ml.getAlgorithms();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
+ */
+ @Override
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ return ml.getAlgoForName(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
+ */
+ @Override
+ public String train(String table, String algorithm, String[] args) throws LensException {
+ return ml.train(table, algorithm, args);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
+ */
+ @Override
+ public List<String> getModels(String algorithm) throws LensException {
+ return ml.getModels(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
+ */
+ @Override
+ public MLModel getModel(String algorithm, String modelId) throws LensException {
+ return ml.getModel(algorithm, modelId);
+ }
+
+ private ServiceProvider getServiceProvider() {
+ if (serviceProvider == null) {
+ serviceProvider = serviceProviderFactory.getServiceProvider();
+ }
+ return serviceProvider;
+ }
+
+ /**
+ * Gets the service provider factory.
+ *
+ * @param conf the conf
+ * @return the service provider factory
+ */
+ private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
+ Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
+ try {
+ return (ServiceProviderFactory) spfClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#init(org.apache.hadoop.hive.conf.HiveConf)
+ */
+ @Override
+ public synchronized void init(HiveConf hiveConf) {
+ ml = new LensMLImpl(hiveConf);
+ ml.init(hiveConf);
+ super.init(hiveConf);
+ serviceProviderFactory = getServiceProviderFactory(hiveConf);
+ LOG.info("Inited ML service");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#start()
+ */
+ @Override
+ public synchronized void start() {
+ ml.start();
+ super.start();
+ LOG.info("Started ML service");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#stop()
+ */
+ @Override
+ public synchronized void stop() {
+ ml.stop();
+ super.stop();
+ LOG.info("Stopped ML service");
+ }
+
+ /**
+ * Clear models.
+ */
+ public void clearModels() {
+ ModelLoader.clearCache();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
+ */
+ @Override
+ public String getModelPath(String algorithm, String modelID) {
+ return ml.getModelPath(algorithm, modelID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
+ * java.lang.String)
+ */
+ @Override
+ public MLTestReport testModel(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ String outputTable) throws LensException {
+ return ml.testModel(sessionHandle, table, algorithm, modelID, new DirectQueryRunner(sessionHandle), outputTable);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
+ */
+ @Override
+ public List<String> getTestReports(String algorithm) throws LensException {
+ return ml.getTestReports(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
+ */
+ @Override
+ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
+ return ml.getTestReport(algorithm, reportID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
+ */
+ @Override
+ public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
+ return ml.predict(algorithm, modelID, features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
+ */
+ @Override
+ public void deleteModel(String algorithm, String modelID) throws LensException {
+ ml.deleteModel(algorithm, modelID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
+ */
+ @Override
+ public void deleteTestReport(String algorithm, String reportID) throws LensException {
+ ml.deleteTestReport(algorithm, reportID);
+ }
+
+ /**
+ * Run the test model query directly in the current lens server process.
+ */
+ private class DirectQueryRunner extends QueryRunner {
+
+ /**
+ * Instantiates a new direct query runner.
+ *
+ * @param sessionHandle the session handle
+ */
+ public DirectQueryRunner(LensSessionHandle sessionHandle) {
+ super(sessionHandle);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
+ */
+ @Override
+ public QueryHandle runQuery(String testQuery) throws LensException {
+ FunctionRegistry.registerTemporaryFunction("predict", HiveMLUDF.class);
+ LOG.info("Registered predict UDF");
+ // Run the query in query executions service
+ QueryExecutionService queryService = (QueryExecutionService) getServiceProvider().getService("query");
+
+ LensConf queryConf = new LensConf();
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
+
+ QueryHandle testQueryHandle = queryService.executeAsync(sessionHandle, testQuery, queryConf, queryName);
+
+ // Wait for test query to complete
+ LensQuery query = queryService.getQuery(sessionHandle, testQueryHandle);
+ LOG.info("Submitted query " + testQueryHandle.getHandleId());
+ while (!query.getStatus().isFinished()) {
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new LensException(e);
+ }
+
+ query = queryService.getQuery(sessionHandle, testQueryHandle);
+ }
+
+ if (query.getStatus().getStatus() != QueryStatus.Status.SUCCESSFUL) {
+ throw new LensException("Failed to run test query: " + testQueryHandle.getHandleId() + " reason= "
+ + query.getStatus().getErrorMessage());
+ }
+
+ return testQueryHandle;
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
+ */
+ @Override
+ public Map<String, String> getAlgoParamDescription(String algorithm) {
+ return ml.getAlgoParamDescription(algorithm);
+ }
+}
[5/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Posted by sh...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLRunner.java
deleted file mode 100644
index bd50cba..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLRunner.java
+++ /dev/null
@@ -1,173 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-
-import org.apache.lens.client.LensClient;
-import org.apache.lens.client.LensClientConfig;
-import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.task.MLTask;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.TableType;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
-import org.apache.hadoop.hive.serde.serdeConstants;
-import org.apache.hadoop.mapred.TextInputFormat;
-
-public class MLRunner {
-
- private static final Log LOG = LogFactory.getLog(MLRunner.class);
-
- private LensMLClient mlClient;
- private String algoName;
- private String database;
- private String trainTable;
- private String trainFile;
- private String testTable;
- private String testFile;
- private String outputTable;
- private String[] features;
- private String labelColumn;
- private HiveConf conf;
-
- public void init(LensMLClient mlClient, String confDir) throws Exception {
- File dir = new File(confDir);
- File propFile = new File(dir, "ml.properties");
- Properties props = new Properties();
- props.load(new FileInputStream(propFile));
- String feat = props.getProperty("features");
- String trainFile = confDir + File.separator + "train.data";
- String testFile = confDir + File.separator + "test.data";
- init(mlClient, props.getProperty("algo"), props.getProperty("database"),
- props.getProperty("traintable"), trainFile,
- props.getProperty("testtable"), testFile,
- props.getProperty("outputtable"), feat.split(","),
- props.getProperty("labelcolumn"));
- }
-
- public void init(LensMLClient mlClient, String algoName,
- String database, String trainTable, String trainFile,
- String testTable, String testFile, String outputTable, String[] features,
- String labelColumn) {
- this.mlClient = mlClient;
- this.algoName = algoName;
- this.database = database;
- this.trainTable = trainTable;
- this.trainFile = trainFile;
- this.testTable = testTable;
- this.testFile = testFile;
- this.outputTable = outputTable;
- this.features = features;
- this.labelColumn = labelColumn;
- //hive metastore settings are loaded via lens-site.xml, so loading LensClientConfig
- //is required
- this.conf = new HiveConf(new LensClientConfig(), MLRunner.class);
- }
-
- public MLTask train() throws Exception {
- LOG.info("Starting train & eval");
-
- createTable(trainTable, trainFile);
- createTable(testTable, testFile);
- MLTask.Builder taskBuilder = new MLTask.Builder();
- taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn).outputTable(outputTable)
- .client(mlClient).trainingTable(trainTable).testTable(testTable);
-
- // Add features
- for (String feature : features) {
- taskBuilder.addFeatureColumn(feature);
- }
- MLTask task = taskBuilder.build();
- LOG.info("Created task " + task.toString());
- task.run();
- return task;
- }
-
- public void createTable(String tableName, String dataFile) throws HiveException {
-
- File filedataFile = new File(dataFile);
- Path dataFilePath = new Path(filedataFile.toURI());
- Path partDir = dataFilePath.getParent();
-
- // Create table
- List<FieldSchema> columns = new ArrayList<FieldSchema>();
-
- // Label is optional. Not used for unsupervised models.
- // If present, label will be the first column, followed by features
- if (labelColumn != null) {
- columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
- }
-
- for (String feature : features) {
- columns.add(new FieldSchema(feature, "double", "Feature " + feature));
- }
-
- Table tbl = Hive.get(conf).newTable(database + "." + tableName);
- tbl.setTableType(TableType.MANAGED_TABLE);
- tbl.getTTable().getSd().setCols(columns);
- // tbl.getTTable().getParameters().putAll(new HashMap<String, String>());
- tbl.setInputFormatClass(TextInputFormat.class);
- tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
- tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
-
- List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
- partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
- tbl.setPartCols(partCols);
-
- Hive.get(conf).dropTable(database, tableName, false, true);
- Hive.get(conf).createTable(tbl, true);
- LOG.info("Created table " + tableName);
-
- // Add partition for the data file
- AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName,
- false);
- Map<String, String> partSpec = new HashMap<String, String>();
- partSpec.put("dummy_partition_col", "dummy_val");
- partitionDesc.addPartition(partSpec, partDir.toUri().toString());
- Hive.get(conf).createPartitions(partitionDesc);
- LOG.info(tableName + ": Added partition " + partDir.toUri().toString());
- }
-
- public static void main(String[] args) throws Exception {
- if (args.length < 1) {
- System.out.println("Usage: org.apache.lens.ml.MLRunner <ml-conf-dir>");
- System.exit(-1);
- }
- String confDir = args[0];
- LensMLClient client = new LensMLClient(new LensClient());
- MLRunner runner = new MLRunner();
- runner.init(client, confDir);
- runner.train();
- System.out.println("Created the Model successfully. Output Table: " + runner.outputTable);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestMetric.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestMetric.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestMetric.java
deleted file mode 100644
index 57adecc..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestMetric.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-/**
- * The Interface MLTestMetric.
- */
-public interface MLTestMetric {
- String getName();
-
- String getDescription();
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestReport.java
deleted file mode 100644
index 909e6df..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLTestReport.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.Serializable;
-import java.util.List;
-
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-import lombok.Setter;
-import lombok.ToString;
-
-/**
- * Instantiates a new ML test report.
- */
-@NoArgsConstructor
-@ToString
-public class MLTestReport implements Serializable {
-
- /** The test table. */
- @Getter
- @Setter
- private String testTable;
-
- /** The output table. */
- @Getter
- @Setter
- private String outputTable;
-
- /** The output column. */
- @Getter
- @Setter
- private String outputColumn;
-
- /** The label column. */
- @Getter
- @Setter
- private String labelColumn;
-
- /** The feature columns. */
- @Getter
- @Setter
- private List<String> featureColumns;
-
- /** The algorithm. */
- @Getter
- @Setter
- private String algorithm;
-
- /** The model id. */
- @Getter
- @Setter
- private String modelID;
-
- /** The report id. */
- @Getter
- @Setter
- private String reportID;
-
- /** The query id. */
- @Getter
- @Setter
- private String queryID;
-
- /** The test output path. */
- @Getter
- @Setter
- private String testOutputPath;
-
- /** The prediction result column. */
- @Getter
- @Setter
- private String predictionResultColumn;
-
- /** The lens query id. */
- @Getter
- @Setter
- private String lensQueryID;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
deleted file mode 100644
index 2e240af..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLUtils.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.api.ServiceProvider;
-import org.apache.lens.server.api.ServiceProviderFactory;
-import org.apache.lens.server.ml.MLService;
-import org.apache.lens.server.ml.MLServiceImpl;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-
-public final class MLUtils {
- private MLUtils() {
- }
-
- private static final HiveConf HIVE_CONF;
-
- static {
- HIVE_CONF = new HiveConf();
- // Add default config so that we know the service provider implementation
- HIVE_CONF.addResource("lensserver-default.xml");
- HIVE_CONF.addResource("lens-site.xml");
- }
-
- public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
- Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
- if (annotation != null) {
- return annotation.name();
- }
- throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
- }
-
- public static MLServiceImpl getMLService() throws Exception {
- return getServiceProvider().getService(MLService.NAME);
- }
-
- public static ServiceProvider getServiceProvider() throws Exception {
- Class<? extends ServiceProviderFactory> spfClass = HIVE_CONF.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY,
- null, ServiceProviderFactory.class);
- ServiceProviderFactory spf = spfClass.newInstance();
- return spf.getServiceProvider();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/ModelLoader.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/ModelLoader.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/ModelLoader.java
deleted file mode 100644
index 429cbf9..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/ModelLoader.java
+++ /dev/null
@@ -1,239 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
-
-import org.apache.commons.io.IOUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-
-import com.google.common.cache.Cache;
-import com.google.common.cache.CacheBuilder;
-
-/**
- * Load ML models from a FS location.
- */
-public final class ModelLoader {
- private ModelLoader() {
- }
-
- /** The Constant MODEL_PATH_BASE_DIR. */
- public static final String MODEL_PATH_BASE_DIR = "lens.ml.model.basedir";
-
- /** The Constant MODEL_PATH_BASE_DIR_DEFAULT. */
- public static final String MODEL_PATH_BASE_DIR_DEFAULT = "file:///tmp";
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(ModelLoader.class);
-
- /** The Constant TEST_REPORT_BASE_DIR. */
- public static final String TEST_REPORT_BASE_DIR = "lens.ml.test.basedir";
-
- /** The Constant TEST_REPORT_BASE_DIR_DEFAULT. */
- public static final String TEST_REPORT_BASE_DIR_DEFAULT = "file:///tmp/ml_reports";
-
- // Model cache settings
- /** The Constant MODEL_CACHE_SIZE. */
- public static final long MODEL_CACHE_SIZE = 10;
-
- /** The Constant MODEL_CACHE_TIMEOUT. */
- public static final long MODEL_CACHE_TIMEOUT = 3600000L; // one hour
-
- /** The model cache. */
- private static Cache<Path, MLModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
- .expireAfterAccess(MODEL_CACHE_TIMEOUT, TimeUnit.MILLISECONDS).build();
-
- /**
- * Gets the model location.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the model location
- */
- public static Path getModelLocation(Configuration conf, String algorithm, String modelID) {
- String modelDataBaseDir = conf.get(MODEL_PATH_BASE_DIR, MODEL_PATH_BASE_DIR_DEFAULT);
- // Model location format - <modelDataBaseDir>/<algorithm>/modelID
- return new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID);
- }
-
- /**
- * Load model.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the ML model
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static MLModel loadModel(Configuration conf, String algorithm, String modelID) throws IOException {
- final Path modelPath = getModelLocation(conf, algorithm, modelID);
- LOG.info("Loading model for algorithm: " + algorithm + " modelID: " + modelID + " At path: "
- + modelPath.toUri().toString());
- try {
- return modelCache.get(modelPath, new Callable<MLModel>() {
- @Override
- public MLModel call() throws Exception {
- FileSystem fs = modelPath.getFileSystem(new HiveConf());
- if (!fs.exists(modelPath)) {
- throw new IOException("Model path not found " + modelPath.toString());
- }
-
- ObjectInputStream ois = null;
- try {
- ois = new ObjectInputStream(fs.open(modelPath));
- MLModel model = (MLModel) ois.readObject();
- LOG.info("Loaded model " + model.getId() + " from location " + modelPath);
- return model;
- } catch (ClassNotFoundException e) {
- throw new IOException(e);
- } finally {
- IOUtils.closeQuietly(ois);
- }
- }
- });
- } catch (ExecutionException exc) {
- throw new IOException(exc);
- }
- }
-
- /**
- * Clear cache.
- */
- public static void clearCache() {
- modelCache.cleanUp();
- }
-
- /**
- * Gets the test report path.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param report the report
- * @return the test report path
- */
- public static Path getTestReportPath(Configuration conf, String algorithm, String report) {
- String testReportDir = conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT);
- return new Path(new Path(testReportDir, algorithm), report);
- }
-
- /**
- * Save test report.
- *
- * @param conf the conf
- * @param report the report
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static void saveTestReport(Configuration conf, MLTestReport report) throws IOException {
- Path reportDir = new Path(conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT));
- FileSystem fs = reportDir.getFileSystem(conf);
-
- if (!fs.exists(reportDir)) {
- LOG.info("Creating test report dir " + reportDir.toUri().toString());
- fs.mkdirs(reportDir);
- }
-
- Path algoDir = new Path(reportDir, report.getAlgorithm());
-
- if (!fs.exists(algoDir)) {
- LOG.info("Creating algorithm report dir " + algoDir.toUri().toString());
- fs.mkdirs(algoDir);
- }
-
- ObjectOutputStream reportOutputStream = null;
- Path reportSaveLocation;
- try {
- reportSaveLocation = new Path(algoDir, report.getReportID());
- reportOutputStream = new ObjectOutputStream(fs.create(reportSaveLocation));
- reportOutputStream.writeObject(report);
- reportOutputStream.flush();
- } catch (IOException ioexc) {
- LOG.error("Error saving test report " + report.getReportID(), ioexc);
- throw ioexc;
- } finally {
- IOUtils.closeQuietly(reportOutputStream);
- }
- LOG.info("Saved report " + report.getReportID() + " at location " + reportSaveLocation.toUri());
- }
-
- /**
- * Load report.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the ML test report
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static MLTestReport loadReport(Configuration conf, String algorithm, String reportID) throws IOException {
- Path reportLocation = getTestReportPath(conf, algorithm, reportID);
- FileSystem fs = reportLocation.getFileSystem(conf);
- ObjectInputStream reportStream = null;
- MLTestReport report = null;
-
- try {
- reportStream = new ObjectInputStream(fs.open(reportLocation));
- report = (MLTestReport) reportStream.readObject();
- } catch (IOException ioex) {
- LOG.error("Error reading report " + reportLocation, ioex);
- } catch (ClassNotFoundException e) {
- throw new IOException(e);
- } finally {
- IOUtils.closeQuietly(reportStream);
- }
- return report;
- }
-
- /**
- * Delete model.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param modelID the model id
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static void deleteModel(HiveConf conf, String algorithm, String modelID) throws IOException {
- Path modelLocation = getModelLocation(conf, algorithm, modelID);
- FileSystem fs = modelLocation.getFileSystem(conf);
- fs.delete(modelLocation, false);
- }
-
- /**
- * Delete test report.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param reportID the report id
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static void deleteTestReport(HiveConf conf, String algorithm, String reportID) throws IOException {
- Path reportPath = getTestReportPath(conf, algorithm, reportID);
- reportPath.getFileSystem(conf).delete(reportPath, false);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MultiPrediction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MultiPrediction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MultiPrediction.java
deleted file mode 100644
index 4794c97..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MultiPrediction.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.List;
-
-/**
- * The Interface MultiPrediction.
- */
-public interface MultiPrediction {
- List<LabelledPrediction> getPredictions();
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/QueryRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/QueryRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/QueryRunner.java
deleted file mode 100644
index 56f9a88..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/QueryRunner.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.query.QueryHandle;
-
-import lombok.Getter;
-import lombok.Setter;
-
-/**
- * Run a query against a Lens server.
- */
-public abstract class QueryRunner {
-
- /** The session handle. */
- protected final LensSessionHandle sessionHandle;
-
- @Getter @Setter
- protected String queryName;
-
- /**
- * Instantiates a new query runner.
- *
- * @param sessionHandle the session handle
- */
- public QueryRunner(LensSessionHandle sessionHandle) {
- this.sessionHandle = sessionHandle;
- }
-
- /**
- * Run query.
- *
- * @param query the query
- * @return the query handle
- * @throws LensException the lens exception
- */
- public abstract QueryHandle runQuery(String query) throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/TableTestingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/TableTestingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/TableTestingSpec.java
deleted file mode 100644
index f7fb1f8..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/TableTestingSpec.java
+++ /dev/null
@@ -1,325 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-
-import org.apache.commons.lang3.StringUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.metadata.Table;
-
-import lombok.Getter;
-
-/**
- * Table specification for running test on a table.
- */
-public class TableTestingSpec {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);
-
- /** The db. */
- private String db;
-
- /** The table containing input data. */
- private String inputTable;
-
- // TODO use partition condition
- /** The partition filter. */
- private String partitionFilter;
-
- /** The feature columns. */
- private List<String> featureColumns;
-
- /** The label column. */
- private String labelColumn;
-
- /** The output column. */
- private String outputColumn;
-
- /** The output table. */
- private String outputTable;
-
- /** The conf. */
- private transient HiveConf conf;
-
- /** The algorithm. */
- private String algorithm;
-
- /** The model id. */
- private String modelID;
-
- @Getter
- private boolean outputTableExists;
-
- @Getter
- private String testID;
-
- private HashMap<String, FieldSchema> columnNameToFieldSchema;
-
- /**
- * The Class TableTestingSpecBuilder.
- */
- public static class TableTestingSpecBuilder {
-
- /** The spec. */
- private final TableTestingSpec spec;
-
- /**
- * Instantiates a new table testing spec builder.
- */
- public TableTestingSpecBuilder() {
- spec = new TableTestingSpec();
- }
-
- /**
- * Database.
- *
- * @param database the database
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder database(String database) {
- spec.db = database;
- return this;
- }
-
- /**
- * Set the input table
- *
- * @param table the table
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder inputTable(String table) {
- spec.inputTable = table;
- return this;
- }
-
- /**
- * Partition filter for input table
- *
- * @param partFilter the part filter
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder partitionFilter(String partFilter) {
- spec.partitionFilter = partFilter;
- return this;
- }
-
- /**
- * Feature columns.
- *
- * @param featureColumns the feature columns
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder featureColumns(List<String> featureColumns) {
- spec.featureColumns = featureColumns;
- return this;
- }
-
- /**
- * Labe column.
- *
- * @param labelColumn the label column
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder lableColumn(String labelColumn) {
- spec.labelColumn = labelColumn;
- return this;
- }
-
- /**
- * Output column.
- *
- * @param outputColumn the output column
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder outputColumn(String outputColumn) {
- spec.outputColumn = outputColumn;
- return this;
- }
-
- /**
- * Output table.
- *
- * @param table the table
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder outputTable(String table) {
- spec.outputTable = table;
- return this;
- }
-
- /**
- * Hive conf.
- *
- * @param conf the conf
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder hiveConf(HiveConf conf) {
- spec.conf = conf;
- return this;
- }
-
- /**
- * Algorithm.
- *
- * @param algorithm the algorithm
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder algorithm(String algorithm) {
- spec.algorithm = algorithm;
- return this;
- }
-
- /**
- * Model id.
- *
- * @param modelID the model id
- * @return the table testing spec builder
- */
- public TableTestingSpecBuilder modelID(String modelID) {
- spec.modelID = modelID;
- return this;
- }
-
- /**
- * Builds the.
- *
- * @return the table testing spec
- */
- public TableTestingSpec build() {
- return spec;
- }
-
- /**
- * Set the unique test id
- *
- * @param testID
- * @return
- */
- public TableTestingSpecBuilder testID(String testID) {
- spec.testID = testID;
- return this;
- }
- }
-
- /**
- * New builder.
- *
- * @return the table testing spec builder
- */
- public static TableTestingSpecBuilder newBuilder() {
- return new TableTestingSpecBuilder();
- }
-
- /**
- * Validate.
- *
- * @return true, if successful
- */
- public boolean validate() {
- List<FieldSchema> columns;
- try {
- Hive metastoreClient = Hive.get(conf);
- Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
- columns = tbl.getAllCols();
- columnNameToFieldSchema = new HashMap<String, FieldSchema>();
-
- for (FieldSchema fieldSchema : columns) {
- columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
- }
-
- // Check if output table exists
- Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
- outputTableExists = (outTbl != null);
- } catch (HiveException exc) {
- LOG.error("Error getting table info " + toString(), exc);
- return false;
- }
-
- // Check if labeled column and feature columns are contained in the table
- List<String> testTableColumns = new ArrayList<String>(columns.size());
- for (FieldSchema column : columns) {
- testTableColumns.add(column.getName());
- }
-
- if (!testTableColumns.containsAll(featureColumns)) {
- LOG.info("Invalid feature columns: " + featureColumns + ". Actual columns in table:" + testTableColumns);
- return false;
- }
-
- if (!testTableColumns.contains(labelColumn)) {
- LOG.info("Invalid label column: " + labelColumn + ". Actual columns in table:" + testTableColumns);
- return false;
- }
-
- if (StringUtils.isBlank(outputColumn)) {
- LOG.info("Output column is required");
- return false;
- }
-
- if (StringUtils.isBlank(outputTable)) {
- LOG.info("Output table is required");
- return false;
- }
- return true;
- }
-
- public String getTestQuery() {
- if (!validate()) {
- return null;
- }
-
- // We always insert a dynamic partition
- StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
- + "') SELECT ");
- String featureCols = StringUtils.join(featureColumns, ",");
- q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
- .append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
- .append(" FROM ").append(inputTable);
-
- return q.toString();
- }
-
- public String getCreateOutputTableQuery() {
- StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
- // Output table contains feature columns, label column, output column
- List<String> outputTableColumns = new ArrayList<String>();
- for (String featureCol : featureColumns) {
- outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
- }
-
- outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
- outputTableColumns.add(outputColumn + " string");
-
- createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
-
- // Append partition column
- createTableQuery.append(") PARTITIONED BY (part_testid string)");
-
- return createTableQuery.toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
new file mode 100644
index 0000000..e0d13c0
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.api;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The Interface AlgoParam.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.FIELD)
+public @interface AlgoParam {
+
+ /**
+ * Name.
+ *
+ * @return the string
+ */
+ String name();
+
+ /**
+ * Help.
+ *
+ * @return the string
+ */
+ String help();
+
+ /**
+ * Default value.
+ *
+ * @return the string
+ */
+ String defaultValue() default "None";
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
new file mode 100644
index 0000000..29bde29
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.api;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * The Interface Algorithm.
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.TYPE)
+public @interface Algorithm {
+
+ /**
+ * Name.
+ *
+ * @return the string
+ */
+ String name();
+
+ /**
+ * Description.
+ *
+ * @return the string
+ */
+ String description();
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java
new file mode 100644
index 0000000..44b0043
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.api;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+
+/**
+ * The Interface MLAlgo.
+ */
+public interface MLAlgo {
+ String getName();
+
+ String getDescription();
+
+ /**
+ * Configure.
+ *
+ * @param configuration the configuration
+ */
+ void configure(LensConf configuration);
+
+ LensConf getConf();
+
+ /**
+ * Train.
+ *
+ * @param conf the conf
+ * @param db the db
+ * @param table the table
+ * @param modelId the model id
+ * @param params the params
+ * @return the ML model
+ * @throws LensException the lens exception
+ */
+ MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
new file mode 100644
index 0000000..1aa699d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.api;
+
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+
+/**
+ * The Interface MLDriver.
+ */
+public interface MLDriver {
+
+ /**
+ * Checks if is algo supported.
+ *
+ * @param algo the algo
+ * @return true, if is algo supported
+ */
+ boolean isAlgoSupported(String algo);
+
+ /**
+ * Gets the algo instance.
+ *
+ * @param algo the algo
+ * @return the algo instance
+ * @throws LensException the lens exception
+ */
+ MLAlgo getAlgoInstance(String algo) throws LensException;
+
+ /**
+ * Inits the.
+ *
+ * @param conf the conf
+ * @throws LensException the lens exception
+ */
+ void init(LensConf conf) throws LensException;
+
+ /**
+ * Start.
+ *
+ * @throws LensException the lens exception
+ */
+ void start() throws LensException;
+
+ /**
+ * Stop.
+ *
+ * @throws LensException the lens exception
+ */
+ void stop() throws LensException;
+
+ List<String> getAlgoNames();
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java
new file mode 100644
index 0000000..73717ac
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.api;
+
+import java.io.Serializable;
+import java.util.Date;
+import java.util.List;
+
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+import lombok.ToString;
+
+/**
+ * Instantiates a new ML model.
+ */
+@NoArgsConstructor
+@ToString
+public abstract class MLModel<PREDICTION> implements Serializable {
+
+ /** The id. */
+ @Getter
+ @Setter
+ private String id;
+
+ /** The created at. */
+ @Getter
+ @Setter
+ private Date createdAt;
+
+ /** The algo name. */
+ @Getter
+ @Setter
+ private String algoName;
+
+ /** The table. */
+ @Getter
+ @Setter
+ private String table;
+
+ /** The params. */
+ @Getter
+ @Setter
+ private List<String> params;
+
+ /** The label column. */
+ @Getter
+ @Setter
+ private String labelColumn;
+
+ /** The feature columns. */
+ @Getter
+ @Setter
+ private List<String> featureColumns;
+
+ /**
+ * Predict.
+ *
+ * @param args the args
+ * @return the prediction
+ */
+ public abstract PREDICTION predict(Object... args);
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
new file mode 100644
index 0000000..51979d8
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
@@ -0,0 +1,117 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.MLAlgo;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+/**
+ * The Class AlgoArgParser.
+ */
+public final class AlgoArgParser {
+ private AlgoArgParser() {
+ }
+
+ /**
+ * The Class CustomArgParser.
+ *
+ * @param <E> the element type
+ */
+ public abstract static class CustomArgParser<E> {
+
+ /**
+ * Parses the.
+ *
+ * @param value the value
+ * @return the e
+ */
+ public abstract E parse(String value);
+ }
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(AlgoArgParser.class);
+
+ /**
+ * Extracts feature names. If the algo has any parameters associated with @AlgoParam annotation, those are set
+ * as well.
+ *
+ * @param algo the algo
+ * @param args the args
+ * @return List of feature column names.
+ */
+ public static List<String> parseArgs(MLAlgo algo, String[] args) {
+ List<String> featureColumns = new ArrayList<String>();
+ Class<? extends MLAlgo> algoClass = algo.getClass();
+ // Get param fields
+ Map<String, Field> fieldMap = new HashMap<String, Field>();
+
+ for (Field fld : algoClass.getDeclaredFields()) {
+ fld.setAccessible(true);
+ AlgoParam paramAnnotation = fld.getAnnotation(AlgoParam.class);
+ if (paramAnnotation != null) {
+ fieldMap.put(paramAnnotation.name(), fld);
+ }
+ }
+
+ for (int i = 0; i < args.length; i += 2) {
+ String key = args[i].trim();
+ String value = args[i + 1].trim();
+
+ try {
+ if ("feature".equalsIgnoreCase(key)) {
+ featureColumns.add(value);
+ } else if (fieldMap.containsKey(key)) {
+ Field f = fieldMap.get(key);
+ if (String.class.equals(f.getType())) {
+ f.set(algo, value);
+ } else if (Integer.TYPE.equals(f.getType())) {
+ f.setInt(algo, Integer.parseInt(value));
+ } else if (Double.TYPE.equals(f.getType())) {
+ f.setDouble(algo, Double.parseDouble(value));
+ } else if (Long.TYPE.equals(f.getType())) {
+ f.setLong(algo, Long.parseLong(value));
+ } else {
+ // check if the algo provides a deserializer for this param
+ String customParserClass = algo.getConf().getProperties().get("lens.ml.args." + key);
+ if (customParserClass != null) {
+ Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
+ .forName(customParserClass);
+ CustomArgParser<?> parser = clz.newInstance();
+ f.set(algo, parser.parse(value));
+ } else {
+ LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
+ }
+ }
+ }
+ } catch (Exception exc) {
+ LOG.error("Error while setting param " + key + " to " + value + " for algo " + algo);
+ }
+ }
+ return featureColumns;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
new file mode 100644
index 0000000..a2fd94b
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.api.MLAlgo;
+
+/**
+ * The Class Algorithms.
+ */
+public class Algorithms {
+
+ /** The algorithm classes. */
+ private final Map<String, Class<? extends MLAlgo>> algorithmClasses
+ = new HashMap<String, Class<? extends MLAlgo>>();
+
+ /**
+ * Register.
+ *
+ * @param algoClass the algo class
+ */
+ public void register(Class<? extends MLAlgo> algoClass) {
+ if (algoClass != null && algoClass.getAnnotation(Algorithm.class) != null) {
+ algorithmClasses.put(algoClass.getAnnotation(Algorithm.class).name(), algoClass);
+ } else {
+ throw new IllegalArgumentException("Not a valid algorithm class: " + algoClass);
+ }
+ }
+
+ /**
+ * Gets the algo for name.
+ *
+ * @param name the name
+ * @return the algo for name
+ * @throws LensException the lens exception
+ */
+ public MLAlgo getAlgoForName(String name) throws LensException {
+ Class<? extends MLAlgo> algoClass = algorithmClasses.get(name);
+ if (algoClass == null) {
+ return null;
+ }
+ Algorithm algoAnnotation = algoClass.getAnnotation(Algorithm.class);
+ String description = algoAnnotation.description();
+ try {
+ Constructor<? extends MLAlgo> algoConstructor = algoClass.getConstructor(String.class, String.class);
+ return algoConstructor.newInstance(name, description);
+ } catch (Exception exc) {
+ throw new LensException("Unable to get algo: " + name, exc);
+ }
+ }
+
+ /**
+ * Checks if is algo supported.
+ *
+ * @param name the name
+ * @return true, if is algo supported
+ */
+ public boolean isAlgoSupported(String name) {
+ return algorithmClasses.containsKey(name);
+ }
+
+ public List<String> getAlgorithmNames() {
+ return new ArrayList<String>(algorithmClasses.keySet());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
new file mode 100644
index 0000000..a960a4a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+/**
+ * Return a single double value as a prediction. This is useful in classifiers where the classifier returns a single
+ * class label as a prediction.
+ */
+public abstract class ClassifierBaseModel extends MLModel<Double> {
+
+ /**
+ * Gets the feature vector.
+ *
+ * @param args the args
+ * @return the feature vector
+ */
+ public final double[] getFeatureVector(Object[] args) {
+ double[] features = new double[args.length];
+ for (int i = 0; i < args.length; i++) {
+ if (args[i] instanceof Double) {
+ features[i] = (Double) args[i];
+ } else if (args[i] instanceof String) {
+ features[i] = Double.parseDouble((String) args[i]);
+ } else {
+ features[i] = 0.0;
+ }
+ }
+ return features;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
new file mode 100644
index 0000000..16a6180
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+import java.util.List;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+/**
+ * The Class ForecastingModel.
+ */
+public class ForecastingModel extends MLModel<MultiPrediction> {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
+ */
+ @Override
+ public MultiPrediction predict(Object... args) {
+ return new ForecastingPredictions(null);
+ }
+
+ /**
+ * The Class ForecastingPredictions.
+ */
+ public static class ForecastingPredictions implements MultiPrediction {
+
+ /** The values. */
+ private final List<LabelledPrediction> values;
+
+ /**
+ * Instantiates a new forecasting predictions.
+ *
+ * @param values the values
+ */
+ public ForecastingPredictions(List<LabelledPrediction> values) {
+ this.values = values;
+ }
+
+ @Override
+ public List<LabelledPrediction> getPredictions() {
+ return values;
+ }
+ }
+
+ /**
+ * The Class ForecastingLabel.
+ */
+ public static class ForecastingLabel implements LabelledPrediction<Long, Double> {
+
+ /** The timestamp. */
+ private final Long timestamp;
+
+ /** The value. */
+ private final double value;
+
+ /**
+ * Instantiates a new forecasting label.
+ *
+ * @param timestamp the timestamp
+ * @param value the value
+ */
+ public ForecastingLabel(long timestamp, double value) {
+ this.timestamp = timestamp;
+ this.value = value;
+ }
+
+ @Override
+ public Long getLabel() {
+ return timestamp;
+ }
+
+ @Override
+ public Double getPrediction() {
+ return value;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/LabelledPrediction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/LabelledPrediction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/LabelledPrediction.java
new file mode 100644
index 0000000..9c7737e
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/LabelledPrediction.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+/**
+ * Prediction type used when the model prediction is of complex types. For example, in forecasting the predictions are a
+ * series of timestamp, and value pairs.
+ *
+ * @param <LABELTYPE> the generic type
+ * @param <PREDICTIONTYPE> the generic type
+ */
+public interface LabelledPrediction<LABELTYPE, PREDICTIONTYPE> {
+ LABELTYPE getLabel();
+
+ PREDICTIONTYPE getPrediction();
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/MultiPrediction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/MultiPrediction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/MultiPrediction.java
new file mode 100644
index 0000000..e910a92
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/MultiPrediction.java
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.lib;
+
+import java.util.List;
+
+/**
+ * The Interface MultiPrediction.
+ */
+public interface MultiPrediction {
+ List<LabelledPrediction> getPredictions();
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
new file mode 100644
index 0000000..4012085
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
@@ -0,0 +1,287 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.lang.reflect.Field;
+import java.util.*;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLModel;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class BaseSparkAlgo.
+ */
+public abstract class BaseSparkAlgo implements MLAlgo {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
+
+ /** The name. */
+ private final String name;
+
+ /** The description. */
+ private final String description;
+
+ /** The spark context. */
+ protected JavaSparkContext sparkContext;
+
+ /** The params. */
+ protected Map<String, String> params;
+
+ /** The conf. */
+ protected transient LensConf conf;
+
+ /** The training fraction. */
+ @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
+ protected double trainingFraction;
+
+ /** The use training fraction. */
+ private boolean useTrainingFraction;
+
+ /** The label. */
+ @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
+ protected String label;
+
+ /** The partition filter. */
+ @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
+ protected String partitionFilter;
+
+ /** The features. */
+ @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
+ protected List<String> features;
+
+ /**
+ * Instantiates a new base spark algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public BaseSparkAlgo(String name, String description) {
+ this.name = name;
+ this.description = description;
+ }
+
+ public void setSparkContext(JavaSparkContext sparkContext) {
+ this.sparkContext = sparkContext;
+ }
+
+ @Override
+ public LensConf getConf() {
+ return conf;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void configure(LensConf configuration) {
+ this.conf = configuration;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+ * java.lang.String, java.lang.String[])
+ */
+ @Override
+ public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
+ throws LensException {
+ parseParams(params);
+
+ TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
+ .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
+
+ if (useTrainingFraction) {
+ builder.trainingFraction(trainingFraction);
+ }
+
+ TableTrainingSpec spec = builder.build();
+ LOG.info("Training " + " with " + features.size() + " features");
+
+ spec.createRDDs(sparkContext);
+
+ RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
+ BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
+ model.setTable(table);
+ model.setParams(Arrays.asList(params));
+ model.setLabelColumn(label);
+ model.setFeatureColumns(features);
+ return model;
+ }
+
+ /**
+ * To hive conf.
+ *
+ * @param conf the conf
+ * @return the hive conf
+ */
+ protected HiveConf toHiveConf(LensConf conf) {
+ HiveConf hiveConf = new HiveConf();
+ for (String key : conf.getProperties().keySet()) {
+ hiveConf.set(key, conf.getProperties().get(key));
+ }
+ return hiveConf;
+ }
+
+ /**
+ * Parses the params.
+ *
+ * @param args the args
+ */
+ public void parseParams(String[] args) {
+ if (args.length % 2 != 0) {
+ throw new IllegalArgumentException("Invalid number of params " + args.length);
+ }
+
+ params = new LinkedHashMap<String, String>();
+
+ for (int i = 0; i < args.length; i += 2) {
+ if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
+ if (features == null) {
+ features = new ArrayList<String>();
+ }
+ features.add(args[i + 1]);
+ } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
+ label = args[i + 1];
+ } else {
+ params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
+ }
+ }
+
+ if (params.containsKey("trainingFraction")) {
+ // Get training Fraction
+ String trainingFractionStr = params.get("trainingFraction");
+ try {
+ trainingFraction = Double.parseDouble(trainingFractionStr);
+ useTrainingFraction = true;
+ } catch (NumberFormatException nfe) {
+ throw new IllegalArgumentException("Invalid training fraction", nfe);
+ }
+ }
+
+ if (params.containsKey("partition") || params.containsKey("p")) {
+ partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
+ }
+
+ parseAlgoParams(params);
+ }
+
+ /**
+ * Gets the param value.
+ *
+ * @param param the param
+ * @param defaultVal the default val
+ * @return the param value
+ */
+ public double getParamValue(String param, double defaultVal) {
+ if (params.containsKey(param)) {
+ try {
+ return Double.parseDouble(params.get(param));
+ } catch (NumberFormatException nfe) {
+ LOG.warn("Couldn't parse param value: " + param + " as double.");
+ }
+ }
+ return defaultVal;
+ }
+
+ /**
+ * Gets the param value.
+ *
+ * @param param the param
+ * @param defaultVal the default val
+ * @return the param value
+ */
+ public int getParamValue(String param, int defaultVal) {
+ if (params.containsKey(param)) {
+ try {
+ return Integer.parseInt(params.get(param));
+ } catch (NumberFormatException nfe) {
+ LOG.warn("Couldn't parse param value: " + param + " as integer.");
+ }
+ }
+ return defaultVal;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public String getDescription() {
+ return description;
+ }
+
+ public Map<String, String> getArgUsage() {
+ Map<String, String> usage = new LinkedHashMap<String, String>();
+ Class<?> clz = this.getClass();
+ // Put class name and description as well as part of the usage
+ Algorithm algorithm = clz.getAnnotation(Algorithm.class);
+ if (algorithm != null) {
+ usage.put("Algorithm Name", algorithm.name());
+ usage.put("Algorithm Description", algorithm.description());
+ }
+
+ // Get all algo params including base algo params
+ while (clz != null) {
+ for (Field field : clz.getDeclaredFields()) {
+ AlgoParam param = field.getAnnotation(AlgoParam.class);
+ if (param != null) {
+ usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
+ }
+ }
+
+ if (clz.equals(BaseSparkAlgo.class)) {
+ break;
+ }
+ clz = clz.getSuperclass();
+ }
+ return usage;
+ }
+
+ /**
+ * Parses the algo params.
+ *
+ * @param params the params
+ */
+ public abstract void parseAlgoParams(Map<String, String> params);
+
+ /**
+ * Train internal.
+ *
+ * @param modelId the model id
+ * @param trainingRDD the training rdd
+ * @return the base spark classification model
+ * @throws LensException the lens exception
+ */
+ protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
new file mode 100644
index 0000000..806dc1f
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import org.apache.lens.ml.algo.lib.ClassifierBaseModel;
+
+import org.apache.spark.mllib.classification.ClassificationModel;
+import org.apache.spark.mllib.linalg.Vectors;
+
+/**
+ * The Class BaseSparkClassificationModel.
+ *
+ * @param <MODEL> the generic type
+ */
+public class BaseSparkClassificationModel<MODEL extends ClassificationModel> extends ClassifierBaseModel {
+
+ /** The model id. */
+ private final String modelId;
+
+ /** The spark model. */
+ private final MODEL sparkModel;
+
+ /**
+ * Instantiates a new base spark classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public BaseSparkClassificationModel(String modelId, MODEL model) {
+ this.modelId = modelId;
+ this.sparkModel = model;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
+ */
+ @Override
+ public Double predict(Object... args) {
+ return sparkModel.predict(Vectors.dense(getFeatureVector(args)));
+ }
+
+ @Override
+ public String getId() {
+ return modelId;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
new file mode 100644
index 0000000..d75efc0
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.log4j.Logger;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+
+import com.google.common.base.Preconditions;
+import scala.Tuple2;
+
+/**
+ * A feature function that directly maps an HCatRecord to a feature vector. Each column becomes a feature in the vector,
+ * with the value of the feature obtained using the value mapper for that column
+ */
+public class ColumnFeatureFunction extends FeatureFunction {
+
+ /** The Constant LOG. */
+ public static final Logger LOG = Logger.getLogger(ColumnFeatureFunction.class);
+
+ /** The feature value mappers. */
+ private final FeatureValueMapper[] featureValueMappers;
+
+ /** The feature positions. */
+ private final int[] featurePositions;
+
+ /** The label column pos. */
+ private final int labelColumnPos;
+
+ /** The num features. */
+ private final int numFeatures;
+
+ /** The default labeled point. */
+ private final LabeledPoint defaultLabeledPoint;
+
+ /**
+ * Feature positions and value mappers are parallel arrays. featurePositions[i] gives the position of ith feature in
+ * the HCatRecord, and valueMappers[i] gives the value mapper used to map that feature to a Double value
+ *
+ * @param featurePositions position number of feature column in the HCatRecord
+ * @param valueMappers mapper for each column position
+ * @param labelColumnPos position of the label column
+ * @param numFeatures number of features in the feature vector
+ * @param defaultLabel default lable to be used for null records
+ */
+ public ColumnFeatureFunction(int[] featurePositions, FeatureValueMapper[] valueMappers, int labelColumnPos,
+ int numFeatures, double defaultLabel) {
+ Preconditions.checkNotNull(valueMappers, "Value mappers argument is required");
+ Preconditions.checkNotNull(featurePositions, "Feature positions are required");
+ Preconditions.checkArgument(valueMappers.length == featurePositions.length,
+ "Mismatch between number of value mappers and feature positions");
+
+ this.featurePositions = featurePositions;
+ this.featureValueMappers = valueMappers;
+ this.labelColumnPos = labelColumnPos;
+ this.numFeatures = numFeatures;
+ defaultLabeledPoint = new LabeledPoint(defaultLabel, Vectors.dense(new double[numFeatures]));
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.FeatureFunction#call(scala.Tuple2)
+ */
+ @Override
+ public LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception {
+ HCatRecord record = tuple._2();
+
+ if (record == null) {
+ LOG.info("@@@ Null record");
+ return defaultLabeledPoint;
+ }
+
+ double[] features = new double[numFeatures];
+
+ for (int i = 0; i < numFeatures; i++) {
+ int featurePos = featurePositions[i];
+ features[i] = featureValueMappers[i].call(record.get(featurePos));
+ }
+
+ double label = featureValueMappers[labelColumnPos].call(record.get(labelColumnPos));
+ return new LabeledPoint(label, Vectors.dense(features));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/DoubleValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/DoubleValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/DoubleValueMapper.java
new file mode 100644
index 0000000..15ba9ea
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/DoubleValueMapper.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+/**
+ * Directly return input when it is known to be double.
+ */
+public class DoubleValueMapper extends FeatureValueMapper {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.FeatureValueMapper#call(java.lang.Object)
+ */
+ @Override
+ public final Double call(Object input) {
+ if (input instanceof Double || input == null) {
+ return input == null ? Double.valueOf(0d) : (Double) input;
+ }
+
+ throw new IllegalArgumentException("Invalid input expecting only doubles, but got " + input);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureFunction.java
new file mode 100644
index 0000000..5e2ab49
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureFunction.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.regression.LabeledPoint;
+
+import scala.Tuple2;
+
+/**
+ * Function to map an HCatRecord to a feature vector usable by MLLib.
+ */
+public abstract class FeatureFunction implements Function<Tuple2<WritableComparable, HCatRecord>, LabeledPoint> {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public abstract LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureValueMapper.java
new file mode 100644
index 0000000..28c8787
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/FeatureValueMapper.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.Serializable;
+
+import org.apache.spark.api.java.function.Function;
+
+/**
+ * Map a feature value to a Double value usable by MLLib.
+ */
+public abstract class FeatureValueMapper implements Function<Object, Double>, Serializable {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ public abstract Double call(Object input);
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
new file mode 100644
index 0000000..4960e1e
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.IOException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * Create a JavaRDD based on a Hive table using HCatInputFormat.
+ */
+public final class HiveTableRDD {
+ private HiveTableRDD() {
+ }
+
+ public static final Log LOG = LogFactory.getLog(HiveTableRDD.class);
+
+ /**
+ * Creates the hive table rdd.
+ *
+ * @param javaSparkContext the java spark context
+ * @param conf the conf
+ * @param db the db
+ * @param table the table
+ * @param partitionFilter the partition filter
+ * @return the java pair rdd
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static JavaPairRDD<WritableComparable, HCatRecord> createHiveTableRDD(JavaSparkContext javaSparkContext,
+ Configuration conf, String db, String table, String partitionFilter) throws IOException {
+
+ HCatInputFormat.setInput(conf, db, table, partitionFilter);
+
+ JavaPairRDD<WritableComparable, HCatRecord> rdd = javaSparkContext.newAPIHadoopRDD(conf,
+ HCatInputFormat.class, // Input
+ WritableComparable.class, // input key class
+ HCatRecord.class); // input value class
+ return rdd;
+ }
+}
[4/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Posted by sh...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
new file mode 100644
index 0000000..c955268
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.File;
+import java.io.FilenameFilter;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLDriver;
+import org.apache.lens.ml.algo.lib.Algorithms;
+import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
+import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
+import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
+import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * The Class SparkMLDriver.
+ */
+public class SparkMLDriver implements MLDriver {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
+
+ /** The owns spark context. */
+ private boolean ownsSparkContext = true;
+
+ /**
+ * The Enum SparkMasterMode.
+ */
+ private enum SparkMasterMode {
+ // Embedded mode used in tests
+ /** The embedded. */
+ EMBEDDED,
+ // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
+ /** The yarn client. */
+ YARN_CLIENT,
+
+ /** The yarn cluster. */
+ YARN_CLUSTER
+ }
+
+ /** The algorithms. */
+ private final Algorithms algorithms = new Algorithms();
+
+ /** The client mode. */
+ private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
+
+ /** The is started. */
+ private boolean isStarted;
+
+ /** The spark conf. */
+ private SparkConf sparkConf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /**
+ * Use spark context.
+ *
+ * @param jsc the jsc
+ */
+ public void useSparkContext(JavaSparkContext jsc) {
+ ownsSparkContext = false;
+ this.sparkContext = jsc;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
+ */
+ @Override
+ public boolean isAlgoSupported(String name) {
+ return algorithms.isAlgoSupported(name);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
+ */
+ @Override
+ public MLAlgo getAlgoInstance(String name) throws LensException {
+ checkStarted();
+
+ if (!isAlgoSupported(name)) {
+ return null;
+ }
+
+ MLAlgo algo = null;
+ try {
+ algo = algorithms.getAlgoForName(name);
+ if (algo instanceof BaseSparkAlgo) {
+ ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
+ }
+ } catch (LensException exc) {
+ LOG.error("Error creating algo object", exc);
+ }
+ return algo;
+ }
+
+ /**
+ * Register algos.
+ */
+ private void registerAlgos() {
+ algorithms.register(NaiveBayesAlgo.class);
+ algorithms.register(SVMAlgo.class);
+ algorithms.register(LogisticRegressionAlgo.class);
+ algorithms.register(DecisionTreeAlgo.class);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void init(LensConf conf) throws LensException {
+ sparkConf = new SparkConf();
+ registerAlgos();
+ for (String key : conf.getProperties().keySet()) {
+ if (key.startsWith("lens.ml.sparkdriver.")) {
+ sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
+ }
+ }
+
+ String sparkAppMaster = sparkConf.get("spark.master");
+ if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) {
+ clientMode = SparkMasterMode.YARN_CLIENT;
+ } else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) {
+ clientMode = SparkMasterMode.YARN_CLUSTER;
+ } else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) {
+ clientMode = SparkMasterMode.EMBEDDED;
+ } else {
+ throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster);
+ }
+
+ if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) {
+ String sparkHome = System.getenv("SPARK_HOME");
+ if (StringUtils.isNotBlank(sparkHome)) {
+ sparkConf.setSparkHome(sparkHome);
+ }
+
+ // If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties.
+ if (StringUtils.isBlank(sparkConf.get("spark.home"))) {
+ throw new IllegalArgumentException("Spark home is not set");
+ }
+
+ LOG.info("Spark home is set to " + sparkConf.get("spark.home"));
+ }
+
+ sparkConf.setAppName("lens-ml");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#start()
+ */
+ @Override
+ public void start() throws LensException {
+ if (sparkContext == null) {
+ sparkContext = new JavaSparkContext(sparkConf);
+ }
+
+ // Adding jars to spark context is only required when running in yarn-client mode
+ if (clientMode != SparkMasterMode.EMBEDDED) {
+ // TODO Figure out only necessary set of JARs to be added for HCatalog
+ // Add hcatalog and hive jars
+ String hiveLocation = System.getenv("HIVE_HOME");
+
+ if (StringUtils.isBlank(hiveLocation)) {
+ throw new LensException("HIVE_HOME is not set");
+ }
+
+ LOG.info("HIVE_HOME at " + hiveLocation);
+
+ File hiveLibDir = new File(hiveLocation, "lib");
+ FilenameFilter jarFileFilter = new FilenameFilter() {
+ @Override
+ public boolean accept(File file, String s) {
+ return s.endsWith(".jar");
+ }
+ };
+
+ List<String> jarFiles = new ArrayList<String>();
+ // Add hive jars
+ for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) {
+ jarFiles.add(jarFile.getAbsolutePath());
+ LOG.info("Adding HIVE jar " + jarFile.getAbsolutePath());
+ sparkContext.addJar(jarFile.getAbsolutePath());
+ }
+
+ // Add hcatalog jars
+ File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog");
+ for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) {
+ jarFiles.add(jarFile.getAbsolutePath());
+ LOG.info("Adding HCATALOG jar " + jarFile.getAbsolutePath());
+ sparkContext.addJar(jarFile.getAbsolutePath());
+ }
+
+ // Add the current jar
+ String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class);
+ for (String lensSparkJar : lensSparkLibJars) {
+ LOG.info("Adding Lens JAR " + lensSparkJar);
+ sparkContext.addJar(lensSparkJar);
+ }
+ }
+
+ isStarted = true;
+ LOG.info("Created Spark context for app: '" + sparkContext.appName() + "', Spark master: " + sparkContext.master());
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLDriver#stop()
+ */
+ @Override
+ public void stop() throws LensException {
+ if (!isStarted) {
+ LOG.warn("Spark driver was not started");
+ return;
+ }
+ isStarted = false;
+ if (ownsSparkContext) {
+ sparkContext.stop();
+ }
+ LOG.info("Stopped spark context " + this);
+ }
+
+ @Override
+ public List<String> getAlgoNames() {
+ return algorithms.getAlgorithmNames();
+ }
+
+ /**
+ * Check started.
+ *
+ * @throws LensException the lens exception
+ */
+ public void checkStarted() throws LensException {
+ if (!isStarted) {
+ throw new LensException("Spark driver is not started yet");
+ }
+ }
+
+ public JavaSparkContext getSparkContext() {
+ return sparkContext;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
new file mode 100644
index 0000000..33fd801
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
@@ -0,0 +1,433 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.lens.api.LensException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.hive.hcatalog.data.schema.HCatFieldSchema;
+import org.apache.hive.hcatalog.data.schema.HCatSchema;
+import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+import com.google.common.base.Preconditions;
+import lombok.Getter;
+import lombok.ToString;
+
+/**
+ * The Class TableTrainingSpec.
+ */
+@ToString
+public class TableTrainingSpec implements Serializable {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);
+
+ /** The training rdd. */
+ @Getter
+ private transient RDD<LabeledPoint> trainingRDD;
+
+ /** The testing rdd. */
+ @Getter
+ private transient RDD<LabeledPoint> testingRDD;
+
+ /** The database. */
+ @Getter
+ private String database;
+
+ /** The table. */
+ @Getter
+ private String table;
+
+ /** The partition filter. */
+ @Getter
+ private String partitionFilter;
+
+ /** The feature columns. */
+ @Getter
+ private List<String> featureColumns;
+
+ /** The label column. */
+ @Getter
+ private String labelColumn;
+
+ /** The conf. */
+ @Getter
+ private transient HiveConf conf;
+
+ // By default all samples are considered for training
+ /** The split training. */
+ private boolean splitTraining;
+
+ /** The training fraction. */
+ private double trainingFraction = 1.0;
+
+ /** The label pos. */
+ int labelPos;
+
+ /** The feature positions. */
+ int[] featurePositions;
+
+ /** The num features. */
+ int numFeatures;
+
+ /** The labeled rdd. */
+ transient JavaRDD<LabeledPoint> labeledRDD;
+
+ /**
+ * New builder.
+ *
+ * @return the table training spec builder
+ */
+ public static TableTrainingSpecBuilder newBuilder() {
+ return new TableTrainingSpecBuilder();
+ }
+
+ /**
+ * The Class TableTrainingSpecBuilder.
+ */
+ public static class TableTrainingSpecBuilder {
+
+ /** The spec. */
+ final TableTrainingSpec spec;
+
+ /**
+ * Instantiates a new table training spec builder.
+ */
+ public TableTrainingSpecBuilder() {
+ spec = new TableTrainingSpec();
+ }
+
+ /**
+ * Hive conf.
+ *
+ * @param conf the conf
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder hiveConf(HiveConf conf) {
+ spec.conf = conf;
+ return this;
+ }
+
+ /**
+ * Database.
+ *
+ * @param db the db
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder database(String db) {
+ spec.database = db;
+ return this;
+ }
+
+ /**
+ * Table.
+ *
+ * @param table the table
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder table(String table) {
+ spec.table = table;
+ return this;
+ }
+
+ /**
+ * Partition filter.
+ *
+ * @param partFilter the part filter
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder partitionFilter(String partFilter) {
+ spec.partitionFilter = partFilter;
+ return this;
+ }
+
+ /**
+ * Label column.
+ *
+ * @param labelColumn the label column
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder labelColumn(String labelColumn) {
+ spec.labelColumn = labelColumn;
+ return this;
+ }
+
+ /**
+ * Feature columns.
+ *
+ * @param featureColumns the feature columns
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) {
+ spec.featureColumns = featureColumns;
+ return this;
+ }
+
+ /**
+ * Builds the.
+ *
+ * @return the table training spec
+ */
+ public TableTrainingSpec build() {
+ return spec;
+ }
+
+ /**
+ * Training fraction.
+ *
+ * @param trainingFraction the training fraction
+ * @return the table training spec builder
+ */
+ public TableTrainingSpecBuilder trainingFraction(double trainingFraction) {
+ Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0,
+ "Training fraction shoule be between 0 and 1");
+ spec.trainingFraction = trainingFraction;
+ spec.splitTraining = true;
+ return this;
+ }
+ }
+
+ /**
+ * The Class DataSample.
+ */
+ public static class DataSample implements Serializable {
+
+ /** The labeled point. */
+ private final LabeledPoint labeledPoint;
+
+ /** The sample. */
+ private final double sample;
+
+ /**
+ * Instantiates a new data sample.
+ *
+ * @param labeledPoint the labeled point
+ */
+ public DataSample(LabeledPoint labeledPoint) {
+ sample = Math.random();
+ this.labeledPoint = labeledPoint;
+ }
+ }
+
+ /**
+ * The Class TrainingFilter.
+ */
+ public static class TrainingFilter implements Function<DataSample, Boolean> {
+
+ /** The training fraction. */
+ private double trainingFraction;
+
+ /**
+ * Instantiates a new training filter.
+ *
+ * @param fraction the fraction
+ */
+ public TrainingFilter(double fraction) {
+ trainingFraction = fraction;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public Boolean call(DataSample v1) throws Exception {
+ return v1.sample <= trainingFraction;
+ }
+ }
+
+ /**
+ * The Class TestingFilter.
+ */
+ public static class TestingFilter implements Function<DataSample, Boolean> {
+
+ /** The training fraction. */
+ private double trainingFraction;
+
+ /**
+ * Instantiates a new testing filter.
+ *
+ * @param fraction the fraction
+ */
+ public TestingFilter(double fraction) {
+ trainingFraction = fraction;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public Boolean call(DataSample v1) throws Exception {
+ return v1.sample > trainingFraction;
+ }
+ }
+
+ /**
+ * The Class GetLabeledPoint.
+ */
+ public static class GetLabeledPoint implements Function<DataSample, LabeledPoint> {
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
+ */
+ @Override
+ public LabeledPoint call(DataSample v1) throws Exception {
+ return v1.labeledPoint;
+ }
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ boolean validate() {
+ List<HCatFieldSchema> columns;
+ try {
+ HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
+ HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
+ columns = tableSchema.getFields();
+ } catch (IOException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ LOG.info(table + " columns " + columns.toString());
+
+ boolean valid = false;
+ if (columns != null && !columns.isEmpty()) {
+ // Check labeled column
+ List<String> columnNames = new ArrayList<String>();
+ for (HCatFieldSchema col : columns) {
+ columnNames.add(col.getName());
+ }
+
+ // Need at least one feature column and one label column
+ valid = columnNames.contains(labelColumn) && columnNames.size() > 1;
+
+ if (valid) {
+ labelPos = columnNames.indexOf(labelColumn);
+
+ // Check feature columns
+ if (featureColumns == null || featureColumns.isEmpty()) {
+ // feature columns are not provided, so all columns except label column are feature columns
+ featurePositions = new int[columnNames.size() - 1];
+ int p = 0;
+ for (int i = 0; i < columnNames.size(); i++) {
+ if (i == labelPos) {
+ continue;
+ }
+ featurePositions[p++] = i;
+ }
+
+ columnNames.remove(labelPos);
+ featureColumns = columnNames;
+ } else {
+ // Feature columns were provided, verify all feature columns are present in the table
+ valid = columnNames.containsAll(featureColumns);
+ if (valid) {
+ // Get feature positions
+ featurePositions = new int[featureColumns.size()];
+ for (int i = 0; i < featureColumns.size(); i++) {
+ featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
+ }
+ }
+ }
+ numFeatures = featureColumns.size();
+ }
+ }
+
+ return valid;
+ }
+
+ /**
+ * Creates the rd ds.
+ *
+ * @param sparkContext the spark context
+ * @throws LensException the lens exception
+ */
+ public void createRDDs(JavaSparkContext sparkContext) throws LensException {
+ // Validate the spec
+ if (!validate()) {
+ throw new LensException("Table spec not valid: " + toString());
+ }
+
+ LOG.info("Creating RDDs with spec " + toString());
+
+ // Get the RDD for table
+ JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
+ try {
+ tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+
+ // Map into trainable RDD
+ // TODO: Figure out a way to use custom value mappers
+ FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
+ final DoubleValueMapper doubleMapper = new DoubleValueMapper();
+ for (int i = 0; i < numFeatures; i++) {
+ valueMappers[i] = doubleMapper;
+ }
+
+ ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
+ numFeatures, 0);
+ labeledRDD = tableRDD.map(trainPrepFunction);
+
+ if (splitTraining) {
+ // We have to split the RDD between a training RDD and a testing RDD
+ LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
+ JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
+ @Override
+ public DataSample call(LabeledPoint v1) throws Exception {
+ return new DataSample(v1);
+ }
+ });
+
+ trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ } else {
+ LOG.info("Using same RDD for train and test");
+ trainingRDD = labeledRDD.rdd();
+ testingRDD = trainingRDD;
+ }
+ LOG.info("Generated RDDs");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
new file mode 100644
index 0000000..6c7619a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.DecisionTree$;
+import org.apache.spark.mllib.tree.configuration.Algo$;
+import org.apache.spark.mllib.tree.impurity.Entropy$;
+import org.apache.spark.mllib.tree.impurity.Gini$;
+import org.apache.spark.mllib.tree.impurity.Impurity;
+import org.apache.spark.mllib.tree.impurity.Variance$;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+import scala.Enumeration;
+
+/**
+ * The Class DecisionTreeAlgo.
+ */
+@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
+public class DecisionTreeAlgo extends BaseSparkAlgo {
+
+ /** The algo. */
+ @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
+ private Enumeration.Value algo;
+
+ /** The decision tree impurity. */
+ @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
+ + "Allowed values are 'gini', 'entropy' and 'variance'")
+ private Impurity decisionTreeImpurity;
+
+ /** The max depth. */
+ @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
+ defaultValue = "100")
+ private int maxDepth;
+
+ /**
+ * Instantiates a new decision tree algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public DecisionTreeAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ String dtreeAlgoName = params.get("algo");
+ if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Classification();
+ } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
+ algo = Algo$.MODULE$.Regression();
+ }
+
+ String impurity = params.get("impurity");
+ if ("gini".equals(impurity)) {
+ decisionTreeImpurity = Gini$.MODULE$;
+ } else if ("entropy".equals(impurity)) {
+ decisionTreeImpurity = Entropy$.MODULE$;
+ } else if ("variance".equals(impurity)) {
+ decisionTreeImpurity = Variance$.MODULE$;
+ }
+
+ maxDepth = getParamValue("maxDepth", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
+ return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
new file mode 100644
index 0000000..27c32f4
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+/**
+ * The Class DecisionTreeClassificationModel.
+ */
+public class DecisionTreeClassificationModel extends BaseSparkClassificationModel<SparkDecisionTreeModel> {
+
+ /**
+ * Instantiates a new decision tree classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public DecisionTreeClassificationModel(String modelId, SparkDecisionTreeModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
new file mode 100644
index 0000000..e561a8d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.dt;
+
+import org.apache.lens.ml.algo.spark.DoubleValueMapper;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.classification.ClassificationModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * This class is created because the Spark decision tree model doesn't extend ClassificationModel.
+ */
+public class SparkDecisionTreeModel implements ClassificationModel {
+
+ /** The model. */
+ private final DecisionTreeModel model;
+
+ /**
+ * Instantiates a new spark decision tree model.
+ *
+ * @param model the model
+ */
+ public SparkDecisionTreeModel(DecisionTreeModel model) {
+ this.model = model;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.rdd.RDD)
+ */
+ @Override
+ public RDD<Object> predict(RDD<Vector> testData) {
+ return model.predict(testData);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.mllib.linalg.Vector)
+ */
+ @Override
+ public double predict(Vector testData) {
+ return model.predict(testData);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.spark.mllib.classification.ClassificationModel#predict(org.apache.spark.api.java.JavaRDD)
+ */
+ @Override
+ public JavaRDD<Double> predict(JavaRDD<Vector> testData) {
+ return model.predict(testData.rdd()).toJavaRDD().map(new DoubleValueMapper());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
new file mode 100644
index 0000000..6450f70
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.kmeans;
+
+import java.util.List;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.*;
+import org.apache.lens.ml.algo.lib.AlgoArgParser;
+import org.apache.lens.ml.algo.spark.HiveTableRDD;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hive.hcatalog.data.HCatRecord;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.clustering.KMeans;
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+import scala.Tuple2;
+
+/**
+ * The Class KMeansAlgo.
+ */
+@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
+public class KMeansAlgo implements MLAlgo {
+
+ /** The conf. */
+ private transient LensConf conf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /** The part filter. */
+ @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
+ private String partFilter = null;
+
+ /** The k. */
+ @AlgoParam(name = "k", help = "Number of cluster")
+ private int k;
+
+ /** The max iterations. */
+ @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
+ private int maxIterations = 100;
+
+ /** The runs. */
+ @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
+ private int runs = 1;
+
+ /** The initialization mode. */
+ @AlgoParam(name = "initializationMode",
+ help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
+ private String initializationMode = "k-means||";
+
+ @Override
+ public String getName() {
+ return getClass().getAnnotation(Algorithm.class).name();
+ }
+
+ @Override
+ public String getDescription() {
+ return getClass().getAnnotation(Algorithm.class).description();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
+ @Override
+ public void configure(LensConf configuration) {
+ this.conf = configuration;
+ }
+
+ @Override
+ public LensConf getConf() {
+ return conf;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
+ * java.lang.String, java.lang.String[])
+ */
+ @Override
+ public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
+ List<String> features = AlgoArgParser.parseArgs(this, params);
+ final int[] featurePositions = new int[features.size()];
+ final int NUM_FEATURES = features.size();
+
+ JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
+ try {
+ // Map feature names to positions
+ Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
+ List<FieldSchema> allCols = tbl.getAllCols();
+ int f = 0;
+ for (int i = 0; i < tbl.getAllCols().size(); i++) {
+ String colName = allCols.get(i).getName();
+ if (features.contains(colName)) {
+ featurePositions[f++] = i;
+ }
+ }
+
+ rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
+ JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
+ @Override
+ public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
+ HCatRecord hCatRecord = v1._2();
+ double[] arr = new double[NUM_FEATURES];
+ for (int i = 0; i < NUM_FEATURES; i++) {
+ Object val = hCatRecord.get(featurePositions[i]);
+ arr[i] = val == null ? 0d : (Double) val;
+ }
+ return Vectors.dense(arr);
+ }
+ });
+
+ KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
+ return new KMeansClusteringModel(modelId, model);
+ } catch (Exception e) {
+ throw new LensException("KMeans algo failed for " + db + "." + table, e);
+ }
+ }
+
+ /**
+ * To hive conf.
+ *
+ * @param conf the conf
+ * @return the hive conf
+ */
+ private HiveConf toHiveConf(LensConf conf) {
+ HiveConf hiveConf = new HiveConf();
+ for (String key : conf.getProperties().keySet()) {
+ hiveConf.set(key, conf.getProperties().get(key));
+ }
+ return hiveConf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
new file mode 100644
index 0000000..62dc536
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.kmeans;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+import org.apache.spark.mllib.clustering.KMeansModel;
+import org.apache.spark.mllib.linalg.Vectors;
+
+/**
+ * The Class KMeansClusteringModel.
+ */
+public class KMeansClusteringModel extends MLModel<Integer> {
+
+ /** The model. */
+ private final KMeansModel model;
+
+ /** The model id. */
+ private final String modelId;
+
+ /**
+ * Instantiates a new k means clustering model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public KMeansClusteringModel(String modelId, KMeansModel model) {
+ this.model = model;
+ this.modelId = modelId;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
+ */
+ @Override
+ public Integer predict(Object... args) {
+ // Convert the params to array of double
+ double[] arr = new double[args.length];
+ for (int i = 0; i < args.length; i++) {
+ if (args[i] != null) {
+ arr[i] = (Double) args[i];
+ } else {
+ arr[i] = 0d;
+ }
+ }
+
+ return model.predict(Vectors.dense(arr));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
new file mode 100644
index 0000000..55caf59
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.lr;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class LogisticRegressionAlgo.
+ */
+@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
+public class LogisticRegressionAlgo extends BaseSparkAlgo {
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /**
+ * Instantiates a new logistic regression algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public LogisticRegressionAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ iterations = getParamValue("iterations", 100);
+ stepSize = getParamValue("stepSize", 1.0d);
+ minBatchFraction = getParamValue("minBatchFraction", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
+ minBatchFraction);
+ return new LogitRegressionClassificationModel(modelId, lrModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
new file mode 100644
index 0000000..a4206e5
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.lr;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.LogisticRegressionModel;
+
+/**
+ * The Class LogitRegressionClassificationModel.
+ */
+public class LogitRegressionClassificationModel extends BaseSparkClassificationModel<LogisticRegressionModel> {
+
+ /**
+ * Instantiates a new logit regression classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public LogitRegressionClassificationModel(String modelId, LogisticRegressionModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
new file mode 100644
index 0000000..b4e1e78
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.nb;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayes;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class NaiveBayesAlgo.
+ */
+@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
+public class NaiveBayesAlgo extends BaseSparkAlgo {
+
+ /** The lambda. */
+ @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
+ private double lambda = 1.0;
+
+ /**
+ * Instantiates a new naive bayes algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public NaiveBayesAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ lambda = getParamValue("lambda", 1.0d);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
new file mode 100644
index 0000000..26d39df
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.nb;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.NaiveBayesModel;
+
+/**
+ * The Class NaiveBayesClassificationModel.
+ */
+public class NaiveBayesClassificationModel extends BaseSparkClassificationModel<NaiveBayesModel> {
+
+ /**
+ * Instantiates a new naive bayes classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public NaiveBayesClassificationModel(String modelId, NaiveBayesModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
new file mode 100644
index 0000000..21a036a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.svm;
+
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.ml.algo.api.AlgoParam;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+import org.apache.spark.mllib.classification.SVMWithSGD;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.rdd.RDD;
+
+/**
+ * The Class SVMAlgo.
+ */
+@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
+public class SVMAlgo extends BaseSparkAlgo {
+
+ /** The min batch fraction. */
+ @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
+ private double minBatchFraction;
+
+ /** The reg param. */
+ @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
+ private double regParam;
+
+ /** The step size. */
+ @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
+ private double stepSize;
+
+ /** The iterations. */
+ @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
+ private int iterations;
+
+ /**
+ * Instantiates a new SVM algo.
+ *
+ * @param name the name
+ * @param description the description
+ */
+ public SVMAlgo(String name, String description) {
+ super(name, description);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
+ */
+ @Override
+ public void parseAlgoParams(Map<String, String> params) {
+ minBatchFraction = getParamValue("minBatchFraction", 1.0);
+ regParam = getParamValue("regParam", 1.0);
+ stepSize = getParamValue("stepSize", 1.0);
+ iterations = getParamValue("iterations", 100);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
+ */
+ @Override
+ protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ throws LensException {
+ SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
+ return new SVMClassificationModel(modelId, svmModel);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
new file mode 100644
index 0000000..433c0f9
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.algo.spark.svm;
+
+import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+
+import org.apache.spark.mllib.classification.SVMModel;
+
+/**
+ * The Class SVMClassificationModel.
+ */
+public class SVMClassificationModel extends BaseSparkClassificationModel<SVMModel> {
+
+ /**
+ * Instantiates a new SVM classification model.
+ *
+ * @param modelId the model id
+ * @param model the model
+ */
+ public SVMClassificationModel(String modelId, SVMModel model) {
+ super(modelId, model);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
new file mode 100644
index 0000000..e124fb0
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLModel;
+
+/**
+ * Lens's machine learning interface used by client code as well as Lens ML service.
+ */
+public interface LensML {
+
+ /** Name of ML service */
+ String NAME = "ml";
+
+ /**
+ * Get list of available machine learning algorithms
+ *
+ * @return
+ */
+ List<String> getAlgorithms();
+
+ /**
+ * Get user friendly information about parameters accepted by the algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return map of param key to its help message
+ */
+ Map<String, String> getAlgoParamDescription(String algorithm);
+
+ /**
+ * Get a algo object instance which could be used to generate a model of the given algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the algo for name
+ * @throws LensException the lens exception
+ */
+ MLAlgo getAlgoForName(String algorithm) throws LensException;
+
+ /**
+ * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
+ * generate the model.
+ *
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param args the args
+ * @return Unique ID of the model created after training is complete
+ * @throws LensException the lens exception
+ */
+ String train(String table, String algorithm, String[] args) throws LensException;
+
+ /**
+ * Get model IDs for the given algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the models
+ * @throws LensException the lens exception
+ */
+ List<String> getModels(String algorithm) throws LensException;
+
+ /**
+ * Get a model instance given the algorithm name and model ID.
+ *
+ * @param algorithm the algorithm
+ * @param modelId the model id
+ * @return the model
+ * @throws LensException the lens exception
+ */
+ MLModel getModel(String algorithm, String modelId) throws LensException;
+
+ /**
+ * Get the FS location where model instance is saved.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the model path
+ */
+ String getModelPath(String algorithm, String modelID);
+
+ /**
+ * Evaluate model by running it against test data contained in the given table.
+ *
+ * @param session the session
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return Test report object containing test output table, and various evaluation metrics
+ * @throws LensException the lens exception
+ */
+ MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
+ String outputTable) throws LensException;
+
+ /**
+ * Get test reports for an algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the test reports
+ * @throws LensException the lens exception
+ */
+ List<String> getTestReports(String algorithm) throws LensException;
+
+ /**
+ * Get a test report by ID.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the test report
+ * @throws LensException the lens exception
+ */
+ MLTestReport getTestReport(String algorithm, String reportID) throws LensException;
+
+ /**
+ * Online predict call given a model ID, algorithm name and sample feature values.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param features the features
+ * @return prediction result
+ * @throws LensException the lens exception
+ */
+ Object predict(String algorithm, String modelID, Object[] features) throws LensException;
+
+ /**
+ * Permanently delete a model instance.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @throws LensException the lens exception
+ */
+ void deleteModel(String algorithm, String modelID) throws LensException;
+
+ /**
+ * Permanently delete a test report instance.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @throws LensException the lens exception
+ */
+ void deleteTestReport(String algorithm, String reportID) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
new file mode 100644
index 0000000..965161a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.io.Serializable;
+import java.util.List;
+
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+import lombok.ToString;
+
+/**
+ * Instantiates a new ML test report.
+ */
+@NoArgsConstructor
+@ToString
+public class MLTestReport implements Serializable {
+
+ /** The test table. */
+ @Getter
+ @Setter
+ private String testTable;
+
+ /** The output table. */
+ @Getter
+ @Setter
+ private String outputTable;
+
+ /** The output column. */
+ @Getter
+ @Setter
+ private String outputColumn;
+
+ /** The label column. */
+ @Getter
+ @Setter
+ private String labelColumn;
+
+ /** The feature columns. */
+ @Getter
+ @Setter
+ private List<String> featureColumns;
+
+ /** The algorithm. */
+ @Getter
+ @Setter
+ private String algorithm;
+
+ /** The model id. */
+ @Getter
+ @Setter
+ private String modelID;
+
+ /** The report id. */
+ @Getter
+ @Setter
+ private String reportID;
+
+ /** The query id. */
+ @Getter
+ @Setter
+ private String queryID;
+
+ /** The test output path. */
+ @Getter
+ @Setter
+ private String testOutputPath;
+
+ /** The prediction result column. */
+ @Getter
+ @Setter
+ private String predictionResultColumn;
+
+ /** The lens query id. */
+ @Getter
+ @Setter
+ private String lensQueryID;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
new file mode 100644
index 0000000..3f7dff1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+/**
+ * The Class ModelMetadata.
+ */
+@XmlRootElement
+/**
+ * Instantiates a new model metadata.
+ *
+ * @param modelID
+ * the model id
+ * @param table
+ * the table
+ * @param algorithm
+ * the algorithm
+ * @param params
+ * the params
+ * @param createdAt
+ * the created at
+ * @param modelPath
+ * the model path
+ * @param labelColumn
+ * the label column
+ * @param features
+ * the features
+ */
+@AllArgsConstructor
+/**
+ * Instantiates a new model metadata.
+ */
+@NoArgsConstructor
+public class ModelMetadata {
+
+ /** The model id. */
+ @XmlElement
+ @Getter
+ private String modelID;
+
+ /** The table. */
+ @XmlElement
+ @Getter
+ private String table;
+
+ /** The algorithm. */
+ @XmlElement
+ @Getter
+ private String algorithm;
+
+ /** The params. */
+ @XmlElement
+ @Getter
+ private String params;
+
+ /** The created at. */
+ @XmlElement
+ @Getter
+ private String createdAt;
+
+ /** The model path. */
+ @XmlElement
+ @Getter
+ private String modelPath;
+
+ /** The label column. */
+ @XmlElement
+ @Getter
+ private String labelColumn;
+
+ /** The features. */
+ @XmlElement
+ @Getter
+ private String features;
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.lang.Object#toString()
+ */
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder();
+
+ builder.append("Algorithm: ").append(algorithm).append('\n');
+ builder.append("Model ID: ").append(modelID).append('\n');
+ builder.append("Training table: ").append(table).append('\n');
+ builder.append("Features: ").append(features).append('\n');
+ builder.append("Labelled Column: ").append(labelColumn).append('\n');
+ builder.append("Training params: ").append(params).append('\n');
+ builder.append("Created on: ").append(createdAt).append('\n');
+ builder.append("Model saved at: ").append(modelPath).append('\n');
+ return builder.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
new file mode 100644
index 0000000..294fef3
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+/**
+ * The Class TestReport.
+ */
+@XmlRootElement
+/**
+ * Instantiates a new test report.
+ *
+ * @param testTable
+ * the test table
+ * @param outputTable
+ * the output table
+ * @param outputColumn
+ * the output column
+ * @param labelColumn
+ * the label column
+ * @param featureColumns
+ * the feature columns
+ * @param algorithm
+ * the algorithm
+ * @param modelID
+ * the model id
+ * @param reportID
+ * the report id
+ * @param queryID
+ * the query id
+ */
+@AllArgsConstructor
+/**
+ * Instantiates a new test report.
+ */
+@NoArgsConstructor
+public class TestReport {
+
+ /** The test table. */
+ @XmlElement
+ @Getter
+ private String testTable;
+
+ /** The output table. */
+ @XmlElement
+ @Getter
+ private String outputTable;
+
+ /** The output column. */
+ @XmlElement
+ @Getter
+ private String outputColumn;
+
+ /** The label column. */
+ @XmlElement
+ @Getter
+ private String labelColumn;
+
+ /** The feature columns. */
+ @XmlElement
+ @Getter
+ private String featureColumns;
+
+ /** The algorithm. */
+ @XmlElement
+ @Getter
+ private String algorithm;
+
+ /** The model id. */
+ @XmlElement
+ @Getter
+ private String modelID;
+
+ /** The report id. */
+ @XmlElement
+ @Getter
+ private String reportID;
+
+ /** The query id. */
+ @XmlElement
+ @Getter
+ private String queryID;
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see java.lang.Object#toString()
+ */
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder();
+ builder.append("Input test table: ").append(testTable).append('\n');
+ builder.append("Algorithm: ").append(algorithm).append('\n');
+ builder.append("Report id: ").append(reportID).append('\n');
+ builder.append("Model id: ").append(modelID).append('\n');
+ builder.append("Lens Query id: ").append(queryID).append('\n');
+ builder.append("Feature columns: ").append(featureColumns).append('\n');
+ builder.append("Labelled column: ").append(labelColumn).append('\n');
+ builder.append("Predicted column: ").append(outputColumn).append('\n');
+ builder.append("Test output table: ").append(outputTable).append('\n');
+ return builder.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
index 5e4d307..d444a32 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
@@ -18,9 +18,9 @@
*/
package org.apache.lens.ml.dao;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTestReport;
-import org.apache.lens.ml.task.MLTask;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.impl.MLTask;
public class MLDBUtils {
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
new file mode 100644
index 0000000..60a4008
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
@@ -0,0 +1,138 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+
+import org.apache.lens.ml.algo.api.MLModel;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.MapredContext;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.lazy.LazyDouble;
+import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyDoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.mapred.JobConf;
+
+/**
+ * Generic UDF to laod ML Models saved in HDFS and apply the model on list of columns passed as argument.
+ */
+@Description(name = "predict",
+ value = "_FUNC_(algorithm, modelID, features...) - Run prediction algorithm with given "
+ + "algorithm name, model ID and input feature columns")
+public final class HiveMLUDF extends GenericUDF {
+ private HiveMLUDF() {
+ }
+
+ /** The Constant UDF_NAME. */
+ public static final String UDF_NAME = "predict";
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(HiveMLUDF.class);
+
+ /** The conf. */
+ private JobConf conf;
+
+ /** The soi. */
+ private StringObjectInspector soi;
+
+ /** The doi. */
+ private LazyDoubleObjectInspector doi;
+
+ /** The model. */
+ private MLModel model;
+
+ /**
+ * Currently we only support double as the return value.
+ *
+ * @param objectInspectors the object inspectors
+ * @return the object inspector
+ * @throws UDFArgumentException the UDF argument exception
+ */
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] objectInspectors) throws UDFArgumentException {
+ // We require algo name, model id and at least one feature
+ if (objectInspectors.length < 3) {
+ throw new UDFArgumentLengthException("Algo name, model ID and at least one feature should be passed to "
+ + UDF_NAME);
+ }
+ LOG.info(UDF_NAME + " initialized");
+ return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#evaluate(org.apache.hadoop.hive.ql.udf.generic.GenericUDF.
+ * DeferredObject[])
+ */
+ @Override
+ public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
+ String algorithm = soi.getPrimitiveJavaObject(deferredObjects[0].get());
+ String modelId = soi.getPrimitiveJavaObject(deferredObjects[1].get());
+
+ Double[] features = new Double[deferredObjects.length - 2];
+ for (int i = 2; i < deferredObjects.length; i++) {
+ LazyDouble lazyDouble = (LazyDouble) deferredObjects[i].get();
+ features[i - 2] = (lazyDouble == null) ? 0d : doi.get(lazyDouble);
+ }
+
+ try {
+ if (model == null) {
+ model = ModelLoader.loadModel(conf, algorithm, modelId);
+ }
+ } catch (IOException e) {
+ throw new HiveException(e);
+ }
+
+ return model.predict(features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#getDisplayString(java.lang.String[])
+ */
+ @Override
+ public String getDisplayString(String[] strings) {
+ return UDF_NAME;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#configure(org.apache.hadoop.hive.ql.exec.MapredContext)
+ */
+ @Override
+ public void configure(MapredContext context) {
+ super.configure(context);
+ conf = context.getJobConf();
+ soi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ doi = LazyPrimitiveObjectInspectorFactory.LAZY_DOUBLE_OBJECT_INSPECTOR;
+ LOG.info(UDF_NAME + " configured. Model base dir path: " + conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
+ }
+}
[6/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
Posted by sh...@apache.org.
Lens-465 : Refactor ml packages. (sharad)
Project: http://git-wip-us.apache.org/repos/asf/incubator-lens/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-lens/commit/0f5ea4c7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-lens/tree/0f5ea4c7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-lens/diff/0f5ea4c7
Branch: refs/heads/master
Commit: 0f5ea4c7827fc4740c1c2ba0fb9527903a2b910c
Parents: 278e0e8
Author: Sharad Agarwal <sh...@flipkarts-MacBook-Pro.local>
Authored: Sun Apr 5 10:48:32 2015 +0530
Committer: Sharad Agarwal <sh...@flipkarts-MacBook-Pro.local>
Committed: Sun Apr 5 10:48:32 2015 +0530
----------------------------------------------------------------------
.../org/apache/lens/api/ml/ModelMetadata.java | 118 ---
.../java/org/apache/lens/api/ml/TestReport.java | 125 ----
.../org/apache/lens/client/LensMLClient.java | 12 +-
.../apache/lens/client/LensMLJerseyClient.java | 4 +-
.../java/org/apache/lens/ml/AlgoArgParser.java | 114 ---
.../main/java/org/apache/lens/ml/AlgoParam.java | 53 --
.../main/java/org/apache/lens/ml/Algorithm.java | 46 --
.../java/org/apache/lens/ml/Algorithms.java | 87 ---
.../org/apache/lens/ml/ClassifierBaseModel.java | 46 --
.../java/org/apache/lens/ml/ExampleUtils.java | 101 ---
.../org/apache/lens/ml/ForecastingModel.java | 93 ---
.../main/java/org/apache/lens/ml/HiveMLUDF.java | 136 ----
.../org/apache/lens/ml/LabelledPrediction.java | 32 -
.../main/java/org/apache/lens/ml/LensML.java | 159 ----
.../java/org/apache/lens/ml/LensMLImpl.java | 734 ------------------
.../main/java/org/apache/lens/ml/MLAlgo.java | 53 --
.../main/java/org/apache/lens/ml/MLDriver.java | 71 --
.../main/java/org/apache/lens/ml/MLModel.java | 79 --
.../main/java/org/apache/lens/ml/MLRunner.java | 173 -----
.../java/org/apache/lens/ml/MLTestMetric.java | 28 -
.../java/org/apache/lens/ml/MLTestReport.java | 95 ---
.../main/java/org/apache/lens/ml/MLUtils.java | 60 --
.../java/org/apache/lens/ml/ModelLoader.java | 239 ------
.../org/apache/lens/ml/MultiPrediction.java | 28 -
.../java/org/apache/lens/ml/QueryRunner.java | 56 --
.../org/apache/lens/ml/TableTestingSpec.java | 325 --------
.../org/apache/lens/ml/algo/api/AlgoParam.java | 53 ++
.../org/apache/lens/ml/algo/api/Algorithm.java | 46 ++
.../org/apache/lens/ml/algo/api/MLAlgo.java | 53 ++
.../org/apache/lens/ml/algo/api/MLDriver.java | 71 ++
.../org/apache/lens/ml/algo/api/MLModel.java | 79 ++
.../apache/lens/ml/algo/lib/AlgoArgParser.java | 117 +++
.../org/apache/lens/ml/algo/lib/Algorithms.java | 89 +++
.../lens/ml/algo/lib/ClassifierBaseModel.java | 48 ++
.../lens/ml/algo/lib/ForecastingModel.java | 95 +++
.../lens/ml/algo/lib/LabelledPrediction.java | 32 +
.../lens/ml/algo/lib/MultiPrediction.java | 28 +
.../lens/ml/algo/spark/BaseSparkAlgo.java | 287 +++++++
.../spark/BaseSparkClassificationModel.java | 65 ++
.../ml/algo/spark/ColumnFeatureFunction.java | 102 +++
.../lens/ml/algo/spark/DoubleValueMapper.java | 39 +
.../lens/ml/algo/spark/FeatureFunction.java | 40 +
.../lens/ml/algo/spark/FeatureValueMapper.java | 36 +
.../apache/lens/ml/algo/spark/HiveTableRDD.java | 63 ++
.../lens/ml/algo/spark/SparkMLDriver.java | 278 +++++++
.../lens/ml/algo/spark/TableTrainingSpec.java | 433 +++++++++++
.../lens/ml/algo/spark/dt/DecisionTreeAlgo.java | 108 +++
.../dt/DecisionTreeClassificationModel.java | 37 +
.../algo/spark/dt/SparkDecisionTreeModel.java | 75 ++
.../lens/ml/algo/spark/kmeans/KMeansAlgo.java | 163 ++++
.../spark/kmeans/KMeansClusteringModel.java | 67 ++
.../algo/spark/lr/LogisticRegressionAlgo.java | 86 +++
.../lr/LogitRegressionClassificationModel.java | 39 +
.../lens/ml/algo/spark/nb/NaiveBayesAlgo.java | 73 ++
.../spark/nb/NaiveBayesClassificationModel.java | 39 +
.../apache/lens/ml/algo/spark/svm/SVMAlgo.java | 90 +++
.../algo/spark/svm/SVMClassificationModel.java | 39 +
.../java/org/apache/lens/ml/api/LensML.java | 161 ++++
.../org/apache/lens/ml/api/MLTestReport.java | 95 +++
.../org/apache/lens/ml/api/ModelMetadata.java | 118 +++
.../java/org/apache/lens/ml/api/TestReport.java | 125 ++++
.../java/org/apache/lens/ml/dao/MLDBUtils.java | 6 +-
.../java/org/apache/lens/ml/impl/HiveMLUDF.java | 138 ++++
.../org/apache/lens/ml/impl/LensMLImpl.java | 744 +++++++++++++++++++
.../java/org/apache/lens/ml/impl/MLRunner.java | 172 +++++
.../java/org/apache/lens/ml/impl/MLTask.java | 285 +++++++
.../java/org/apache/lens/ml/impl/MLUtils.java | 62 ++
.../org/apache/lens/ml/impl/ModelLoader.java | 242 ++++++
.../org/apache/lens/ml/impl/QueryRunner.java | 56 ++
.../apache/lens/ml/impl/TableTestingSpec.java | 325 ++++++++
.../java/org/apache/lens/ml/server/MLApp.java | 60 ++
.../org/apache/lens/ml/server/MLService.java | 27 +
.../apache/lens/ml/server/MLServiceImpl.java | 329 ++++++++
.../lens/ml/server/MLServiceResource.java | 427 +++++++++++
.../lens/ml/spark/ColumnFeatureFunction.java | 102 ---
.../apache/lens/ml/spark/DoubleValueMapper.java | 39 -
.../apache/lens/ml/spark/FeatureFunction.java | 40 -
.../lens/ml/spark/FeatureValueMapper.java | 36 -
.../org/apache/lens/ml/spark/HiveTableRDD.java | 63 --
.../org/apache/lens/ml/spark/SparkMLDriver.java | 275 -------
.../apache/lens/ml/spark/TableTrainingSpec.java | 433 -----------
.../lens/ml/spark/algos/BaseSparkAlgo.java | 290 --------
.../lens/ml/spark/algos/DecisionTreeAlgo.java | 109 ---
.../apache/lens/ml/spark/algos/KMeansAlgo.java | 163 ----
.../ml/spark/algos/LogisticRegressionAlgo.java | 86 ---
.../lens/ml/spark/algos/NaiveBayesAlgo.java | 73 --
.../org/apache/lens/ml/spark/algos/SVMAlgo.java | 90 ---
.../models/BaseSparkClassificationModel.java | 65 --
.../models/DecisionTreeClassificationModel.java | 35 -
.../ml/spark/models/KMeansClusteringModel.java | 67 --
.../LogitRegressionClassificationModel.java | 37 -
.../models/NaiveBayesClassificationModel.java | 37 -
.../ml/spark/models/SVMClassificationModel.java | 37 -
.../ml/spark/models/SparkDecisionTreeModel.java | 75 --
.../java/org/apache/lens/ml/task/MLTask.java | 286 -------
.../java/org/apache/lens/rdd/LensRDDClient.java | 2 +-
.../java/org/apache/lens/server/ml/MLApp.java | 60 --
.../org/apache/lens/server/ml/MLService.java | 27 -
.../apache/lens/server/ml/MLServiceImpl.java | 324 --------
.../lens/server/ml/MLServiceResource.java | 415 -----------
.../java/org/apache/lens/ml/ExampleUtils.java | 101 +++
.../java/org/apache/lens/ml/TestMLResource.java | 15 +-
.../java/org/apache/lens/ml/TestMLRunner.java | 7 +-
lens-ml-lib/src/test/resources/lens-site.xml | 6 +-
tools/conf-pseudo-distr/server/lens-site.xml | 6 +-
105 files changed, 6367 insertions(+), 6343 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/api/ml/ModelMetadata.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/api/ml/ModelMetadata.java b/lens-ml-lib/src/main/java/org/apache/lens/api/ml/ModelMetadata.java
deleted file mode 100644
index 0f072bf..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/api/ml/ModelMetadata.java
+++ /dev/null
@@ -1,118 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.api.ml;
-
-import javax.xml.bind.annotation.XmlElement;
-import javax.xml.bind.annotation.XmlRootElement;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-
-/**
- * The Class ModelMetadata.
- */
-@XmlRootElement
-/**
- * Instantiates a new model metadata.
- *
- * @param modelID
- * the model id
- * @param table
- * the table
- * @param algorithm
- * the algorithm
- * @param params
- * the params
- * @param createdAt
- * the created at
- * @param modelPath
- * the model path
- * @param labelColumn
- * the label column
- * @param features
- * the features
- */
-@AllArgsConstructor
-/**
- * Instantiates a new model metadata.
- */
-@NoArgsConstructor
-public class ModelMetadata {
-
- /** The model id. */
- @XmlElement
- @Getter
- private String modelID;
-
- /** The table. */
- @XmlElement
- @Getter
- private String table;
-
- /** The algorithm. */
- @XmlElement
- @Getter
- private String algorithm;
-
- /** The params. */
- @XmlElement
- @Getter
- private String params;
-
- /** The created at. */
- @XmlElement
- @Getter
- private String createdAt;
-
- /** The model path. */
- @XmlElement
- @Getter
- private String modelPath;
-
- /** The label column. */
- @XmlElement
- @Getter
- private String labelColumn;
-
- /** The features. */
- @XmlElement
- @Getter
- private String features;
-
- /*
- * (non-Javadoc)
- *
- * @see java.lang.Object#toString()
- */
- @Override
- public String toString() {
- StringBuilder builder = new StringBuilder();
-
- builder.append("Algorithm: ").append(algorithm).append('\n');
- builder.append("Model ID: ").append(modelID).append('\n');
- builder.append("Training table: ").append(table).append('\n');
- builder.append("Features: ").append(features).append('\n');
- builder.append("Labelled Column: ").append(labelColumn).append('\n');
- builder.append("Training params: ").append(params).append('\n');
- builder.append("Created on: ").append(createdAt).append('\n');
- builder.append("Model saved at: ").append(modelPath).append('\n');
- return builder.toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/api/ml/TestReport.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/api/ml/TestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/api/ml/TestReport.java
deleted file mode 100644
index 2ae384b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/api/ml/TestReport.java
+++ /dev/null
@@ -1,125 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.api.ml;
-
-import javax.xml.bind.annotation.XmlElement;
-import javax.xml.bind.annotation.XmlRootElement;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-
-/**
- * The Class TestReport.
- */
-@XmlRootElement
-/**
- * Instantiates a new test report.
- *
- * @param testTable
- * the test table
- * @param outputTable
- * the output table
- * @param outputColumn
- * the output column
- * @param labelColumn
- * the label column
- * @param featureColumns
- * the feature columns
- * @param algorithm
- * the algorithm
- * @param modelID
- * the model id
- * @param reportID
- * the report id
- * @param queryID
- * the query id
- */
-@AllArgsConstructor
-/**
- * Instantiates a new test report.
- */
-@NoArgsConstructor
-public class TestReport {
-
- /** The test table. */
- @XmlElement
- @Getter
- private String testTable;
-
- /** The output table. */
- @XmlElement
- @Getter
- private String outputTable;
-
- /** The output column. */
- @XmlElement
- @Getter
- private String outputColumn;
-
- /** The label column. */
- @XmlElement
- @Getter
- private String labelColumn;
-
- /** The feature columns. */
- @XmlElement
- @Getter
- private String featureColumns;
-
- /** The algorithm. */
- @XmlElement
- @Getter
- private String algorithm;
-
- /** The model id. */
- @XmlElement
- @Getter
- private String modelID;
-
- /** The report id. */
- @XmlElement
- @Getter
- private String reportID;
-
- /** The query id. */
- @XmlElement
- @Getter
- private String queryID;
-
- /*
- * (non-Javadoc)
- *
- * @see java.lang.Object#toString()
- */
- @Override
- public String toString() {
- StringBuilder builder = new StringBuilder();
- builder.append("Input test table: ").append(testTable).append('\n');
- builder.append("Algorithm: ").append(algorithm).append('\n');
- builder.append("Report id: ").append(reportID).append('\n');
- builder.append("Model id: ").append(modelID).append('\n');
- builder.append("Lens Query id: ").append(queryID).append('\n');
- builder.append("Feature columns: ").append(featureColumns).append('\n');
- builder.append("Labelled column: ").append(labelColumn).append('\n');
- builder.append("Predicted column: ").append(outputColumn).append('\n');
- builder.append("Test output table: ").append(outputTable).append('\n');
- return builder.toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
index d9ec314..4d4521e 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
@@ -32,12 +32,12 @@ import javax.ws.rs.core.Form;
import org.apache.lens.api.LensException;
import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.ml.ModelMetadata;
-import org.apache.lens.api.ml.TestReport;
-import org.apache.lens.ml.LensML;
-import org.apache.lens.ml.MLAlgo;
-import org.apache.lens.ml.MLModel;
-import org.apache.lens.ml.MLTestReport;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.LensML;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.api.ModelMetadata;
+import org.apache.lens.ml.api.TestReport;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
index af47a41..c68dd12 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
@@ -31,8 +31,8 @@ import javax.ws.rs.core.MediaType;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
-import org.apache.lens.api.ml.ModelMetadata;
-import org.apache.lens.api.ml.TestReport;
+import org.apache.lens.ml.api.ModelMetadata;
+import org.apache.lens.ml.api.TestReport;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
deleted file mode 100644
index 20da083..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoArgParser.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.reflect.Field;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-/**
- * The Class AlgoArgParser.
- */
-public final class AlgoArgParser {
- private AlgoArgParser() {
- }
-
- /**
- * The Class CustomArgParser.
- *
- * @param <E> the element type
- */
- public abstract static class CustomArgParser<E> {
-
- /**
- * Parses the.
- *
- * @param value the value
- * @return the e
- */
- public abstract E parse(String value);
- }
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(AlgoArgParser.class);
-
- /**
- * Extracts feature names. If the algo has any parameters associated with @AlgoParam annotation, those are set
- * as well.
- *
- * @param algo the algo
- * @param args the args
- * @return List of feature column names.
- */
- public static List<String> parseArgs(MLAlgo algo, String[] args) {
- List<String> featureColumns = new ArrayList<String>();
- Class<? extends MLAlgo> algoClass = algo.getClass();
- // Get param fields
- Map<String, Field> fieldMap = new HashMap<String, Field>();
-
- for (Field fld : algoClass.getDeclaredFields()) {
- fld.setAccessible(true);
- AlgoParam paramAnnotation = fld.getAnnotation(AlgoParam.class);
- if (paramAnnotation != null) {
- fieldMap.put(paramAnnotation.name(), fld);
- }
- }
-
- for (int i = 0; i < args.length; i += 2) {
- String key = args[i].trim();
- String value = args[i + 1].trim();
-
- try {
- if ("feature".equalsIgnoreCase(key)) {
- featureColumns.add(value);
- } else if (fieldMap.containsKey(key)) {
- Field f = fieldMap.get(key);
- if (String.class.equals(f.getType())) {
- f.set(algo, value);
- } else if (Integer.TYPE.equals(f.getType())) {
- f.setInt(algo, Integer.parseInt(value));
- } else if (Double.TYPE.equals(f.getType())) {
- f.setDouble(algo, Double.parseDouble(value));
- } else if (Long.TYPE.equals(f.getType())) {
- f.setLong(algo, Long.parseLong(value));
- } else {
- // check if the algo provides a deserializer for this param
- String customParserClass = algo.getConf().getProperties().get("lens.ml.args." + key);
- if (customParserClass != null) {
- Class<? extends CustomArgParser<?>> clz = (Class<? extends CustomArgParser<?>>) Class
- .forName(customParserClass);
- CustomArgParser<?> parser = clz.newInstance();
- f.set(algo, parser.parse(value));
- } else {
- LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
- }
- }
- }
- } catch (Exception exc) {
- LOG.error("Error while setting param " + key + " to " + value + " for algo " + algo);
- }
- }
- return featureColumns;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
deleted file mode 100644
index 5836f51..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/AlgoParam.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
-
-/**
- * The Interface AlgoParam.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.FIELD)
-public @interface AlgoParam {
-
- /**
- * Name.
- *
- * @return the string
- */
- String name();
-
- /**
- * Help.
- *
- * @return the string
- */
- String help();
-
- /**
- * Default value.
- *
- * @return the string
- */
- String defaultValue() default "None";
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithm.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithm.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithm.java
deleted file mode 100644
index 7025d7b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithm.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
-
-/**
- * The Interface Algorithm.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.TYPE)
-public @interface Algorithm {
-
- /**
- * Name.
- *
- * @return the string
- */
- String name();
-
- /**
- * Description.
- *
- * @return the string
- */
- String description();
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
deleted file mode 100644
index c1b7212..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/Algorithms.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.lang.reflect.Constructor;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-
-/**
- * The Class Algorithms.
- */
-public class Algorithms {
-
- /** The algorithm classes. */
- private final Map<String, Class<? extends MLAlgo>> algorithmClasses
- = new HashMap<String, Class<? extends MLAlgo>>();
-
- /**
- * Register.
- *
- * @param algoClass the algo class
- */
- public void register(Class<? extends MLAlgo> algoClass) {
- if (algoClass != null && algoClass.getAnnotation(Algorithm.class) != null) {
- algorithmClasses.put(algoClass.getAnnotation(Algorithm.class).name(), algoClass);
- } else {
- throw new IllegalArgumentException("Not a valid algorithm class: " + algoClass);
- }
- }
-
- /**
- * Gets the algo for name.
- *
- * @param name the name
- * @return the algo for name
- * @throws LensException the lens exception
- */
- public MLAlgo getAlgoForName(String name) throws LensException {
- Class<? extends MLAlgo> algoClass = algorithmClasses.get(name);
- if (algoClass == null) {
- return null;
- }
- Algorithm algoAnnotation = algoClass.getAnnotation(Algorithm.class);
- String description = algoAnnotation.description();
- try {
- Constructor<? extends MLAlgo> algoConstructor = algoClass.getConstructor(String.class, String.class);
- return algoConstructor.newInstance(name, description);
- } catch (Exception exc) {
- throw new LensException("Unable to get algo: " + name, exc);
- }
- }
-
- /**
- * Checks if is algo supported.
- *
- * @param name the name
- * @return true, if is algo supported
- */
- public boolean isAlgoSupported(String name) {
- return algorithmClasses.containsKey(name);
- }
-
- public List<String> getAlgorithmNames() {
- return new ArrayList<String>(algorithmClasses.keySet());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/ClassifierBaseModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/ClassifierBaseModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/ClassifierBaseModel.java
deleted file mode 100644
index 68008fe..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/ClassifierBaseModel.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-/**
- * Return a single double value as a prediction. This is useful in classifiers where the classifier returns a single
- * class label as a prediction.
- */
-public abstract class ClassifierBaseModel extends MLModel<Double> {
-
- /**
- * Gets the feature vector.
- *
- * @param args the args
- * @return the feature vector
- */
- public final double[] getFeatureVector(Object[] args) {
- double[] features = new double[args.length];
- for (int i = 0; i < args.length; i++) {
- if (args[i] instanceof Double) {
- features[i] = (Double) args[i];
- } else if (args[i] instanceof String) {
- features[i] = Double.parseDouble((String) args[i]);
- } else {
- features[i] = 0.0;
- }
- }
- return features;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/ExampleUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/ExampleUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/ExampleUtils.java
deleted file mode 100644
index 9fe1ea0..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/ExampleUtils.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.TableType;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
-import org.apache.hadoop.hive.serde.serdeConstants;
-import org.apache.hadoop.mapred.TextInputFormat;
-
-/**
- * The Class ExampleUtils.
- */
-public final class ExampleUtils {
- private ExampleUtils() {
- }
-
- private static final Log LOG = LogFactory.getLog(ExampleUtils.class);
-
- /**
- * Creates the example table.
- *
- * @param conf the conf
- * @param database the database
- * @param tableName the table name
- * @param sampleDataFile the sample data file
- * @param labelColumn the label column
- * @param features the features
- * @throws HiveException the hive exception
- */
- public static void createTable(HiveConf conf, String database, String tableName, String sampleDataFile,
- String labelColumn, Map<String, String> tableParams, String... features) throws HiveException {
-
- Path dataFilePath = new Path(sampleDataFile);
- Path partDir = dataFilePath.getParent();
-
- // Create table
- List<FieldSchema> columns = new ArrayList<FieldSchema>();
-
- // Label is optional. Not used for unsupervised models.
- // If present, label will be the first column, followed by features
- if (labelColumn != null) {
- columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
- }
-
- for (String feature : features) {
- columns.add(new FieldSchema(feature, "double", "Feature " + feature));
- }
-
- Table tbl = Hive.get(conf).newTable(database + "." + tableName);
- tbl.setTableType(TableType.MANAGED_TABLE);
- tbl.getTTable().getSd().setCols(columns);
- tbl.getTTable().getParameters().putAll(tableParams);
- tbl.setInputFormatClass(TextInputFormat.class);
- tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
- tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
-
- List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
- partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
- tbl.setPartCols(partCols);
-
- Hive.get(conf).createTable(tbl, false);
- LOG.info("Created table " + tableName);
-
- // Add partition for the data file
- AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName, false);
- Map<String, String> partSpec = new HashMap<String, String>();
- partSpec.put("dummy_partition_col", "dummy_val");
- partitionDesc.addPartition(partSpec, partDir.toUri().toString());
- Hive.get(conf).createPartitions(partitionDesc);
- LOG.info(tableName + ": Added partition " + partDir.toUri().toString());
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/ForecastingModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/ForecastingModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/ForecastingModel.java
deleted file mode 100644
index 5163db5..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/ForecastingModel.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.List;
-
-/**
- * The Class ForecastingModel.
- */
-public class ForecastingModel extends MLModel<MultiPrediction> {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public MultiPrediction predict(Object... args) {
- return new ForecastingPredictions(null);
- }
-
- /**
- * The Class ForecastingPredictions.
- */
- public static class ForecastingPredictions implements MultiPrediction {
-
- /** The values. */
- private final List<LabelledPrediction> values;
-
- /**
- * Instantiates a new forecasting predictions.
- *
- * @param values the values
- */
- public ForecastingPredictions(List<LabelledPrediction> values) {
- this.values = values;
- }
-
- @Override
- public List<LabelledPrediction> getPredictions() {
- return values;
- }
- }
-
- /**
- * The Class ForecastingLabel.
- */
- public static class ForecastingLabel implements LabelledPrediction<Long, Double> {
-
- /** The timestamp. */
- private final Long timestamp;
-
- /** The value. */
- private final double value;
-
- /**
- * Instantiates a new forecasting label.
- *
- * @param timestamp the timestamp
- * @param value the value
- */
- public ForecastingLabel(long timestamp, double value) {
- this.timestamp = timestamp;
- this.value = value;
- }
-
- @Override
- public Long getLabel() {
- return timestamp;
- }
-
- @Override
- public Double getPrediction() {
- return value;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/HiveMLUDF.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/HiveMLUDF.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/HiveMLUDF.java
deleted file mode 100644
index 687ca54..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/HiveMLUDF.java
+++ /dev/null
@@ -1,136 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.IOException;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.MapredContext;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.lazy.LazyDouble;
-import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyDoubleObjectInspector;
-import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
-import org.apache.hadoop.mapred.JobConf;
-
-/**
- * Generic UDF to laod ML Models saved in HDFS and apply the model on list of columns passed as argument.
- */
-@Description(name = "predict",
- value = "_FUNC_(algorithm, modelID, features...) - Run prediction algorithm with given "
- + "algorithm name, model ID and input feature columns")
-public final class HiveMLUDF extends GenericUDF {
- private HiveMLUDF() {
- }
-
- /** The Constant UDF_NAME. */
- public static final String UDF_NAME = "predict";
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(HiveMLUDF.class);
-
- /** The conf. */
- private JobConf conf;
-
- /** The soi. */
- private StringObjectInspector soi;
-
- /** The doi. */
- private LazyDoubleObjectInspector doi;
-
- /** The model. */
- private MLModel model;
-
- /**
- * Currently we only support double as the return value.
- *
- * @param objectInspectors the object inspectors
- * @return the object inspector
- * @throws UDFArgumentException the UDF argument exception
- */
- @Override
- public ObjectInspector initialize(ObjectInspector[] objectInspectors) throws UDFArgumentException {
- // We require algo name, model id and at least one feature
- if (objectInspectors.length < 3) {
- throw new UDFArgumentLengthException("Algo name, model ID and at least one feature should be passed to "
- + UDF_NAME);
- }
- LOG.info(UDF_NAME + " initialized");
- return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#evaluate(org.apache.hadoop.hive.ql.udf.generic.GenericUDF.
- * DeferredObject[])
- */
- @Override
- public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
- String algorithm = soi.getPrimitiveJavaObject(deferredObjects[0].get());
- String modelId = soi.getPrimitiveJavaObject(deferredObjects[1].get());
-
- Double[] features = new Double[deferredObjects.length - 2];
- for (int i = 2; i < deferredObjects.length; i++) {
- LazyDouble lazyDouble = (LazyDouble) deferredObjects[i].get();
- features[i - 2] = (lazyDouble == null) ? 0d : doi.get(lazyDouble);
- }
-
- try {
- if (model == null) {
- model = ModelLoader.loadModel(conf, algorithm, modelId);
- }
- } catch (IOException e) {
- throw new HiveException(e);
- }
-
- return model.predict(features);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#getDisplayString(java.lang.String[])
- */
- @Override
- public String getDisplayString(String[] strings) {
- return UDF_NAME;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hadoop.hive.ql.udf.generic.GenericUDF#configure(org.apache.hadoop.hive.ql.exec.MapredContext)
- */
- @Override
- public void configure(MapredContext context) {
- super.configure(context);
- conf = context.getJobConf();
- soi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
- doi = LazyPrimitiveObjectInspectorFactory.LAZY_DOUBLE_OBJECT_INSPECTOR;
- LOG.info(UDF_NAME + " configured. Model base dir path: " + conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/LabelledPrediction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LabelledPrediction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LabelledPrediction.java
deleted file mode 100644
index 6e7f677..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LabelledPrediction.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-/**
- * Prediction type used when the model prediction is of complex types. For example, in forecasting the predictions are a
- * series of timestamp, and value pairs.
- *
- * @param <LABELTYPE> the generic type
- * @param <PREDICTIONTYPE> the generic type
- */
-public interface LabelledPrediction<LABELTYPE, PREDICTIONTYPE> {
- LABELTYPE getLabel();
-
- PREDICTIONTYPE getPrediction();
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
deleted file mode 100644
index cdf28dd..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensML.java
+++ /dev/null
@@ -1,159 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.List;
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.api.LensSessionHandle;
-
-/**
- * Lens's machine learning interface used by client code as well as Lens ML service.
- */
-public interface LensML {
-
- /** Name of ML service */
- String NAME = "ml";
-
- /**
- * Get list of available machine learning algorithms
- *
- * @return
- */
- List<String> getAlgorithms();
-
- /**
- * Get user friendly information about parameters accepted by the algorithm.
- *
- * @param algorithm the algorithm
- * @return map of param key to its help message
- */
- Map<String, String> getAlgoParamDescription(String algorithm);
-
- /**
- * Get a algo object instance which could be used to generate a model of the given algorithm.
- *
- * @param algorithm the algorithm
- * @return the algo for name
- * @throws LensException the lens exception
- */
- MLAlgo getAlgoForName(String algorithm) throws LensException;
-
- /**
- * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
- * generate the model.
- *
- * @param table the table
- * @param algorithm the algorithm
- * @param args the args
- * @return Unique ID of the model created after training is complete
- * @throws LensException the lens exception
- */
- String train(String table, String algorithm, String[] args) throws LensException;
-
- /**
- * Get model IDs for the given algorithm.
- *
- * @param algorithm the algorithm
- * @return the models
- * @throws LensException the lens exception
- */
- List<String> getModels(String algorithm) throws LensException;
-
- /**
- * Get a model instance given the algorithm name and model ID.
- *
- * @param algorithm the algorithm
- * @param modelId the model id
- * @return the model
- * @throws LensException the lens exception
- */
- MLModel getModel(String algorithm, String modelId) throws LensException;
-
- /**
- * Get the FS location where model instance is saved.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the model path
- */
- String getModelPath(String algorithm, String modelID);
-
- /**
- * Evaluate model by running it against test data contained in the given table.
- *
- * @param session the session
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return Test report object containing test output table, and various evaluation metrics
- * @throws LensException the lens exception
- */
- MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
- String outputTable) throws LensException;
-
- /**
- * Get test reports for an algorithm.
- *
- * @param algorithm the algorithm
- * @return the test reports
- * @throws LensException the lens exception
- */
- List<String> getTestReports(String algorithm) throws LensException;
-
- /**
- * Get a test report by ID.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- * @throws LensException the lens exception
- */
- MLTestReport getTestReport(String algorithm, String reportID) throws LensException;
-
- /**
- * Online predict call given a model ID, algorithm name and sample feature values.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param features the features
- * @return prediction result
- * @throws LensException the lens exception
- */
- Object predict(String algorithm, String modelID, Object[] features) throws LensException;
-
- /**
- * Permanently delete a model instance.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @throws LensException the lens exception
- */
- void deleteModel(String algorithm, String modelID) throws LensException;
-
- /**
- * Permanently delete a test report instance.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @throws LensException the lens exception
- */
- void deleteTestReport(String algorithm, String reportID) throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
deleted file mode 100644
index b45f7f2..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/LensMLImpl.java
+++ /dev/null
@@ -1,734 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.IOException;
-import java.io.ObjectOutputStream;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
-
-import javax.ws.rs.client.Client;
-import javax.ws.rs.client.ClientBuilder;
-import javax.ws.rs.client.Entity;
-import javax.ws.rs.client.WebTarget;
-import javax.ws.rs.core.MediaType;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.query.LensQuery;
-import org.apache.lens.api.query.QueryHandle;
-import org.apache.lens.api.query.QueryStatus;
-import org.apache.lens.ml.spark.SparkMLDriver;
-import org.apache.lens.ml.spark.algos.BaseSparkAlgo;
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.api.session.SessionService;
-
-import org.apache.commons.io.IOUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.ql.session.SessionState;
-import org.apache.spark.api.java.JavaSparkContext;
-
-import org.glassfish.jersey.media.multipart.FormDataBodyPart;
-import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
-import org.glassfish.jersey.media.multipart.FormDataMultiPart;
-import org.glassfish.jersey.media.multipart.MultiPartFeature;
-
-/**
- * The Class LensMLImpl.
- */
-public class LensMLImpl implements LensML {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
-
- /** The drivers. */
- protected List<MLDriver> drivers;
-
- /** The conf. */
- private HiveConf conf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /** Check if the predict UDF has been registered for a user */
- private final Map<LensSessionHandle, Boolean> predictUdfStatus;
- /** Background thread to periodically check if we need to clear expire status for a session */
- private ScheduledExecutorService udfStatusExpirySvc;
-
- /**
- * Instantiates a new lens ml impl.
- *
- * @param conf the conf
- */
- public LensMLImpl(HiveConf conf) {
- this.conf = conf;
- this.predictUdfStatus = new ConcurrentHashMap<LensSessionHandle, Boolean>();
- }
-
- public HiveConf getConf() {
- return conf;
- }
-
- /**
- * Use an existing Spark context. Useful in case of
- *
- * @param jsc JavaSparkContext instance
- */
- public void setSparkContext(JavaSparkContext jsc) {
- this.sparkContext = jsc;
- }
-
- public List<String> getAlgorithms() {
- List<String> algos = new ArrayList<String>();
- for (MLDriver driver : drivers) {
- algos.addAll(driver.getAlgoNames());
- }
- return algos;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
- */
- public MLAlgo getAlgoForName(String algorithm) throws LensException {
- for (MLDriver driver : drivers) {
- if (driver.isAlgoSupported(algorithm)) {
- return driver.getAlgoInstance(algorithm);
- }
- }
- throw new LensException("Algo not supported " + algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
- */
- public String train(String table, String algorithm, String[] args) throws LensException {
- MLAlgo algo = getAlgoForName(algorithm);
-
- String modelId = UUID.randomUUID().toString();
-
- LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
- + Arrays.toString(args));
-
- String database = null;
- if (SessionState.get() != null) {
- database = SessionState.get().getCurrentDatabase();
- } else {
- database = "default";
- }
-
- MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
-
- LOG.info("Done training model: " + modelId);
-
- model.setCreatedAt(new Date());
- model.setAlgoName(algorithm);
-
- Path modelLocation = null;
- try {
- modelLocation = persistModel(model);
- LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
- return model.getId();
- } catch (IOException e) {
- throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
- }
- }
-
- /**
- * Gets the algo dir.
- *
- * @param algoName the algo name
- * @return the algo dir
- * @throws IOException Signals that an I/O exception has occurred.
- */
- private Path getAlgoDir(String algoName) throws IOException {
- String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
- return new Path(new Path(modelSaveBaseDir), algoName);
- }
-
- /**
- * Persist model.
- *
- * @param model the model
- * @return the path
- * @throws IOException Signals that an I/O exception has occurred.
- */
- private Path persistModel(MLModel model) throws IOException {
- // Get model save path
- Path algoDir = getAlgoDir(model.getAlgoName());
- FileSystem fs = algoDir.getFileSystem(conf);
-
- if (!fs.exists(algoDir)) {
- fs.mkdirs(algoDir);
- }
-
- Path modelSavePath = new Path(algoDir, model.getId());
- ObjectOutputStream outputStream = null;
-
- try {
- outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
- outputStream.writeObject(model);
- outputStream.flush();
- } catch (IOException io) {
- LOG.error("Error saving model " + model.getId() + " reason: " + io.getMessage());
- throw io;
- } finally {
- IOUtils.closeQuietly(outputStream);
- }
- return modelSavePath;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
- */
- public List<String> getModels(String algorithm) throws LensException {
- try {
- Path algoDir = getAlgoDir(algorithm);
- FileSystem fs = algoDir.getFileSystem(conf);
- if (!fs.exists(algoDir)) {
- return null;
- }
-
- List<String> models = new ArrayList<String>();
-
- for (FileStatus stat : fs.listStatus(algoDir)) {
- models.add(stat.getPath().getName());
- }
-
- if (models.isEmpty()) {
- return null;
- }
-
- return models;
- } catch (IOException ioex) {
- throw new LensException(ioex);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
- */
- public MLModel getModel(String algorithm, String modelId) throws LensException {
- try {
- return ModelLoader.loadModel(conf, algorithm, modelId);
- } catch (IOException e) {
- throw new LensException(e);
- }
- }
-
- /**
- * Inits the.
- *
- * @param hiveConf the hive conf
- */
- public synchronized void init(HiveConf hiveConf) {
- this.conf = hiveConf;
-
- // Get all the drivers
- String[] driverClasses = hiveConf.getStrings("lens.ml.drivers");
-
- if (driverClasses == null || driverClasses.length == 0) {
- throw new RuntimeException("No ML Drivers specified in conf");
- }
-
- LOG.info("Loading drivers " + Arrays.toString(driverClasses));
- drivers = new ArrayList<MLDriver>(driverClasses.length);
-
- for (String driverClass : driverClasses) {
- Class<?> cls;
- try {
- cls = Class.forName(driverClass);
- } catch (ClassNotFoundException e) {
- LOG.error("Driver class not found " + driverClass);
- continue;
- }
-
- if (!MLDriver.class.isAssignableFrom(cls)) {
- LOG.warn("Not a driver class " + driverClass);
- continue;
- }
-
- try {
- Class<? extends MLDriver> mlDriverClass = (Class<? extends MLDriver>) cls;
- MLDriver driver = mlDriverClass.newInstance();
- driver.init(toLensConf(conf));
- drivers.add(driver);
- LOG.info("Added driver " + driverClass);
- } catch (Exception e) {
- LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
- }
- }
- if (drivers.isEmpty()) {
- throw new RuntimeException("No ML drivers loaded");
- }
-
- LOG.info("Inited ML service");
- }
-
- /**
- * Start.
- */
- public synchronized void start() {
- for (MLDriver driver : drivers) {
- try {
- if (driver instanceof SparkMLDriver && sparkContext != null) {
- ((SparkMLDriver) driver).useSparkContext(sparkContext);
- }
- driver.start();
- } catch (LensException e) {
- LOG.error("Failed to start driver " + driver, e);
- }
- }
-
- udfStatusExpirySvc = Executors.newSingleThreadScheduledExecutor();
- udfStatusExpirySvc.scheduleAtFixedRate(new UDFStatusExpiryRunnable(), 60, 60, TimeUnit.SECONDS);
-
- LOG.info("Started ML service");
- }
-
- /**
- * Stop.
- */
- public synchronized void stop() {
- for (MLDriver driver : drivers) {
- try {
- driver.stop();
- } catch (LensException e) {
- LOG.error("Failed to stop driver " + driver, e);
- }
- }
- drivers.clear();
- udfStatusExpirySvc.shutdownNow();
- LOG.info("Stopped ML service");
- }
-
- public synchronized HiveConf getHiveConf() {
- return conf;
- }
-
- /**
- * Clear models.
- */
- public void clearModels() {
- ModelLoader.clearCache();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
- */
- public String getModelPath(String algorithm, String modelID) {
- return ModelLoader.getModelLocation(conf, algorithm, modelID).toString();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
- * java.lang.String)
- */
- @Override
- public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
- String outputTable) throws LensException {
- return null;
- }
-
- /**
- * Test a model in embedded mode.
- *
- * @param sessionHandle the session handle
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param queryApiUrl the query api url
- * @return the ML test report
- * @throws LensException the lens exception
- */
- public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- String queryApiUrl, String outputTable) throws LensException {
- return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl),
- outputTable);
- }
-
- /**
- * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its
- * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is
- * associated with this unique ID.
- * <p/>
- * <p>
- * This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens
- * server.
- * </p>
- *
- * @param sessionHandle the session handle
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param queryRunner the query runner
- * @param outputTable table where test output will be written
- * @return the ML test report
- * @throws LensException the lens exception
- */
- public MLTestReport testModel(final LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- QueryRunner queryRunner, String outputTable) throws LensException {
- if (sessionHandle == null) {
- throw new NullPointerException("Null session not allowed");
- }
- // check if algorithm exists
- if (!getAlgorithms().contains(algorithm)) {
- throw new LensException("No such algorithm " + algorithm);
- }
-
- MLModel<?> model;
- try {
- model = ModelLoader.loadModel(conf, algorithm, modelID);
- } catch (IOException e) {
- throw new LensException(e);
- }
-
- if (model == null) {
- throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm);
- }
-
- String database = null;
-
- if (SessionState.get() != null) {
- database = SessionState.get().getCurrentDatabase();
- }
-
- String testID = UUID.randomUUID().toString().replace("-", "_");
- final String testTable = outputTable;
- final String testResultColumn = "prediction_result";
-
- // TODO support error metric UDAFs
- TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf)
- .database(database == null ? "default" : database).inputTable(table).featureColumns(model.getFeatureColumns())
- .outputColumn(testResultColumn).lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID)
- .outputTable(testTable).testID(testID).build();
-
- String testQuery = spec.getTestQuery();
- if (testQuery == null) {
- throw new LensException("Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID="
- + modelID);
- }
-
- if (!spec.isOutputTableExists()) {
- LOG.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid="
- + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery());
- // create the output table
- String createOutputTableQuery = spec.getCreateOutputTableQuery();
- queryRunner.runQuery(createOutputTableQuery);
- LOG.info("Table created " + testTable);
- }
-
- // Check if ML UDF is registered in this session
- registerPredictUdf(sessionHandle, queryRunner);
-
- LOG.info("Running evaluation query " + testQuery);
- queryRunner.setQueryName("model_test_" + modelID);
- QueryHandle testQueryHandle = queryRunner.runQuery(testQuery);
-
- MLTestReport testReport = new MLTestReport();
- testReport.setReportID(testID);
- testReport.setAlgorithm(algorithm);
- testReport.setFeatureColumns(model.getFeatureColumns());
- testReport.setLabelColumn(model.getLabelColumn());
- testReport.setModelID(model.getId());
- testReport.setOutputColumn(testResultColumn);
- testReport.setOutputTable(testTable);
- testReport.setTestTable(table);
- testReport.setQueryID(testQueryHandle.toString());
-
- // Save test report
- persistTestReport(testReport);
- LOG.info("Saved test report " + testReport.getReportID());
- return testReport;
- }
-
- /**
- * Persist test report.
- *
- * @param testReport the test report
- * @throws LensException the lens exception
- */
- private void persistTestReport(MLTestReport testReport) throws LensException {
- LOG.info("saving test report " + testReport.getReportID());
- try {
- ModelLoader.saveTestReport(conf, testReport);
- LOG.info("Saved report " + testReport.getReportID());
- } catch (IOException e) {
- LOG.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage());
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
- */
- public List<String> getTestReports(String algorithm) throws LensException {
- Path reportBaseDir = new Path(conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT));
- FileSystem fs = null;
-
- try {
- fs = reportBaseDir.getFileSystem(conf);
- if (!fs.exists(reportBaseDir)) {
- return null;
- }
-
- Path algoDir = new Path(reportBaseDir, algorithm);
- if (!fs.exists(algoDir)) {
- return null;
- }
-
- List<String> reports = new ArrayList<String>();
- for (FileStatus stat : fs.listStatus(algoDir)) {
- reports.add(stat.getPath().getName());
- }
- return reports;
- } catch (IOException e) {
- LOG.error("Error reading report list for " + algorithm, e);
- return null;
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
- */
- public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
- try {
- return ModelLoader.loadReport(conf, algorithm, reportID);
- } catch (IOException e) {
- throw new LensException(e);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
- */
- public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
- // Load the model instance
- MLModel<?> model = getModel(algorithm, modelID);
- return model.predict(features);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
- */
- public void deleteModel(String algorithm, String modelID) throws LensException {
- try {
- ModelLoader.deleteModel(conf, algorithm, modelID);
- LOG.info("DELETED model " + modelID + " algorithm=" + algorithm);
- } catch (IOException e) {
- LOG.error(
- "Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e);
- throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
- */
- public void deleteTestReport(String algorithm, String reportID) throws LensException {
- try {
- ModelLoader.deleteTestReport(conf, algorithm, reportID);
- LOG.info("DELETED report=" + reportID + " algorithm=" + algorithm);
- } catch (IOException e) {
- LOG.error("Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e);
- throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
- */
- public Map<String, String> getAlgoParamDescription(String algorithm) {
- MLAlgo algo = null;
- try {
- algo = getAlgoForName(algorithm);
- } catch (LensException e) {
- LOG.error("Error getting algo description : " + algorithm, e);
- return null;
- }
- if (algo instanceof BaseSparkAlgo) {
- return ((BaseSparkAlgo) algo).getArgUsage();
- }
- return null;
- }
-
- /**
- * Submit model test query to a remote Lens server.
- */
- class RemoteQueryRunner extends QueryRunner {
-
- /** The query api url. */
- final String queryApiUrl;
-
- /**
- * Instantiates a new remote query runner.
- *
- * @param sessionHandle the session handle
- * @param queryApiUrl the query api url
- */
- public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) {
- super(sessionHandle);
- this.queryApiUrl = queryApiUrl;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
- */
- @Override
- public QueryHandle runQuery(String query) throws LensException {
- // Create jersey client for query endpoint
- Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build();
- WebTarget target = client.target(queryApiUrl);
- final FormDataMultiPart mp = new FormDataMultiPart();
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
- MediaType.APPLICATION_XML_TYPE));
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query));
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute"));
-
- LensConf lensConf = new LensConf();
- lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
- lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf,
- MediaType.APPLICATION_XML_TYPE));
-
- final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE),
- QueryHandle.class);
-
- LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request()
- .get(LensQuery.class);
-
- QueryStatus stat = ctx.getStatus();
- while (!stat.isFinished()) {
- ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request().get(LensQuery.class);
- stat = ctx.getStatus();
- try {
- Thread.sleep(500);
- } catch (InterruptedException e) {
- throw new LensException(e);
- }
- }
-
- if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) {
- throw new LensException("Query failed " + ctx.getQueryHandle().getHandleId() + " reason:"
- + stat.getErrorMessage());
- }
-
- return ctx.getQueryHandle();
- }
- }
-
- /**
- * To lens conf.
- *
- * @param conf the conf
- * @return the lens conf
- */
- private LensConf toLensConf(HiveConf conf) {
- LensConf lensConf = new LensConf();
- lensConf.getProperties().putAll(conf.getValByRegex(".*"));
- return lensConf;
- }
-
- protected void registerPredictUdf(LensSessionHandle sessionHandle, QueryRunner queryRunner) throws LensException {
- if (isUdfRegisterd(sessionHandle)) {
- // Already registered, nothing to do
- return;
- }
-
- LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
- // We have to add UDF jars to the session
- try {
- SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
- String[] udfJars = conf.getStrings("lens.server.ml.predict.udf.jars");
- if (udfJars != null) {
- for (String jar : udfJars) {
- sessionService.addResource(sessionHandle, "jar", jar);
- LOG.info(jar + " added UDF session " + sessionHandle.getPublicId().toString());
- }
- }
- } catch (Exception e) {
- throw new LensException(e);
- }
-
- String regUdfQuery = "CREATE TEMPORARY FUNCTION " + HiveMLUDF.UDF_NAME + " AS '" + HiveMLUDF.class
- .getCanonicalName() + "'";
- queryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
- QueryHandle udfQuery = queryRunner.runQuery(regUdfQuery);
- predictUdfStatus.put(sessionHandle, true);
- LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
- }
-
- protected boolean isUdfRegisterd(LensSessionHandle sessionHandle) {
- return predictUdfStatus.containsKey(sessionHandle);
- }
-
- /**
- * Periodically check if sessions have been closed, and clear UDF registered status.
- */
- private class UDFStatusExpiryRunnable implements Runnable {
- public void run() {
- try {
- SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
- // Clear status of sessions which are closed.
- List<LensSessionHandle> sessions = new ArrayList<LensSessionHandle>(predictUdfStatus.keySet());
- for (LensSessionHandle sessionHandle : sessions) {
- if (!sessionService.isOpen(sessionHandle)) {
- LOG.info("Session closed, removing UDF status: " + sessionHandle);
- predictUdfStatus.remove(sessionHandle);
- }
- }
- } catch (Exception exc) {
- LOG.warn("Error clearing UDF statuses", exc);
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
deleted file mode 100644
index 7dccf2c..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLAlgo.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-
-/**
- * The Interface MLAlgo.
- */
-public interface MLAlgo {
- String getName();
-
- String getDescription();
-
- /**
- * Configure.
- *
- * @param configuration the configuration
- */
- void configure(LensConf configuration);
-
- LensConf getConf();
-
- /**
- * Train.
- *
- * @param conf the conf
- * @param db the db
- * @param table the table
- * @param modelId the model id
- * @param params the params
- * @return the ML model
- * @throws LensException the lens exception
- */
- MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
deleted file mode 100644
index 567e717..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLDriver.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-
-/**
- * The Interface MLDriver.
- */
-public interface MLDriver {
-
- /**
- * Checks if is algo supported.
- *
- * @param algo the algo
- * @return true, if is algo supported
- */
- boolean isAlgoSupported(String algo);
-
- /**
- * Gets the algo instance.
- *
- * @param algo the algo
- * @return the algo instance
- * @throws LensException the lens exception
- */
- MLAlgo getAlgoInstance(String algo) throws LensException;
-
- /**
- * Inits the.
- *
- * @param conf the conf
- * @throws LensException the lens exception
- */
- void init(LensConf conf) throws LensException;
-
- /**
- * Start.
- *
- * @throws LensException the lens exception
- */
- void start() throws LensException;
-
- /**
- * Stop.
- *
- * @throws LensException the lens exception
- */
- void stop() throws LensException;
-
- List<String> getAlgoNames();
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
deleted file mode 100644
index c177757..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/MLModel.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml;
-
-import java.io.Serializable;
-import java.util.Date;
-import java.util.List;
-
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-import lombok.Setter;
-import lombok.ToString;
-
-/**
- * Instantiates a new ML model.
- */
-@NoArgsConstructor
-@ToString
-public abstract class MLModel<PREDICTION> implements Serializable {
-
- /** The id. */
- @Getter
- @Setter
- private String id;
-
- /** The created at. */
- @Getter
- @Setter
- private Date createdAt;
-
- /** The algo name. */
- @Getter
- @Setter
- private String algoName;
-
- /** The table. */
- @Getter
- @Setter
- private String table;
-
- /** The params. */
- @Getter
- @Setter
- private List<String> params;
-
- /** The label column. */
- @Getter
- @Setter
- private String labelColumn;
-
- /** The feature columns. */
- @Getter
- @Setter
- private List<String> featureColumns;
-
- /**
- * Predict.
- *
- * @param args the args
- * @return the prediction
- */
- public abstract PREDICTION predict(Object... args);
-}