You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by sh...@apache.org on 2015/04/05 09:11:04 UTC
[3/6] incubator-lens git commit: Lens-465 : Refactor ml packages.
(sharad)
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
new file mode 100644
index 0000000..f0c6e04
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
@@ -0,0 +1,744 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import javax.ws.rs.client.Client;
+import javax.ws.rs.client.ClientBuilder;
+import javax.ws.rs.client.Entity;
+import javax.ws.rs.client.WebTarget;
+import javax.ws.rs.core.MediaType;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.LensQuery;
+import org.apache.lens.api.query.QueryHandle;
+import org.apache.lens.api.query.QueryStatus;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.MLDriver;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.spark.SparkMLDriver;
+import org.apache.lens.ml.api.LensML;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.session.SessionService;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.session.SessionState;
+import org.apache.spark.api.java.JavaSparkContext;
+
+import org.glassfish.jersey.media.multipart.FormDataBodyPart;
+import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
+import org.glassfish.jersey.media.multipart.FormDataMultiPart;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+
+/**
+ * The Class LensMLImpl.
+ */
+public class LensMLImpl implements LensML {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /** The drivers. */
+ protected List<MLDriver> drivers;
+
+ /** The conf. */
+ private HiveConf conf;
+
+ /** The spark context. */
+ private JavaSparkContext sparkContext;
+
+ /** Check if the predict UDF has been registered for a user */
+ private final Map<LensSessionHandle, Boolean> predictUdfStatus;
+ /** Background thread to periodically check if we need to clear expire status for a session */
+ private ScheduledExecutorService udfStatusExpirySvc;
+
+ /**
+ * Instantiates a new lens ml impl.
+ *
+ * @param conf the conf
+ */
+ public LensMLImpl(HiveConf conf) {
+ this.conf = conf;
+ this.predictUdfStatus = new ConcurrentHashMap<LensSessionHandle, Boolean>();
+ }
+
+ public HiveConf getConf() {
+ return conf;
+ }
+
+ /**
+ * Use an existing Spark context. Useful in case of
+ *
+ * @param jsc JavaSparkContext instance
+ */
+ public void setSparkContext(JavaSparkContext jsc) {
+ this.sparkContext = jsc;
+ }
+
+ public List<String> getAlgorithms() {
+ List<String> algos = new ArrayList<String>();
+ for (MLDriver driver : drivers) {
+ algos.addAll(driver.getAlgoNames());
+ }
+ return algos;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
+ */
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ for (MLDriver driver : drivers) {
+ if (driver.isAlgoSupported(algorithm)) {
+ return driver.getAlgoInstance(algorithm);
+ }
+ }
+ throw new LensException("Algo not supported " + algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
+ */
+ public String train(String table, String algorithm, String[] args) throws LensException {
+ MLAlgo algo = getAlgoForName(algorithm);
+
+ String modelId = UUID.randomUUID().toString();
+
+ LOG.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
+ + Arrays.toString(args));
+
+ String database = null;
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ } else {
+ database = "default";
+ }
+
+ MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
+
+ LOG.info("Done training model: " + modelId);
+
+ model.setCreatedAt(new Date());
+ model.setAlgoName(algorithm);
+
+ Path modelLocation = null;
+ try {
+ modelLocation = persistModel(model);
+ LOG.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
+ return model.getId();
+ } catch (IOException e) {
+ throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
+ }
+ }
+
+ /**
+ * Gets the algo dir.
+ *
+ * @param algoName the algo name
+ * @return the algo dir
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ private Path getAlgoDir(String algoName) throws IOException {
+ String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
+ return new Path(new Path(modelSaveBaseDir), algoName);
+ }
+
+ /**
+ * Persist model.
+ *
+ * @param model the model
+ * @return the path
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ private Path persistModel(MLModel model) throws IOException {
+ // Get model save path
+ Path algoDir = getAlgoDir(model.getAlgoName());
+ FileSystem fs = algoDir.getFileSystem(conf);
+
+ if (!fs.exists(algoDir)) {
+ fs.mkdirs(algoDir);
+ }
+
+ Path modelSavePath = new Path(algoDir, model.getId());
+ ObjectOutputStream outputStream = null;
+
+ try {
+ outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
+ outputStream.writeObject(model);
+ outputStream.flush();
+ } catch (IOException io) {
+ LOG.error("Error saving model " + model.getId() + " reason: " + io.getMessage());
+ throw io;
+ } finally {
+ IOUtils.closeQuietly(outputStream);
+ }
+ return modelSavePath;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
+ */
+ public List<String> getModels(String algorithm) throws LensException {
+ try {
+ Path algoDir = getAlgoDir(algorithm);
+ FileSystem fs = algoDir.getFileSystem(conf);
+ if (!fs.exists(algoDir)) {
+ return null;
+ }
+
+ List<String> models = new ArrayList<String>();
+
+ for (FileStatus stat : fs.listStatus(algoDir)) {
+ models.add(stat.getPath().getName());
+ }
+
+ if (models.isEmpty()) {
+ return null;
+ }
+
+ return models;
+ } catch (IOException ioex) {
+ throw new LensException(ioex);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
+ */
+ public MLModel getModel(String algorithm, String modelId) throws LensException {
+ try {
+ return ModelLoader.loadModel(conf, algorithm, modelId);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+ }
+
+ /**
+ * Inits the.
+ *
+ * @param hiveConf the hive conf
+ */
+ public synchronized void init(HiveConf hiveConf) {
+ this.conf = hiveConf;
+
+ // Get all the drivers
+ String[] driverClasses = hiveConf.getStrings("lens.ml.drivers");
+
+ if (driverClasses == null || driverClasses.length == 0) {
+ throw new RuntimeException("No ML Drivers specified in conf");
+ }
+
+ LOG.info("Loading drivers " + Arrays.toString(driverClasses));
+ drivers = new ArrayList<MLDriver>(driverClasses.length);
+
+ for (String driverClass : driverClasses) {
+ Class<?> cls;
+ try {
+ cls = Class.forName(driverClass);
+ } catch (ClassNotFoundException e) {
+ LOG.error("Driver class not found " + driverClass);
+ continue;
+ }
+
+ if (!MLDriver.class.isAssignableFrom(cls)) {
+ LOG.warn("Not a driver class " + driverClass);
+ continue;
+ }
+
+ try {
+ Class<? extends MLDriver> mlDriverClass = (Class<? extends MLDriver>) cls;
+ MLDriver driver = mlDriverClass.newInstance();
+ driver.init(toLensConf(conf));
+ drivers.add(driver);
+ LOG.info("Added driver " + driverClass);
+ } catch (Exception e) {
+ LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
+ }
+ }
+ if (drivers.isEmpty()) {
+ throw new RuntimeException("No ML drivers loaded");
+ }
+
+ LOG.info("Inited ML service");
+ }
+
+ /**
+ * Start.
+ */
+ public synchronized void start() {
+ for (MLDriver driver : drivers) {
+ try {
+ if (driver instanceof SparkMLDriver && sparkContext != null) {
+ ((SparkMLDriver) driver).useSparkContext(sparkContext);
+ }
+ driver.start();
+ } catch (LensException e) {
+ LOG.error("Failed to start driver " + driver, e);
+ }
+ }
+
+ udfStatusExpirySvc = Executors.newSingleThreadScheduledExecutor();
+ udfStatusExpirySvc.scheduleAtFixedRate(new UDFStatusExpiryRunnable(), 60, 60, TimeUnit.SECONDS);
+
+ LOG.info("Started ML service");
+ }
+
+ /**
+ * Stop.
+ */
+ public synchronized void stop() {
+ for (MLDriver driver : drivers) {
+ try {
+ driver.stop();
+ } catch (LensException e) {
+ LOG.error("Failed to stop driver " + driver, e);
+ }
+ }
+ drivers.clear();
+ udfStatusExpirySvc.shutdownNow();
+ LOG.info("Stopped ML service");
+ }
+
+ public synchronized HiveConf getHiveConf() {
+ return conf;
+ }
+
+ /**
+ * Clear models.
+ */
+ public void clearModels() {
+ ModelLoader.clearCache();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
+ */
+ public String getModelPath(String algorithm, String modelID) {
+ return ModelLoader.getModelLocation(conf, algorithm, modelID).toString();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
+ * java.lang.String)
+ */
+ @Override
+ public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
+ String outputTable) throws LensException {
+ return null;
+ }
+
+ /**
+ * Test a model in embedded mode.
+ *
+ * @param sessionHandle the session handle
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param queryApiUrl the query api url
+ * @return the ML test report
+ * @throws LensException the lens exception
+ */
+ public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ String queryApiUrl, String outputTable) throws LensException {
+ return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl),
+ outputTable);
+ }
+
+ /**
+ * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its
+ * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is
+ * associated with this unique ID.
+ * <p/>
+ * <p>
+ * This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens
+ * server.
+ * </p>
+ *
+ * @param sessionHandle the session handle
+ * @param table the table
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param queryRunner the query runner
+ * @param outputTable table where test output will be written
+ * @return the ML test report
+ * @throws LensException the lens exception
+ */
+ public MLTestReport testModel(final LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ QueryRunner queryRunner, String outputTable) throws LensException {
+ if (sessionHandle == null) {
+ throw new NullPointerException("Null session not allowed");
+ }
+ // check if algorithm exists
+ if (!getAlgorithms().contains(algorithm)) {
+ throw new LensException("No such algorithm " + algorithm);
+ }
+
+ MLModel<?> model;
+ try {
+ model = ModelLoader.loadModel(conf, algorithm, modelID);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+
+ if (model == null) {
+ throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm);
+ }
+
+ String database = null;
+
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ }
+
+ String testID = UUID.randomUUID().toString().replace("-", "_");
+ final String testTable = outputTable;
+ final String testResultColumn = "prediction_result";
+
+ // TODO support error metric UDAFs
+ TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf)
+ .database(database == null ? "default" : database).inputTable(table).featureColumns(model.getFeatureColumns())
+ .outputColumn(testResultColumn).lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID)
+ .outputTable(testTable).testID(testID).build();
+
+ String testQuery = spec.getTestQuery();
+ if (testQuery == null) {
+ throw new LensException("Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID="
+ + modelID);
+ }
+
+ if (!spec.isOutputTableExists()) {
+ LOG.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid="
+ + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery());
+ // create the output table
+ String createOutputTableQuery = spec.getCreateOutputTableQuery();
+ queryRunner.runQuery(createOutputTableQuery);
+ LOG.info("Table created " + testTable);
+ }
+
+ // Check if ML UDF is registered in this session
+ registerPredictUdf(sessionHandle, queryRunner);
+
+ LOG.info("Running evaluation query " + testQuery);
+ queryRunner.setQueryName("model_test_" + modelID);
+ QueryHandle testQueryHandle = queryRunner.runQuery(testQuery);
+
+ MLTestReport testReport = new MLTestReport();
+ testReport.setReportID(testID);
+ testReport.setAlgorithm(algorithm);
+ testReport.setFeatureColumns(model.getFeatureColumns());
+ testReport.setLabelColumn(model.getLabelColumn());
+ testReport.setModelID(model.getId());
+ testReport.setOutputColumn(testResultColumn);
+ testReport.setOutputTable(testTable);
+ testReport.setTestTable(table);
+ testReport.setQueryID(testQueryHandle.toString());
+
+ // Save test report
+ persistTestReport(testReport);
+ LOG.info("Saved test report " + testReport.getReportID());
+ return testReport;
+ }
+
+ /**
+ * Persist test report.
+ *
+ * @param testReport the test report
+ * @throws LensException the lens exception
+ */
+ private void persistTestReport(MLTestReport testReport) throws LensException {
+ LOG.info("saving test report " + testReport.getReportID());
+ try {
+ ModelLoader.saveTestReport(conf, testReport);
+ LOG.info("Saved report " + testReport.getReportID());
+ } catch (IOException e) {
+ LOG.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage());
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
+ */
+ public List<String> getTestReports(String algorithm) throws LensException {
+ Path reportBaseDir = new Path(conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT));
+ FileSystem fs = null;
+
+ try {
+ fs = reportBaseDir.getFileSystem(conf);
+ if (!fs.exists(reportBaseDir)) {
+ return null;
+ }
+
+ Path algoDir = new Path(reportBaseDir, algorithm);
+ if (!fs.exists(algoDir)) {
+ return null;
+ }
+
+ List<String> reports = new ArrayList<String>();
+ for (FileStatus stat : fs.listStatus(algoDir)) {
+ reports.add(stat.getPath().getName());
+ }
+ return reports;
+ } catch (IOException e) {
+ LOG.error("Error reading report list for " + algorithm, e);
+ return null;
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
+ */
+ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
+ try {
+ return ModelLoader.loadReport(conf, algorithm, reportID);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
+ */
+ public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
+ // Load the model instance
+ MLModel<?> model = getModel(algorithm, modelID);
+ return model.predict(features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
+ */
+ public void deleteModel(String algorithm, String modelID) throws LensException {
+ try {
+ ModelLoader.deleteModel(conf, algorithm, modelID);
+ LOG.info("DELETED model " + modelID + " algorithm=" + algorithm);
+ } catch (IOException e) {
+ LOG.error(
+ "Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e);
+ throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
+ */
+ public void deleteTestReport(String algorithm, String reportID) throws LensException {
+ try {
+ ModelLoader.deleteTestReport(conf, algorithm, reportID);
+ LOG.info("DELETED report=" + reportID + " algorithm=" + algorithm);
+ } catch (IOException e) {
+ LOG.error("Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e);
+ throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
+ */
+ public Map<String, String> getAlgoParamDescription(String algorithm) {
+ MLAlgo algo = null;
+ try {
+ algo = getAlgoForName(algorithm);
+ } catch (LensException e) {
+ LOG.error("Error getting algo description : " + algorithm, e);
+ return null;
+ }
+ if (algo instanceof BaseSparkAlgo) {
+ return ((BaseSparkAlgo) algo).getArgUsage();
+ }
+ return null;
+ }
+
+ /**
+ * Submit model test query to a remote Lens server.
+ */
+ class RemoteQueryRunner extends QueryRunner {
+
+ /** The query api url. */
+ final String queryApiUrl;
+
+ /**
+ * Instantiates a new remote query runner.
+ *
+ * @param sessionHandle the session handle
+ * @param queryApiUrl the query api url
+ */
+ public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) {
+ super(sessionHandle);
+ this.queryApiUrl = queryApiUrl;
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
+ */
+ @Override
+ public QueryHandle runQuery(String query) throws LensException {
+ // Create jersey client for query endpoint
+ Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build();
+ WebTarget target = client.target(queryApiUrl);
+ final FormDataMultiPart mp = new FormDataMultiPart();
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
+ MediaType.APPLICATION_XML_TYPE));
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query));
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute"));
+
+ LensConf lensConf = new LensConf();
+ lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
+ lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
+ mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf,
+ MediaType.APPLICATION_XML_TYPE));
+
+ final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE),
+ QueryHandle.class);
+
+ LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request()
+ .get(LensQuery.class);
+
+ QueryStatus stat = ctx.getStatus();
+ while (!stat.isFinished()) {
+ ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request().get(LensQuery.class);
+ stat = ctx.getStatus();
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new LensException(e);
+ }
+ }
+
+ if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) {
+ throw new LensException("Query failed " + ctx.getQueryHandle().getHandleId() + " reason:"
+ + stat.getErrorMessage());
+ }
+
+ return ctx.getQueryHandle();
+ }
+ }
+
+ /**
+ * To lens conf.
+ *
+ * @param conf the conf
+ * @return the lens conf
+ */
+ private LensConf toLensConf(HiveConf conf) {
+ LensConf lensConf = new LensConf();
+ lensConf.getProperties().putAll(conf.getValByRegex(".*"));
+ return lensConf;
+ }
+
+ protected void registerPredictUdf(LensSessionHandle sessionHandle, QueryRunner queryRunner) throws LensException {
+ if (isUdfRegisterd(sessionHandle)) {
+ // Already registered, nothing to do
+ return;
+ }
+
+ LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
+ // We have to add UDF jars to the session
+ try {
+ SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
+ String[] udfJars = conf.getStrings("lens.server.ml.predict.udf.jars");
+ if (udfJars != null) {
+ for (String jar : udfJars) {
+ sessionService.addResource(sessionHandle, "jar", jar);
+ LOG.info(jar + " added UDF session " + sessionHandle.getPublicId().toString());
+ }
+ }
+ } catch (Exception e) {
+ throw new LensException(e);
+ }
+
+ String regUdfQuery = "CREATE TEMPORARY FUNCTION " + HiveMLUDF.UDF_NAME + " AS '" + HiveMLUDF.class
+ .getCanonicalName() + "'";
+ queryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
+ QueryHandle udfQuery = queryRunner.runQuery(regUdfQuery);
+ predictUdfStatus.put(sessionHandle, true);
+ LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
+ }
+
+ protected boolean isUdfRegisterd(LensSessionHandle sessionHandle) {
+ return predictUdfStatus.containsKey(sessionHandle);
+ }
+
+ /**
+ * Periodically check if sessions have been closed, and clear UDF registered status.
+ */
+ private class UDFStatusExpiryRunnable implements Runnable {
+ public void run() {
+ try {
+ SessionService sessionService = (SessionService) MLUtils.getServiceProvider().getService(SessionService.NAME);
+ // Clear status of sessions which are closed.
+ List<LensSessionHandle> sessions = new ArrayList<LensSessionHandle>(predictUdfStatus.keySet());
+ for (LensSessionHandle sessionHandle : sessions) {
+ if (!sessionService.isOpen(sessionHandle)) {
+ LOG.info("Session closed, removing UDF status: " + sessionHandle);
+ predictUdfStatus.remove(sessionHandle);
+ }
+ }
+ } catch (Exception exc) {
+ LOG.warn("Error clearing UDF statuses", exc);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
new file mode 100644
index 0000000..625d020
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
@@ -0,0 +1,172 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+
+import org.apache.lens.client.LensClient;
+import org.apache.lens.client.LensClientConfig;
+import org.apache.lens.client.LensMLClient;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.TableType;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
+import org.apache.hadoop.hive.serde.serdeConstants;
+import org.apache.hadoop.mapred.TextInputFormat;
+
+public class MLRunner {
+
+ private static final Log LOG = LogFactory.getLog(MLRunner.class);
+
+ private LensMLClient mlClient;
+ private String algoName;
+ private String database;
+ private String trainTable;
+ private String trainFile;
+ private String testTable;
+ private String testFile;
+ private String outputTable;
+ private String[] features;
+ private String labelColumn;
+ private HiveConf conf;
+
+ public void init(LensMLClient mlClient, String confDir) throws Exception {
+ File dir = new File(confDir);
+ File propFile = new File(dir, "ml.properties");
+ Properties props = new Properties();
+ props.load(new FileInputStream(propFile));
+ String feat = props.getProperty("features");
+ String trainFile = confDir + File.separator + "train.data";
+ String testFile = confDir + File.separator + "test.data";
+ init(mlClient, props.getProperty("algo"), props.getProperty("database"),
+ props.getProperty("traintable"), trainFile,
+ props.getProperty("testtable"), testFile,
+ props.getProperty("outputtable"), feat.split(","),
+ props.getProperty("labelcolumn"));
+ }
+
+ public void init(LensMLClient mlClient, String algoName,
+ String database, String trainTable, String trainFile,
+ String testTable, String testFile, String outputTable, String[] features,
+ String labelColumn) {
+ this.mlClient = mlClient;
+ this.algoName = algoName;
+ this.database = database;
+ this.trainTable = trainTable;
+ this.trainFile = trainFile;
+ this.testTable = testTable;
+ this.testFile = testFile;
+ this.outputTable = outputTable;
+ this.features = features;
+ this.labelColumn = labelColumn;
+ //hive metastore settings are loaded via lens-site.xml, so loading LensClientConfig
+ //is required
+ this.conf = new HiveConf(new LensClientConfig(), MLRunner.class);
+ }
+
+ public MLTask train() throws Exception {
+ LOG.info("Starting train & eval");
+
+ createTable(trainTable, trainFile);
+ createTable(testTable, testFile);
+ MLTask.Builder taskBuilder = new MLTask.Builder();
+ taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn).outputTable(outputTable)
+ .client(mlClient).trainingTable(trainTable).testTable(testTable);
+
+ // Add features
+ for (String feature : features) {
+ taskBuilder.addFeatureColumn(feature);
+ }
+ MLTask task = taskBuilder.build();
+ LOG.info("Created task " + task.toString());
+ task.run();
+ return task;
+ }
+
+ public void createTable(String tableName, String dataFile) throws HiveException {
+
+ File filedataFile = new File(dataFile);
+ Path dataFilePath = new Path(filedataFile.toURI());
+ Path partDir = dataFilePath.getParent();
+
+ // Create table
+ List<FieldSchema> columns = new ArrayList<FieldSchema>();
+
+ // Label is optional. Not used for unsupervised models.
+ // If present, label will be the first column, followed by features
+ if (labelColumn != null) {
+ columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
+ }
+
+ for (String feature : features) {
+ columns.add(new FieldSchema(feature, "double", "Feature " + feature));
+ }
+
+ Table tbl = Hive.get(conf).newTable(database + "." + tableName);
+ tbl.setTableType(TableType.MANAGED_TABLE);
+ tbl.getTTable().getSd().setCols(columns);
+ // tbl.getTTable().getParameters().putAll(new HashMap<String, String>());
+ tbl.setInputFormatClass(TextInputFormat.class);
+ tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
+ tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
+
+ List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
+ partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
+ tbl.setPartCols(partCols);
+
+ Hive.get(conf).dropTable(database, tableName, false, true);
+ Hive.get(conf).createTable(tbl, true);
+ LOG.info("Created table " + tableName);
+
+ // Add partition for the data file
+ AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName,
+ false);
+ Map<String, String> partSpec = new HashMap<String, String>();
+ partSpec.put("dummy_partition_col", "dummy_val");
+ partitionDesc.addPartition(partSpec, partDir.toUri().toString());
+ Hive.get(conf).createPartitions(partitionDesc);
+ LOG.info(tableName + ": Added partition " + partDir.toUri().toString());
+ }
+
+ public static void main(String[] args) throws Exception {
+ if (args.length < 1) {
+ System.out.println("Usage: " + MLRunner.class.getName() + " <ml-conf-dir>");
+ System.exit(-1);
+ }
+ String confDir = args[0];
+ LensMLClient client = new LensMLClient(new LensClient());
+ MLRunner runner = new MLRunner();
+ runner.init(client, confDir);
+ runner.train();
+ System.out.println("Created the Model successfully. Output Table: " + runner.outputTable);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
new file mode 100644
index 0000000..2867b90
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
@@ -0,0 +1,285 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.util.*;
+
+import org.apache.lens.client.LensMLClient;
+import org.apache.lens.ml.api.LensML;
+import org.apache.lens.ml.api.MLTestReport;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import lombok.Getter;
+import lombok.ToString;
+
+/**
+ * Run a complete cycle of train and test (evaluation) for an ML algorithm
+ */
+@ToString
+public class MLTask implements Runnable {
+ private static final Log LOG = LogFactory.getLog(MLTask.class);
+
+ public enum State {
+ RUNNING, SUCCESSFUL, FAILED
+ }
+
+ @Getter
+ private State taskState;
+
+ /**
+ * Name of the algo/algorithm.
+ */
+ @Getter
+ private String algorithm;
+
+ /**
+ * Name of the table containing training data.
+ */
+ @Getter
+ private String trainingTable;
+
+ /**
+ * Name of the table containing test data. Optional, if not provided trainingTable itself is
+ * used for testing
+ */
+ @Getter
+ private String testTable;
+
+ /**
+ * Training table partition spec
+ */
+ @Getter
+ private String partitionSpec;
+
+ /**
+ * Name of the column which is a label for supervised algorithms.
+ */
+ @Getter
+ private String labelColumn;
+
+ /**
+ * Names of columns which are features in the training data.
+ */
+ @Getter
+ private List<String> featureColumns;
+
+ /**
+ * Configuration for the example.
+ */
+ @Getter
+ private HiveConf configuration;
+
+ private LensML ml;
+ private String taskID;
+
+ /**
+ * ml client
+ */
+ @Getter
+ private LensMLClient mlClient;
+
+ /**
+ * Output table name
+ */
+ @Getter
+ private String outputTable;
+
+ /**
+ * Extra params passed to the training algorithm
+ */
+ @Getter
+ private Map<String, String> extraParams;
+
+ @Getter
+ private String modelID;
+
+ @Getter
+ private String reportID;
+
+ /**
+ * Use ExampleTask.Builder to create an instance
+ */
+ private MLTask() {
+ // Use builder to construct the example
+ extraParams = new HashMap<String, String>();
+ taskID = UUID.randomUUID().toString();
+ }
+
+ /**
+ * Builder to create an example task
+ */
+ public static class Builder {
+ private MLTask task;
+
+ public Builder() {
+ task = new MLTask();
+ }
+
+ public Builder trainingTable(String trainingTable) {
+ task.trainingTable = trainingTable;
+ return this;
+ }
+
+ public Builder testTable(String testTable) {
+ task.testTable = testTable;
+ return this;
+ }
+
+ public Builder algorithm(String algorithm) {
+ task.algorithm = algorithm;
+ return this;
+ }
+
+ public Builder labelColumn(String labelColumn) {
+ task.labelColumn = labelColumn;
+ return this;
+ }
+
+ public Builder client(LensMLClient client) {
+ task.mlClient = client;
+ return this;
+ }
+
+ public Builder addFeatureColumn(String featureColumn) {
+ if (task.featureColumns == null) {
+ task.featureColumns = new ArrayList<String>();
+ }
+ task.featureColumns.add(featureColumn);
+ return this;
+ }
+
+ public Builder hiveConf(HiveConf hiveConf) {
+ task.configuration = hiveConf;
+ return this;
+ }
+
+
+
+ public Builder extraParam(String param, String value) {
+ task.extraParams.put(param, value);
+ return this;
+ }
+
+ public Builder partitionSpec(String partitionSpec) {
+ task.partitionSpec = partitionSpec;
+ return this;
+ }
+
+ public Builder outputTable(String outputTable) {
+ task.outputTable = outputTable;
+ return this;
+ }
+
+ public MLTask build() {
+ MLTask builtTask = task;
+ task = null;
+ return builtTask;
+ }
+
+ }
+
+ @Override
+ public void run() {
+ taskState = State.RUNNING;
+ LOG.info("Starting " + taskID);
+ try {
+ runTask();
+ taskState = State.SUCCESSFUL;
+ LOG.info("Complete " + taskID);
+ } catch (Exception e) {
+ taskState = State.FAILED;
+ LOG.info("Error running task " + taskID, e);
+ }
+ }
+
+ /**
+ * Train an ML model, with specified algorithm and input data. Do model evaluation using the evaluation data and print
+ * evaluation result
+ *
+ * @throws Exception
+ */
+ private void runTask() throws Exception {
+ if (mlClient != null) {
+ // Connect to a remote Lens server
+ ml = mlClient;
+ LOG.info("Working in client mode. Lens session handle " + mlClient.getSessionHandle().getPublicId());
+ } else {
+ // In server mode session handle has to be passed by the user as a request parameter
+ ml = MLUtils.getMLService();
+ LOG.info("Working in Lens server");
+ }
+
+ String[] algoArgs = buildTrainingArgs();
+ LOG.info("Starting task " + taskID + " algo args: " + Arrays.toString(algoArgs));
+
+ modelID = ml.train(trainingTable, algorithm, algoArgs);
+ printModelMetadata(taskID, modelID);
+
+ LOG.info("Starting test " + taskID);
+ testTable = (testTable != null) ? testTable : trainingTable;
+ MLTestReport testReport = ml.testModel(mlClient.getSessionHandle(), testTable, algorithm, modelID, outputTable);
+ reportID = testReport.getReportID();
+ printTestReport(taskID, testReport);
+ saveTask();
+ }
+
+ // Save task metadata to DB
+ private void saveTask() {
+ LOG.info("Saving task details to DB");
+ }
+
+ private void printTestReport(String exampleID, MLTestReport testReport) {
+ StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
+ builder.append("\n\t");
+ builder.append("EvaluationReport: ").append(testReport.toString());
+ System.out.println(builder.toString());
+ }
+
+ private String[] buildTrainingArgs() {
+ List<String> argList = new ArrayList<String>();
+ argList.add("label");
+ argList.add(labelColumn);
+
+ // Add all the features
+ for (String featureCol : featureColumns) {
+ argList.add("feature");
+ argList.add(featureCol);
+ }
+
+ // Add extra params
+ for (String param : extraParams.keySet()) {
+ argList.add(param);
+ argList.add(extraParams.get(param));
+ }
+
+ return argList.toArray(new String[argList.size()]);
+ }
+
+ // Get the model instance and print its metadat to stdout
+ private void printModelMetadata(String exampleID, String modelID) throws Exception {
+ StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
+ builder.append("\n\t");
+ builder.append("Model: ");
+ builder.append(ml.getModel(algorithm, modelID).toString());
+ System.out.println(builder.toString());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
new file mode 100644
index 0000000..9c96d9b
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.server.MLService;
+import org.apache.lens.ml.server.MLServiceImpl;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+
+public final class MLUtils {
+ private MLUtils() {
+ }
+
+ private static final HiveConf HIVE_CONF;
+
+ static {
+ HIVE_CONF = new HiveConf();
+ // Add default config so that we know the service provider implementation
+ HIVE_CONF.addResource("lensserver-default.xml");
+ HIVE_CONF.addResource("lens-site.xml");
+ }
+
+ public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
+ Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
+ if (annotation != null) {
+ return annotation.name();
+ }
+ throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
+ }
+
+ public static MLServiceImpl getMLService() throws Exception {
+ return getServiceProvider().getService(MLService.NAME);
+ }
+
+ public static ServiceProvider getServiceProvider() throws Exception {
+ Class<? extends ServiceProviderFactory> spfClass = HIVE_CONF.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY,
+ null, ServiceProviderFactory.class);
+ ServiceProviderFactory spf = spfClass.newInstance();
+ return spf.getServiceProvider();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
new file mode 100644
index 0000000..c0e7953
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
@@ -0,0 +1,242 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+
+/**
+ * Load ML models from a FS location.
+ */
+public final class ModelLoader {
+ private ModelLoader() {
+ }
+
+ /** The Constant MODEL_PATH_BASE_DIR. */
+ public static final String MODEL_PATH_BASE_DIR = "lens.ml.model.basedir";
+
+ /** The Constant MODEL_PATH_BASE_DIR_DEFAULT. */
+ public static final String MODEL_PATH_BASE_DIR_DEFAULT = "file:///tmp";
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(ModelLoader.class);
+
+ /** The Constant TEST_REPORT_BASE_DIR. */
+ public static final String TEST_REPORT_BASE_DIR = "lens.ml.test.basedir";
+
+ /** The Constant TEST_REPORT_BASE_DIR_DEFAULT. */
+ public static final String TEST_REPORT_BASE_DIR_DEFAULT = "file:///tmp/ml_reports";
+
+ // Model cache settings
+ /** The Constant MODEL_CACHE_SIZE. */
+ public static final long MODEL_CACHE_SIZE = 10;
+
+ /** The Constant MODEL_CACHE_TIMEOUT. */
+ public static final long MODEL_CACHE_TIMEOUT = 3600000L; // one hour
+
+ /** The model cache. */
+ private static Cache<Path, MLModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
+ .expireAfterAccess(MODEL_CACHE_TIMEOUT, TimeUnit.MILLISECONDS).build();
+
+ /**
+ * Gets the model location.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the model location
+ */
+ public static Path getModelLocation(Configuration conf, String algorithm, String modelID) {
+ String modelDataBaseDir = conf.get(MODEL_PATH_BASE_DIR, MODEL_PATH_BASE_DIR_DEFAULT);
+ // Model location format - <modelDataBaseDir>/<algorithm>/modelID
+ return new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID);
+ }
+
+ /**
+ * Load model.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return the ML model
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static MLModel loadModel(Configuration conf, String algorithm, String modelID) throws IOException {
+ final Path modelPath = getModelLocation(conf, algorithm, modelID);
+ LOG.info("Loading model for algorithm: " + algorithm + " modelID: " + modelID + " At path: "
+ + modelPath.toUri().toString());
+ try {
+ return modelCache.get(modelPath, new Callable<MLModel>() {
+ @Override
+ public MLModel call() throws Exception {
+ FileSystem fs = modelPath.getFileSystem(new HiveConf());
+ if (!fs.exists(modelPath)) {
+ throw new IOException("Model path not found " + modelPath.toString());
+ }
+
+ ObjectInputStream ois = null;
+ try {
+ ois = new ObjectInputStream(fs.open(modelPath));
+ MLModel model = (MLModel) ois.readObject();
+ LOG.info("Loaded model " + model.getId() + " from location " + modelPath);
+ return model;
+ } catch (ClassNotFoundException e) {
+ throw new IOException(e);
+ } finally {
+ IOUtils.closeQuietly(ois);
+ }
+ }
+ });
+ } catch (ExecutionException exc) {
+ throw new IOException(exc);
+ }
+ }
+
+ /**
+ * Clear cache.
+ */
+ public static void clearCache() {
+ modelCache.cleanUp();
+ }
+
+ /**
+ * Gets the test report path.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param report the report
+ * @return the test report path
+ */
+ public static Path getTestReportPath(Configuration conf, String algorithm, String report) {
+ String testReportDir = conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT);
+ return new Path(new Path(testReportDir, algorithm), report);
+ }
+
+ /**
+ * Save test report.
+ *
+ * @param conf the conf
+ * @param report the report
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void saveTestReport(Configuration conf, MLTestReport report) throws IOException {
+ Path reportDir = new Path(conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT));
+ FileSystem fs = reportDir.getFileSystem(conf);
+
+ if (!fs.exists(reportDir)) {
+ LOG.info("Creating test report dir " + reportDir.toUri().toString());
+ fs.mkdirs(reportDir);
+ }
+
+ Path algoDir = new Path(reportDir, report.getAlgorithm());
+
+ if (!fs.exists(algoDir)) {
+ LOG.info("Creating algorithm report dir " + algoDir.toUri().toString());
+ fs.mkdirs(algoDir);
+ }
+
+ ObjectOutputStream reportOutputStream = null;
+ Path reportSaveLocation;
+ try {
+ reportSaveLocation = new Path(algoDir, report.getReportID());
+ reportOutputStream = new ObjectOutputStream(fs.create(reportSaveLocation));
+ reportOutputStream.writeObject(report);
+ reportOutputStream.flush();
+ } catch (IOException ioexc) {
+ LOG.error("Error saving test report " + report.getReportID(), ioexc);
+ throw ioexc;
+ } finally {
+ IOUtils.closeQuietly(reportOutputStream);
+ }
+ LOG.info("Saved report " + report.getReportID() + " at location " + reportSaveLocation.toUri());
+ }
+
+ /**
+ * Load report.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the ML test report
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static MLTestReport loadReport(Configuration conf, String algorithm, String reportID) throws IOException {
+ Path reportLocation = getTestReportPath(conf, algorithm, reportID);
+ FileSystem fs = reportLocation.getFileSystem(conf);
+ ObjectInputStream reportStream = null;
+ MLTestReport report = null;
+
+ try {
+ reportStream = new ObjectInputStream(fs.open(reportLocation));
+ report = (MLTestReport) reportStream.readObject();
+ } catch (IOException ioex) {
+ LOG.error("Error reading report " + reportLocation, ioex);
+ } catch (ClassNotFoundException e) {
+ throw new IOException(e);
+ } finally {
+ IOUtils.closeQuietly(reportStream);
+ }
+ return report;
+ }
+
+ /**
+ * Delete model.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void deleteModel(HiveConf conf, String algorithm, String modelID) throws IOException {
+ Path modelLocation = getModelLocation(conf, algorithm, modelID);
+ FileSystem fs = modelLocation.getFileSystem(conf);
+ fs.delete(modelLocation, false);
+ }
+
+ /**
+ * Delete test report.
+ *
+ * @param conf the conf
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @throws IOException Signals that an I/O exception has occurred.
+ */
+ public static void deleteTestReport(HiveConf conf, String algorithm, String reportID) throws IOException {
+ Path reportPath = getTestReportPath(conf, algorithm, reportID);
+ reportPath.getFileSystem(conf).delete(reportPath, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
new file mode 100644
index 0000000..2f2e017
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.QueryHandle;
+
+import lombok.Getter;
+import lombok.Setter;
+
+/**
+ * Run a query against a Lens server.
+ */
+public abstract class QueryRunner {
+
+ /** The session handle. */
+ protected final LensSessionHandle sessionHandle;
+
+ @Getter @Setter
+ protected String queryName;
+
+ /**
+ * Instantiates a new query runner.
+ *
+ * @param sessionHandle the session handle
+ */
+ public QueryRunner(LensSessionHandle sessionHandle) {
+ this.sessionHandle = sessionHandle;
+ }
+
+ /**
+ * Run query.
+ *
+ * @param query the query
+ * @return the query handle
+ * @throws LensException the lens exception
+ */
+ public abstract QueryHandle runQuery(String query) throws LensException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
new file mode 100644
index 0000000..34b2a3f
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
@@ -0,0 +1,325 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+
+import lombok.Getter;
+
+/**
+ * Table specification for running test on a table.
+ */
+public class TableTestingSpec {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);
+
+ /** The db. */
+ private String db;
+
+ /** The table containing input data. */
+ private String inputTable;
+
+ // TODO use partition condition
+ /** The partition filter. */
+ private String partitionFilter;
+
+ /** The feature columns. */
+ private List<String> featureColumns;
+
+ /** The label column. */
+ private String labelColumn;
+
+ /** The output column. */
+ private String outputColumn;
+
+ /** The output table. */
+ private String outputTable;
+
+ /** The conf. */
+ private transient HiveConf conf;
+
+ /** The algorithm. */
+ private String algorithm;
+
+ /** The model id. */
+ private String modelID;
+
+ @Getter
+ private boolean outputTableExists;
+
+ @Getter
+ private String testID;
+
+ private HashMap<String, FieldSchema> columnNameToFieldSchema;
+
+ /**
+ * The Class TableTestingSpecBuilder.
+ */
+ public static class TableTestingSpecBuilder {
+
+ /** The spec. */
+ private final TableTestingSpec spec;
+
+ /**
+ * Instantiates a new table testing spec builder.
+ */
+ public TableTestingSpecBuilder() {
+ spec = new TableTestingSpec();
+ }
+
+ /**
+ * Database.
+ *
+ * @param database the database
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder database(String database) {
+ spec.db = database;
+ return this;
+ }
+
+ /**
+ * Set the input table
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder inputTable(String table) {
+ spec.inputTable = table;
+ return this;
+ }
+
+ /**
+ * Partition filter for input table
+ *
+ * @param partFilter the part filter
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder partitionFilter(String partFilter) {
+ spec.partitionFilter = partFilter;
+ return this;
+ }
+
+ /**
+ * Feature columns.
+ *
+ * @param featureColumns the feature columns
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder featureColumns(List<String> featureColumns) {
+ spec.featureColumns = featureColumns;
+ return this;
+ }
+
+ /**
+ * Labe column.
+ *
+ * @param labelColumn the label column
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder lableColumn(String labelColumn) {
+ spec.labelColumn = labelColumn;
+ return this;
+ }
+
+ /**
+ * Output column.
+ *
+ * @param outputColumn the output column
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputColumn(String outputColumn) {
+ spec.outputColumn = outputColumn;
+ return this;
+ }
+
+ /**
+ * Output table.
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputTable(String table) {
+ spec.outputTable = table;
+ return this;
+ }
+
+ /**
+ * Hive conf.
+ *
+ * @param conf the conf
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder hiveConf(HiveConf conf) {
+ spec.conf = conf;
+ return this;
+ }
+
+ /**
+ * Algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder algorithm(String algorithm) {
+ spec.algorithm = algorithm;
+ return this;
+ }
+
+ /**
+ * Model id.
+ *
+ * @param modelID the model id
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder modelID(String modelID) {
+ spec.modelID = modelID;
+ return this;
+ }
+
+ /**
+ * Builds the.
+ *
+ * @return the table testing spec
+ */
+ public TableTestingSpec build() {
+ return spec;
+ }
+
+ /**
+ * Set the unique test id
+ *
+ * @param testID
+ * @return
+ */
+ public TableTestingSpecBuilder testID(String testID) {
+ spec.testID = testID;
+ return this;
+ }
+ }
+
+ /**
+ * New builder.
+ *
+ * @return the table testing spec builder
+ */
+ public static TableTestingSpecBuilder newBuilder() {
+ return new TableTestingSpecBuilder();
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ public boolean validate() {
+ List<FieldSchema> columns;
+ try {
+ Hive metastoreClient = Hive.get(conf);
+ Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
+ columns = tbl.getAllCols();
+ columnNameToFieldSchema = new HashMap<String, FieldSchema>();
+
+ for (FieldSchema fieldSchema : columns) {
+ columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
+ }
+
+ // Check if output table exists
+ Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
+ outputTableExists = (outTbl != null);
+ } catch (HiveException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ // Check if labeled column and feature columns are contained in the table
+ List<String> testTableColumns = new ArrayList<String>(columns.size());
+ for (FieldSchema column : columns) {
+ testTableColumns.add(column.getName());
+ }
+
+ if (!testTableColumns.containsAll(featureColumns)) {
+ LOG.info("Invalid feature columns: " + featureColumns + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (!testTableColumns.contains(labelColumn)) {
+ LOG.info("Invalid label column: " + labelColumn + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputColumn)) {
+ LOG.info("Output column is required");
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputTable)) {
+ LOG.info("Output table is required");
+ return false;
+ }
+ return true;
+ }
+
+ public String getTestQuery() {
+ if (!validate()) {
+ return null;
+ }
+
+ // We always insert a dynamic partition
+ StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
+ + "') SELECT ");
+ String featureCols = StringUtils.join(featureColumns, ",");
+ q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
+ .append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
+ .append(" FROM ").append(inputTable);
+
+ return q.toString();
+ }
+
+ public String getCreateOutputTableQuery() {
+ StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
+ // Output table contains feature columns, label column, output column
+ List<String> outputTableColumns = new ArrayList<String>();
+ for (String featureCol : featureColumns) {
+ outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
+ }
+
+ outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
+ outputTableColumns.add(outputColumn + " string");
+
+ createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
+
+ // Append partition column
+ createTableQuery.append(") PARTITIONED BY (part_testid string)");
+
+ return createTableQuery.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
new file mode 100644
index 0000000..e6e3c02
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import javax.ws.rs.ApplicationPath;
+import javax.ws.rs.core.Application;
+
+import org.glassfish.jersey.filter.LoggingFilter;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+
+@ApplicationPath("/ml")
+public class MLApp extends Application {
+
+ private final Set<Class<?>> classes;
+
+ /**
+ * Pass additional classes when running in test mode
+ *
+ * @param additionalClasses
+ */
+ public MLApp(Class<?>... additionalClasses) {
+ classes = new HashSet<Class<?>>();
+
+ // register root resource
+ classes.add(MLServiceResource.class);
+ classes.add(MultiPartFeature.class);
+ classes.add(LoggingFilter.class);
+ for (Class<?> cls : additionalClasses) {
+ classes.add(cls);
+ }
+
+ }
+
+ /**
+ * Get classes for this resource
+ */
+ @Override
+ public Set<Class<?>> getClasses() {
+ return classes;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
new file mode 100644
index 0000000..f8b7cd1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLService.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import org.apache.lens.ml.api.LensML;
+
+/**
+ * The Interface MLService.
+ */
+public interface MLService extends LensML {
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
new file mode 100644
index 0000000..f3e8ec1
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
@@ -0,0 +1,329 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.LensQuery;
+import org.apache.lens.api.query.QueryHandle;
+import org.apache.lens.api.query.QueryStatus;
+import org.apache.lens.ml.algo.api.*;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.impl.HiveMLUDF;
+import org.apache.lens.ml.impl.LensMLImpl;
+import org.apache.lens.ml.impl.ModelLoader;
+import org.apache.lens.ml.impl.QueryRunner;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+import org.apache.lens.server.api.query.QueryExecutionService;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
+import org.apache.hive.service.CompositeService;
+
+/**
+ * The Class MLServiceImpl.
+ */
+public class MLServiceImpl extends CompositeService implements MLService {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /** The ml. */
+ private LensMLImpl ml;
+
+ /** The service provider. */
+ private ServiceProvider serviceProvider;
+
+ /** The service provider factory. */
+ private ServiceProviderFactory serviceProviderFactory;
+
+ /**
+ * Instantiates a new ML service impl.
+ */
+ public MLServiceImpl() {
+ this(NAME);
+ }
+
+ /**
+ * Instantiates a new ML service impl.
+ *
+ * @param name the name
+ */
+ public MLServiceImpl(String name) {
+ super(name);
+ }
+
+ @Override
+ public List<String> getAlgorithms() {
+ return ml.getAlgorithms();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
+ */
+ @Override
+ public MLAlgo getAlgoForName(String algorithm) throws LensException {
+ return ml.getAlgoForName(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
+ */
+ @Override
+ public String train(String table, String algorithm, String[] args) throws LensException {
+ return ml.train(table, algorithm, args);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
+ */
+ @Override
+ public List<String> getModels(String algorithm) throws LensException {
+ return ml.getModels(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
+ */
+ @Override
+ public MLModel getModel(String algorithm, String modelId) throws LensException {
+ return ml.getModel(algorithm, modelId);
+ }
+
+ private ServiceProvider getServiceProvider() {
+ if (serviceProvider == null) {
+ serviceProvider = serviceProviderFactory.getServiceProvider();
+ }
+ return serviceProvider;
+ }
+
+ /**
+ * Gets the service provider factory.
+ *
+ * @param conf the conf
+ * @return the service provider factory
+ */
+ private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
+ Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
+ try {
+ return (ServiceProviderFactory) spfClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#init(org.apache.hadoop.hive.conf.HiveConf)
+ */
+ @Override
+ public synchronized void init(HiveConf hiveConf) {
+ ml = new LensMLImpl(hiveConf);
+ ml.init(hiveConf);
+ super.init(hiveConf);
+ serviceProviderFactory = getServiceProviderFactory(hiveConf);
+ LOG.info("Inited ML service");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#start()
+ */
+ @Override
+ public synchronized void start() {
+ ml.start();
+ super.start();
+ LOG.info("Started ML service");
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#stop()
+ */
+ @Override
+ public synchronized void stop() {
+ ml.stop();
+ super.stop();
+ LOG.info("Stopped ML service");
+ }
+
+ /**
+ * Clear models.
+ */
+ public void clearModels() {
+ ModelLoader.clearCache();
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
+ */
+ @Override
+ public String getModelPath(String algorithm, String modelID) {
+ return ml.getModelPath(algorithm, modelID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
+ * java.lang.String)
+ */
+ @Override
+ public MLTestReport testModel(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
+ String outputTable) throws LensException {
+ return ml.testModel(sessionHandle, table, algorithm, modelID, new DirectQueryRunner(sessionHandle), outputTable);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
+ */
+ @Override
+ public List<String> getTestReports(String algorithm) throws LensException {
+ return ml.getTestReports(algorithm);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
+ */
+ @Override
+ public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
+ return ml.getTestReport(algorithm, reportID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
+ */
+ @Override
+ public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
+ return ml.predict(algorithm, modelID, features);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
+ */
+ @Override
+ public void deleteModel(String algorithm, String modelID) throws LensException {
+ ml.deleteModel(algorithm, modelID);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
+ */
+ @Override
+ public void deleteTestReport(String algorithm, String reportID) throws LensException {
+ ml.deleteTestReport(algorithm, reportID);
+ }
+
+ /**
+ * Run the test model query directly in the current lens server process.
+ */
+ private class DirectQueryRunner extends QueryRunner {
+
+ /**
+ * Instantiates a new direct query runner.
+ *
+ * @param sessionHandle the session handle
+ */
+ public DirectQueryRunner(LensSessionHandle sessionHandle) {
+ super(sessionHandle);
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
+ */
+ @Override
+ public QueryHandle runQuery(String testQuery) throws LensException {
+ FunctionRegistry.registerTemporaryFunction("predict", HiveMLUDF.class);
+ LOG.info("Registered predict UDF");
+ // Run the query in query executions service
+ QueryExecutionService queryService = (QueryExecutionService) getServiceProvider().getService("query");
+
+ LensConf queryConf = new LensConf();
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
+
+ QueryHandle testQueryHandle = queryService.executeAsync(sessionHandle, testQuery, queryConf, queryName);
+
+ // Wait for test query to complete
+ LensQuery query = queryService.getQuery(sessionHandle, testQueryHandle);
+ LOG.info("Submitted query " + testQueryHandle.getHandleId());
+ while (!query.getStatus().isFinished()) {
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new LensException(e);
+ }
+
+ query = queryService.getQuery(sessionHandle, testQueryHandle);
+ }
+
+ if (query.getStatus().getStatus() != QueryStatus.Status.SUCCESSFUL) {
+ throw new LensException("Failed to run test query: " + testQueryHandle.getHandleId() + " reason= "
+ + query.getStatus().getErrorMessage());
+ }
+
+ return testQueryHandle;
+ }
+ }
+
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
+ */
+ @Override
+ public Map<String, String> getAlgoParamDescription(String algorithm) {
+ return ml.getAlgoParamDescription(algorithm);
+ }
+}