You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lens.apache.org by am...@apache.org on 2015/04/15 21:49:59 UTC
[25/50] [abbrv] incubator-lens git commit: Lens-465 : Refactor ml
packages. (sharad)
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
new file mode 100644
index 0000000..f9c954e
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
@@ -0,0 +1,427 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.server;
+
+import static org.apache.commons.lang.StringUtils.isBlank;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.ws.rs.BadRequestException;
+import javax.ws.rs.Consumes;
+import javax.ws.rs.DELETE;
+import javax.ws.rs.GET;
+import javax.ws.rs.NotFoundException;
+import javax.ws.rs.POST;
+import javax.ws.rs.Path;
+import javax.ws.rs.PathParam;
+import javax.ws.rs.Produces;
+import javax.ws.rs.core.Context;
+import javax.ws.rs.core.MediaType;
+import javax.ws.rs.core.MultivaluedMap;
+import javax.ws.rs.core.Response;
+import javax.ws.rs.core.UriInfo;
+
+import org.apache.lens.api.LensException;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.StringList;
+import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.api.ModelMetadata;
+import org.apache.lens.ml.api.TestReport;
+import org.apache.lens.ml.impl.ModelLoader;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+
+import org.glassfish.jersey.media.multipart.FormDataParam;
+
+/**
+ * Machine Learning service.
+ */
+@Path("/ml")
+@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
+public class MLServiceResource {
+
+ /** The Constant LOG. */
+ public static final Log LOG = LogFactory.getLog(MLServiceResource.class);
+
+ /** The ml service. */
+ MLService mlService;
+
+ /** The service provider. */
+ ServiceProvider serviceProvider;
+
+ /** The service provider factory. */
+ ServiceProviderFactory serviceProviderFactory;
+
+ private static final HiveConf HIVE_CONF;
+
+ /**
+ * Message indicating if ML service is up
+ */
+ public static final String ML_UP_MESSAGE = "ML service is up";
+
+ static {
+ HIVE_CONF = new HiveConf();
+ // Add default config so that we know the service provider implementation
+ HIVE_CONF.addResource("lensserver-default.xml");
+ HIVE_CONF.addResource("lens-site.xml");
+ }
+
+ /**
+ * Instantiates a new ML service resource.
+ */
+ public MLServiceResource() {
+ serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
+ }
+
+ private ServiceProvider getServiceProvider() {
+ if (serviceProvider == null) {
+ serviceProvider = serviceProviderFactory.getServiceProvider();
+ }
+ return serviceProvider;
+ }
+
+ /**
+ * Gets the service provider factory.
+ *
+ * @param conf the conf
+ * @return the service provider factory
+ */
+ private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
+ Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
+ try {
+ return (ServiceProviderFactory) spfClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private MLService getMlService() {
+ if (mlService == null) {
+ mlService = (MLService) getServiceProvider().getService(MLService.NAME);
+ }
+ return mlService;
+ }
+
+ /**
+ * Indicates if ML resource is up
+ *
+ * @return
+ */
+ @GET
+ public String mlResourceUp() {
+ return ML_UP_MESSAGE;
+ }
+
+ /**
+ * Get a list of algos available
+ *
+ * @return
+ */
+ @GET
+ @Path("algos")
+ public StringList getAlgoNames() {
+ List<String> algos = getMlService().getAlgorithms();
+ StringList result = new StringList(algos);
+ return result;
+ }
+
+ /**
+ * Gets the human readable param description of an algorithm
+ *
+ * @param algorithm the algorithm
+ * @return the param description
+ */
+ @GET
+ @Path("algos/{algorithm}")
+ public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
+ Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm);
+ if (paramDesc == null) {
+ throw new NotFoundException("Param description not found for " + algorithm);
+ }
+
+ List<String> descriptions = new ArrayList<String>();
+ for (String key : paramDesc.keySet()) {
+ descriptions.add(key + " : " + paramDesc.get(key));
+ }
+ return new StringList(descriptions);
+ }
+
+ /**
+ * Get model ID list for a given algorithm.
+ *
+ * @param algorithm algorithm name
+ * @return the models for algo
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("models/{algorithm}")
+ public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
+ List<String> models = getMlService().getModels(algorithm);
+ if (models == null || models.isEmpty()) {
+ throw new NotFoundException("No models found for algorithm " + algorithm);
+ }
+ return new StringList(models);
+ }
+
+ /**
+ * Get metadata of the model given algorithm and model ID.
+ *
+ * @param algorithm algorithm name
+ * @param modelID model ID
+ * @return model metadata
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("models/{algorithm}/{modelID}")
+ public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
+ throws LensException {
+ MLModel model = getMlService().getModel(algorithm, modelID);
+ if (model == null) {
+ throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
+ }
+
+ ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
+ model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
+ model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
+ return meta;
+ }
+
+ /**
+ * Delete a model given model ID and algorithm name.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @return confirmation text
+ * @throws LensException the lens exception
+ */
+ @DELETE
+ @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
+ @Path("models/{algorithm}/{modelID}")
+ public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
+ throws LensException {
+ getMlService().deleteModel(algorithm, modelID);
+ return "DELETED model=" + modelID + " algorithm=" + algorithm;
+ }
+
+ /**
+ * Train a model given an algorithm name and algorithm parameters
+ * <p>
+ * Following parameters are mandatory and must be passed as part of the form
+ * <p/>
+ * <ol>
+ * <li>table - input Hive table to load training data from</li>
+ * <li>label - name of the labelled column</li>
+ * <li>feature - one entry per feature column. At least one feature column is required</li>
+ * </ol>
+ * <p/>
+ * </p>
+ *
+ * @param algorithm algorithm name
+ * @param form form data
+ * @return if model is successfully trained, the model ID will be returned
+ * @throws LensException the lens exception
+ */
+ @POST
+ @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
+ @Path("{algorithm}/train")
+ public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form)
+ throws LensException {
+
+ // Check if algo is valid
+ if (getMlService().getAlgoForName(algorithm) == null) {
+ throw new NotFoundException("Algo for algo: " + algorithm + " not found");
+ }
+
+ if (isBlank(form.getFirst("table"))) {
+ throw new BadRequestException("table parameter is rquired");
+ }
+
+ String table = form.getFirst("table");
+
+ if (isBlank(form.getFirst("label"))) {
+ throw new BadRequestException("label parameter is required");
+ }
+
+ // Check features
+ List<String> featureNames = form.get("feature");
+ if (featureNames.size() < 1) {
+ throw new BadRequestException("At least one feature is required");
+ }
+
+ List<String> algoArgs = new ArrayList<String>();
+ Set<Map.Entry<String, List<String>>> paramSet = form.entrySet();
+
+ for (Map.Entry<String, List<String>> e : paramSet) {
+ String p = e.getKey();
+ List<String> values = e.getValue();
+ if ("algorithm".equals(p) || "table".equals(p)) {
+ continue;
+ } else if ("feature".equals(p)) {
+ for (String feature : values) {
+ algoArgs.add("feature");
+ algoArgs.add(feature);
+ }
+ } else if ("label".equals(p)) {
+ algoArgs.add("label");
+ algoArgs.add(values.get(0));
+ } else {
+ algoArgs.add(p);
+ algoArgs.add(values.get(0));
+ }
+ }
+ LOG.info("Training table " + table + " with algo " + algorithm + " params=" + algoArgs.toString());
+ String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
+ LOG.info("Done training " + table + " modelid = " + modelId);
+ return modelId;
+ }
+
+ /**
+ * Clear model cache (for admin use).
+ *
+ * @return OK if the cache was cleared
+ */
+ @DELETE
+ @Path("clearModelCache")
+ @Produces(MediaType.TEXT_PLAIN)
+ public Response clearModelCache() {
+ ModelLoader.clearCache();
+ LOG.info("Cleared model cache");
+ return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
+ }
+
+ /**
+ * Run a test on a model for an algorithm.
+ *
+ * @param algorithm algorithm name
+ * @param modelID model ID
+ * @param table Hive table to run test on
+ * @param session Lens session ID. This session ID will be used to run the test query
+ * @return Test report ID
+ * @throws LensException the lens exception
+ */
+ @POST
+ @Path("test/{table}/{algorithm}/{modelID}")
+ @Consumes(MediaType.MULTIPART_FORM_DATA)
+ public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
+ @PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session,
+ @FormDataParam("outputTable") String outputTable) throws LensException {
+ MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable);
+ return testReport.getReportID();
+ }
+
+ /**
+ * Get list of reports for a given algorithm.
+ *
+ * @param algoritm the algoritm
+ * @return the reports for algorithm
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("reports/{algorithm}")
+ public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException {
+ List<String> reports = getMlService().getTestReports(algoritm);
+ if (reports == null || reports.isEmpty()) {
+ throw new NotFoundException("No test reports found for " + algoritm);
+ }
+ return new StringList(reports);
+ }
+
+ /**
+ * Get a single test report given the algorithm name and report id.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the test report
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("reports/{algorithm}/{reportID}")
+ public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
+ throws LensException {
+ MLTestReport report = getMlService().getTestReport(algorithm, reportID);
+
+ if (report == null) {
+ throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm);
+ }
+
+ TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(),
+ report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(),
+ report.getModelID(), report.getReportID(), report.getLensQueryID());
+ return result;
+ }
+
+ /**
+ * DELETE a report given the algorithm name and report ID.
+ *
+ * @param algorithm the algorithm
+ * @param reportID the report id
+ * @return the string
+ * @throws LensException the lens exception
+ */
+ @DELETE
+ @Path("reports/{algorithm}/{reportID}")
+ @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
+ public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
+ throws LensException {
+ getMlService().deleteTestReport(algorithm, reportID);
+ return "DELETED report=" + reportID + " algorithm=" + algorithm;
+ }
+
+ /**
+ * Predict.
+ *
+ * @param algorithm the algorithm
+ * @param modelID the model id
+ * @param uriInfo the uri info
+ * @return the string
+ * @throws LensException the lens exception
+ */
+ @GET
+ @Path("/predict/{algorithm}/{modelID}")
+ @Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON})
+ public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
+ @Context UriInfo uriInfo) throws LensException {
+ // Load the model instance
+ MLModel<?> model = getMlService().getModel(algorithm, modelID);
+
+ // Get input feature names
+ MultivaluedMap<String, String> params = uriInfo.getQueryParameters();
+ String[] features = new String[model.getFeatureColumns().size()];
+ // Assuming that feature name parameters are same
+ int i = 0;
+ for (String feature : model.getFeatureColumns()) {
+ features[i++] = params.getFirst(feature);
+ }
+
+ // TODO needs a 'prediction formatter'
+ return getMlService().predict(algorithm, modelID, features).toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
deleted file mode 100644
index abdad68..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/ColumnFeatureFunction.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.log4j.Logger;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-import com.google.common.base.Preconditions;
-import scala.Tuple2;
-
-/**
- * A feature function that directly maps an HCatRecord to a feature vector. Each column becomes a feature in the vector,
- * with the value of the feature obtained using the value mapper for that column
- */
-public class ColumnFeatureFunction extends FeatureFunction {
-
- /** The Constant LOG. */
- public static final Logger LOG = Logger.getLogger(ColumnFeatureFunction.class);
-
- /** The feature value mappers. */
- private final FeatureValueMapper[] featureValueMappers;
-
- /** The feature positions. */
- private final int[] featurePositions;
-
- /** The label column pos. */
- private final int labelColumnPos;
-
- /** The num features. */
- private final int numFeatures;
-
- /** The default labeled point. */
- private final LabeledPoint defaultLabeledPoint;
-
- /**
- * Feature positions and value mappers are parallel arrays. featurePositions[i] gives the position of ith feature in
- * the HCatRecord, and valueMappers[i] gives the value mapper used to map that feature to a Double value
- *
- * @param featurePositions position number of feature column in the HCatRecord
- * @param valueMappers mapper for each column position
- * @param labelColumnPos position of the label column
- * @param numFeatures number of features in the feature vector
- * @param defaultLabel default lable to be used for null records
- */
- public ColumnFeatureFunction(int[] featurePositions, FeatureValueMapper[] valueMappers, int labelColumnPos,
- int numFeatures, double defaultLabel) {
- Preconditions.checkNotNull(valueMappers, "Value mappers argument is required");
- Preconditions.checkNotNull(featurePositions, "Feature positions are required");
- Preconditions.checkArgument(valueMappers.length == featurePositions.length,
- "Mismatch between number of value mappers and feature positions");
-
- this.featurePositions = featurePositions;
- this.featureValueMappers = valueMappers;
- this.labelColumnPos = labelColumnPos;
- this.numFeatures = numFeatures;
- defaultLabeledPoint = new LabeledPoint(defaultLabel, Vectors.dense(new double[numFeatures]));
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.FeatureFunction#call(scala.Tuple2)
- */
- @Override
- public LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception {
- HCatRecord record = tuple._2();
-
- if (record == null) {
- LOG.info("@@@ Null record");
- return defaultLabeledPoint;
- }
-
- double[] features = new double[numFeatures];
-
- for (int i = 0; i < numFeatures; i++) {
- int featurePos = featurePositions[i];
- features[i] = featureValueMappers[i].call(record.get(featurePos));
- }
-
- double label = featureValueMappers[labelColumnPos].call(record.get(labelColumnPos));
- return new LabeledPoint(label, Vectors.dense(features));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
deleted file mode 100644
index 781ccd1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/DoubleValueMapper.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-/**
- * Directly return input when it is known to be double.
- */
-public class DoubleValueMapper extends FeatureValueMapper {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.FeatureValueMapper#call(java.lang.Object)
- */
- @Override
- public final Double call(Object input) {
- if (input instanceof Double || input == null) {
- return input == null ? Double.valueOf(0d) : (Double) input;
- }
-
- throw new IllegalArgumentException("Invalid input expecting only doubles, but got " + input);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
deleted file mode 100644
index affed7b..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureFunction.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-import scala.Tuple2;
-
-/**
- * Function to map an HCatRecord to a feature vector usable by MLLib.
- */
-public abstract class FeatureFunction implements Function<Tuple2<WritableComparable, HCatRecord>, LabeledPoint> {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public abstract LabeledPoint call(Tuple2<WritableComparable, HCatRecord> tuple) throws Exception;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
deleted file mode 100644
index b692379..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/FeatureValueMapper.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.Serializable;
-
-import org.apache.spark.api.java.function.Function;
-
-/**
- * Map a feature value to a Double value usable by MLLib.
- */
-public abstract class FeatureValueMapper implements Function<Object, Double>, Serializable {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- public abstract Double call(Object input);
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
deleted file mode 100644
index 44a8e1d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/HiveTableRDD.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.IOException;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-
-/**
- * Create a JavaRDD based on a Hive table using HCatInputFormat.
- */
-public final class HiveTableRDD {
- private HiveTableRDD() {
- }
-
- public static final Log LOG = LogFactory.getLog(HiveTableRDD.class);
-
- /**
- * Creates the hive table rdd.
- *
- * @param javaSparkContext the java spark context
- * @param conf the conf
- * @param db the db
- * @param table the table
- * @param partitionFilter the partition filter
- * @return the java pair rdd
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static JavaPairRDD<WritableComparable, HCatRecord> createHiveTableRDD(JavaSparkContext javaSparkContext,
- Configuration conf, String db, String table, String partitionFilter) throws IOException {
-
- HCatInputFormat.setInput(conf, db, table, partitionFilter);
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = javaSparkContext.newAPIHadoopRDD(conf,
- HCatInputFormat.class, // Input
- WritableComparable.class, // input key class
- HCatRecord.class); // input value class
- return rdd;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
deleted file mode 100644
index 1e452f1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/SparkMLDriver.java
+++ /dev/null
@@ -1,275 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.File;
-import java.io.FilenameFilter;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.Algorithms;
-import org.apache.lens.ml.MLAlgo;
-import org.apache.lens.ml.MLDriver;
-import org.apache.lens.ml.spark.algos.*;
-
-import org.apache.commons.lang.StringUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-
-/**
- * The Class SparkMLDriver.
- */
-public class SparkMLDriver implements MLDriver {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(SparkMLDriver.class);
-
- /** The owns spark context. */
- private boolean ownsSparkContext = true;
-
- /**
- * The Enum SparkMasterMode.
- */
- private enum SparkMasterMode {
- // Embedded mode used in tests
- /** The embedded. */
- EMBEDDED,
- // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
- /** The yarn client. */
- YARN_CLIENT,
-
- /** The yarn cluster. */
- YARN_CLUSTER
- }
-
- /** The algorithms. */
- private final Algorithms algorithms = new Algorithms();
-
- /** The client mode. */
- private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
-
- /** The is started. */
- private boolean isStarted;
-
- /** The spark conf. */
- private SparkConf sparkConf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /**
- * Use spark context.
- *
- * @param jsc the jsc
- */
- public void useSparkContext(JavaSparkContext jsc) {
- ownsSparkContext = false;
- this.sparkContext = jsc;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#isAlgoSupported(java.lang.String)
- */
- @Override
- public boolean isAlgoSupported(String name) {
- return algorithms.isAlgoSupported(name);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
- */
- @Override
- public MLAlgo getAlgoInstance(String name) throws LensException {
- checkStarted();
-
- if (!isAlgoSupported(name)) {
- return null;
- }
-
- MLAlgo algo = null;
- try {
- algo = algorithms.getAlgoForName(name);
- if (algo instanceof BaseSparkAlgo) {
- ((BaseSparkAlgo) algo).setSparkContext(sparkContext);
- }
- } catch (LensException exc) {
- LOG.error("Error creating algo object", exc);
- }
- return algo;
- }
-
- /**
- * Register algos.
- */
- private void registerAlgos() {
- algorithms.register(NaiveBayesAlgo.class);
- algorithms.register(SVMAlgo.class);
- algorithms.register(LogisticRegressionAlgo.class);
- algorithms.register(DecisionTreeAlgo.class);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
- */
- @Override
- public void init(LensConf conf) throws LensException {
- sparkConf = new SparkConf();
- registerAlgos();
- for (String key : conf.getProperties().keySet()) {
- if (key.startsWith("lens.ml.sparkdriver.")) {
- sparkConf.set(key.substring("lens.ml.sparkdriver.".length()), conf.getProperties().get(key));
- }
- }
-
- String sparkAppMaster = sparkConf.get("spark.master");
- if ("yarn-client".equalsIgnoreCase(sparkAppMaster)) {
- clientMode = SparkMasterMode.YARN_CLIENT;
- } else if ("yarn-cluster".equalsIgnoreCase(sparkAppMaster)) {
- clientMode = SparkMasterMode.YARN_CLUSTER;
- } else if ("local".equalsIgnoreCase(sparkAppMaster) || StringUtils.isBlank(sparkAppMaster)) {
- clientMode = SparkMasterMode.EMBEDDED;
- } else {
- throw new IllegalArgumentException("Invalid master mode " + sparkAppMaster);
- }
-
- if (clientMode == SparkMasterMode.YARN_CLIENT || clientMode == SparkMasterMode.YARN_CLUSTER) {
- String sparkHome = System.getenv("SPARK_HOME");
- if (StringUtils.isNotBlank(sparkHome)) {
- sparkConf.setSparkHome(sparkHome);
- }
-
- // If SPARK_HOME is not set, SparkConf can read from the Lens-site.xml or System properties.
- if (StringUtils.isBlank(sparkConf.get("spark.home"))) {
- throw new IllegalArgumentException("Spark home is not set");
- }
-
- LOG.info("Spark home is set to " + sparkConf.get("spark.home"));
- }
-
- sparkConf.setAppName("lens-ml");
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#start()
- */
- @Override
- public void start() throws LensException {
- if (sparkContext == null) {
- sparkContext = new JavaSparkContext(sparkConf);
- }
-
- // Adding jars to spark context is only required when running in yarn-client mode
- if (clientMode != SparkMasterMode.EMBEDDED) {
- // TODO Figure out only necessary set of JARs to be added for HCatalog
- // Add hcatalog and hive jars
- String hiveLocation = System.getenv("HIVE_HOME");
-
- if (StringUtils.isBlank(hiveLocation)) {
- throw new LensException("HIVE_HOME is not set");
- }
-
- LOG.info("HIVE_HOME at " + hiveLocation);
-
- File hiveLibDir = new File(hiveLocation, "lib");
- FilenameFilter jarFileFilter = new FilenameFilter() {
- @Override
- public boolean accept(File file, String s) {
- return s.endsWith(".jar");
- }
- };
-
- List<String> jarFiles = new ArrayList<String>();
- // Add hive jars
- for (File jarFile : hiveLibDir.listFiles(jarFileFilter)) {
- jarFiles.add(jarFile.getAbsolutePath());
- LOG.info("Adding HIVE jar " + jarFile.getAbsolutePath());
- sparkContext.addJar(jarFile.getAbsolutePath());
- }
-
- // Add hcatalog jars
- File hcatalogDir = new File(hiveLocation + "/hcatalog/share/hcatalog");
- for (File jarFile : hcatalogDir.listFiles(jarFileFilter)) {
- jarFiles.add(jarFile.getAbsolutePath());
- LOG.info("Adding HCATALOG jar " + jarFile.getAbsolutePath());
- sparkContext.addJar(jarFile.getAbsolutePath());
- }
-
- // Add the current jar
- String[] lensSparkLibJars = JavaSparkContext.jarOfClass(SparkMLDriver.class);
- for (String lensSparkJar : lensSparkLibJars) {
- LOG.info("Adding Lens JAR " + lensSparkJar);
- sparkContext.addJar(lensSparkJar);
- }
- }
-
- isStarted = true;
- LOG.info("Created Spark context for app: '" + sparkContext.appName() + "', Spark master: " + sparkContext.master());
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLDriver#stop()
- */
- @Override
- public void stop() throws LensException {
- if (!isStarted) {
- LOG.warn("Spark driver was not started");
- return;
- }
- isStarted = false;
- if (ownsSparkContext) {
- sparkContext.stop();
- }
- LOG.info("Stopped spark context " + this);
- }
-
- @Override
- public List<String> getAlgoNames() {
- return algorithms.getAlgorithmNames();
- }
-
- /**
- * Check started.
- *
- * @throws LensException the lens exception
- */
- public void checkStarted() throws LensException {
- if (!isStarted) {
- throw new LensException("Spark driver is not started yet");
- }
- }
-
- public JavaSparkContext getSparkContext() {
- return sparkContext;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
deleted file mode 100644
index e569b1e..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/TableTrainingSpec.java
+++ /dev/null
@@ -1,433 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark;
-
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.lens.api.LensException;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.hive.hcatalog.data.schema.HCatFieldSchema;
-import org.apache.hive.hcatalog.data.schema.HCatSchema;
-import org.apache.hive.hcatalog.mapreduce.HCatInputFormat;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-import com.google.common.base.Preconditions;
-import lombok.Getter;
-import lombok.ToString;
-
-/**
- * The Class TableTrainingSpec.
- */
-@ToString
-public class TableTrainingSpec implements Serializable {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);
-
- /** The training rdd. */
- @Getter
- private transient RDD<LabeledPoint> trainingRDD;
-
- /** The testing rdd. */
- @Getter
- private transient RDD<LabeledPoint> testingRDD;
-
- /** The database. */
- @Getter
- private String database;
-
- /** The table. */
- @Getter
- private String table;
-
- /** The partition filter. */
- @Getter
- private String partitionFilter;
-
- /** The feature columns. */
- @Getter
- private List<String> featureColumns;
-
- /** The label column. */
- @Getter
- private String labelColumn;
-
- /** The conf. */
- @Getter
- private transient HiveConf conf;
-
- // By default all samples are considered for training
- /** The split training. */
- private boolean splitTraining;
-
- /** The training fraction. */
- private double trainingFraction = 1.0;
-
- /** The label pos. */
- int labelPos;
-
- /** The feature positions. */
- int[] featurePositions;
-
- /** The num features. */
- int numFeatures;
-
- /** The labeled rdd. */
- transient JavaRDD<LabeledPoint> labeledRDD;
-
- /**
- * New builder.
- *
- * @return the table training spec builder
- */
- public static TableTrainingSpecBuilder newBuilder() {
- return new TableTrainingSpecBuilder();
- }
-
- /**
- * The Class TableTrainingSpecBuilder.
- */
- public static class TableTrainingSpecBuilder {
-
- /** The spec. */
- final TableTrainingSpec spec;
-
- /**
- * Instantiates a new table training spec builder.
- */
- public TableTrainingSpecBuilder() {
- spec = new TableTrainingSpec();
- }
-
- /**
- * Hive conf.
- *
- * @param conf the conf
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder hiveConf(HiveConf conf) {
- spec.conf = conf;
- return this;
- }
-
- /**
- * Database.
- *
- * @param db the db
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder database(String db) {
- spec.database = db;
- return this;
- }
-
- /**
- * Table.
- *
- * @param table the table
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder table(String table) {
- spec.table = table;
- return this;
- }
-
- /**
- * Partition filter.
- *
- * @param partFilter the part filter
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder partitionFilter(String partFilter) {
- spec.partitionFilter = partFilter;
- return this;
- }
-
- /**
- * Label column.
- *
- * @param labelColumn the label column
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder labelColumn(String labelColumn) {
- spec.labelColumn = labelColumn;
- return this;
- }
-
- /**
- * Feature columns.
- *
- * @param featureColumns the feature columns
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) {
- spec.featureColumns = featureColumns;
- return this;
- }
-
- /**
- * Builds the.
- *
- * @return the table training spec
- */
- public TableTrainingSpec build() {
- return spec;
- }
-
- /**
- * Training fraction.
- *
- * @param trainingFraction the training fraction
- * @return the table training spec builder
- */
- public TableTrainingSpecBuilder trainingFraction(double trainingFraction) {
- Preconditions.checkArgument(trainingFraction >= 0 && trainingFraction <= 1.0,
- "Training fraction shoule be between 0 and 1");
- spec.trainingFraction = trainingFraction;
- spec.splitTraining = true;
- return this;
- }
- }
-
- /**
- * The Class DataSample.
- */
- public static class DataSample implements Serializable {
-
- /** The labeled point. */
- private final LabeledPoint labeledPoint;
-
- /** The sample. */
- private final double sample;
-
- /**
- * Instantiates a new data sample.
- *
- * @param labeledPoint the labeled point
- */
- public DataSample(LabeledPoint labeledPoint) {
- sample = Math.random();
- this.labeledPoint = labeledPoint;
- }
- }
-
- /**
- * The Class TrainingFilter.
- */
- public static class TrainingFilter implements Function<DataSample, Boolean> {
-
- /** The training fraction. */
- private double trainingFraction;
-
- /**
- * Instantiates a new training filter.
- *
- * @param fraction the fraction
- */
- public TrainingFilter(double fraction) {
- trainingFraction = fraction;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public Boolean call(DataSample v1) throws Exception {
- return v1.sample <= trainingFraction;
- }
- }
-
- /**
- * The Class TestingFilter.
- */
- public static class TestingFilter implements Function<DataSample, Boolean> {
-
- /** The training fraction. */
- private double trainingFraction;
-
- /**
- * Instantiates a new testing filter.
- *
- * @param fraction the fraction
- */
- public TestingFilter(double fraction) {
- trainingFraction = fraction;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public Boolean call(DataSample v1) throws Exception {
- return v1.sample > trainingFraction;
- }
- }
-
- /**
- * The Class GetLabeledPoint.
- */
- public static class GetLabeledPoint implements Function<DataSample, LabeledPoint> {
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.spark.api.java.function.Function#call(java.lang.Object)
- */
- @Override
- public LabeledPoint call(DataSample v1) throws Exception {
- return v1.labeledPoint;
- }
- }
-
- /**
- * Validate.
- *
- * @return true, if successful
- */
- boolean validate() {
- List<HCatFieldSchema> columns;
- try {
- HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
- HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
- columns = tableSchema.getFields();
- } catch (IOException exc) {
- LOG.error("Error getting table info " + toString(), exc);
- return false;
- }
-
- LOG.info(table + " columns " + columns.toString());
-
- boolean valid = false;
- if (columns != null && !columns.isEmpty()) {
- // Check labeled column
- List<String> columnNames = new ArrayList<String>();
- for (HCatFieldSchema col : columns) {
- columnNames.add(col.getName());
- }
-
- // Need at least one feature column and one label column
- valid = columnNames.contains(labelColumn) && columnNames.size() > 1;
-
- if (valid) {
- labelPos = columnNames.indexOf(labelColumn);
-
- // Check feature columns
- if (featureColumns == null || featureColumns.isEmpty()) {
- // feature columns are not provided, so all columns except label column are feature columns
- featurePositions = new int[columnNames.size() - 1];
- int p = 0;
- for (int i = 0; i < columnNames.size(); i++) {
- if (i == labelPos) {
- continue;
- }
- featurePositions[p++] = i;
- }
-
- columnNames.remove(labelPos);
- featureColumns = columnNames;
- } else {
- // Feature columns were provided, verify all feature columns are present in the table
- valid = columnNames.containsAll(featureColumns);
- if (valid) {
- // Get feature positions
- featurePositions = new int[featureColumns.size()];
- for (int i = 0; i < featureColumns.size(); i++) {
- featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
- }
- }
- }
- numFeatures = featureColumns.size();
- }
- }
-
- return valid;
- }
-
- /**
- * Creates the rd ds.
- *
- * @param sparkContext the spark context
- * @throws LensException the lens exception
- */
- public void createRDDs(JavaSparkContext sparkContext) throws LensException {
- // Validate the spec
- if (!validate()) {
- throw new LensException("Table spec not valid: " + toString());
- }
-
- LOG.info("Creating RDDs with spec " + toString());
-
- // Get the RDD for table
- JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
- try {
- tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
- } catch (IOException e) {
- throw new LensException(e);
- }
-
- // Map into trainable RDD
- // TODO: Figure out a way to use custom value mappers
- FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
- final DoubleValueMapper doubleMapper = new DoubleValueMapper();
- for (int i = 0; i < numFeatures; i++) {
- valueMappers[i] = doubleMapper;
- }
-
- ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
- numFeatures, 0);
- labeledRDD = tableRDD.map(trainPrepFunction);
-
- if (splitTraining) {
- // We have to split the RDD between a training RDD and a testing RDD
- LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
- JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
- @Override
- public DataSample call(LabeledPoint v1) throws Exception {
- return new DataSample(v1);
- }
- });
-
- trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- } else {
- LOG.info("Using same RDD for train and test");
- trainingRDD = labeledRDD.rdd();
- testingRDD = trainingRDD;
- }
- LOG.info("Generated RDDs");
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
deleted file mode 100644
index 22cda6d..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/BaseSparkAlgo.java
+++ /dev/null
@@ -1,290 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.lang.reflect.Field;
-import java.util.*;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.MLAlgo;
-import org.apache.lens.ml.MLModel;
-
-import org.apache.lens.ml.spark.TableTrainingSpec;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class BaseSparkAlgo.
- */
-public abstract class BaseSparkAlgo implements MLAlgo {
-
- /** The Constant LOG. */
- public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
-
- /** The name. */
- private final String name;
-
- /** The description. */
- private final String description;
-
- /** The spark context. */
- protected JavaSparkContext sparkContext;
-
- /** The params. */
- protected Map<String, String> params;
-
- /** The conf. */
- protected transient LensConf conf;
-
- /** The training fraction. */
- @AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
- protected double trainingFraction;
-
- /** The use training fraction. */
- private boolean useTrainingFraction;
-
- /** The label. */
- @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
- protected String label;
-
- /** The partition filter. */
- @AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
- protected String partitionFilter;
-
- /** The features. */
- @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
- protected List<String> features;
-
- /**
- * Instantiates a new base spark algo.
- *
- * @param name the name
- * @param description the description
- */
- public BaseSparkAlgo(String name, String description) {
- this.name = name;
- this.description = description;
- }
-
- public void setSparkContext(JavaSparkContext sparkContext) {
- this.sparkContext = sparkContext;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
- throws LensException {
- parseParams(params);
-
- TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
- .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
-
- if (useTrainingFraction) {
- builder.trainingFraction(trainingFraction);
- }
-
- TableTrainingSpec spec = builder.build();
- LOG.info("Training " + " with " + features.size() + " features");
-
- spec.createRDDs(sparkContext);
-
- RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
- BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
- model.setTable(table);
- model.setParams(Arrays.asList(params));
- model.setLabelColumn(label);
- model.setFeatureColumns(features);
- return model;
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- protected HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-
- /**
- * Parses the params.
- *
- * @param args the args
- */
- public void parseParams(String[] args) {
- if (args.length % 2 != 0) {
- throw new IllegalArgumentException("Invalid number of params " + args.length);
- }
-
- params = new LinkedHashMap<String, String>();
-
- for (int i = 0; i < args.length; i += 2) {
- if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
- if (features == null) {
- features = new ArrayList<String>();
- }
- features.add(args[i + 1]);
- } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
- label = args[i + 1];
- } else {
- params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
- }
- }
-
- if (params.containsKey("trainingFraction")) {
- // Get training Fraction
- String trainingFractionStr = params.get("trainingFraction");
- try {
- trainingFraction = Double.parseDouble(trainingFractionStr);
- useTrainingFraction = true;
- } catch (NumberFormatException nfe) {
- throw new IllegalArgumentException("Invalid training fraction", nfe);
- }
- }
-
- if (params.containsKey("partition") || params.containsKey("p")) {
- partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
- }
-
- parseAlgoParams(params);
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public double getParamValue(String param, double defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Double.parseDouble(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as double.");
- }
- }
- return defaultVal;
- }
-
- /**
- * Gets the param value.
- *
- * @param param the param
- * @param defaultVal the default val
- * @return the param value
- */
- public int getParamValue(String param, int defaultVal) {
- if (params.containsKey(param)) {
- try {
- return Integer.parseInt(params.get(param));
- } catch (NumberFormatException nfe) {
- LOG.warn("Couldn't parse param value: " + param + " as integer.");
- }
- }
- return defaultVal;
- }
-
- public String getName() {
- return name;
- }
-
- public String getDescription() {
- return description;
- }
-
- public Map<String, String> getArgUsage() {
- Map<String, String> usage = new LinkedHashMap<String, String>();
- Class<?> clz = this.getClass();
- // Put class name and description as well as part of the usage
- Algorithm algorithm = clz.getAnnotation(Algorithm.class);
- if (algorithm != null) {
- usage.put("Algorithm Name", algorithm.name());
- usage.put("Algorithm Description", algorithm.description());
- }
-
- // Get all algo params including base algo params
- while (clz != null) {
- for (Field field : clz.getDeclaredFields()) {
- AlgoParam param = field.getAnnotation(AlgoParam.class);
- if (param != null) {
- usage.put("[param] " + param.name(), param.help() + " Default Value = " + param.defaultValue());
- }
- }
-
- if (clz.equals(BaseSparkAlgo.class)) {
- break;
- }
- clz = clz.getSuperclass();
- }
- return usage;
- }
-
- /**
- * Parses the algo params.
- *
- * @param params the params
- */
- public abstract void parseAlgoParams(Map<String, String> params);
-
- /**
- * Train internal.
- *
- * @param modelId the model id
- * @param trainingRDD the training rdd
- * @return the base spark classification model
- * @throws LensException the lens exception
- */
- protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException;
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
deleted file mode 100644
index a6d66c5..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/DecisionTreeAlgo.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.DecisionTreeClassificationModel;
-import org.apache.lens.ml.spark.models.SparkDecisionTreeModel;
-
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.tree.DecisionTree$;
-import org.apache.spark.mllib.tree.configuration.Algo$;
-import org.apache.spark.mllib.tree.impurity.Entropy$;
-import org.apache.spark.mllib.tree.impurity.Gini$;
-import org.apache.spark.mllib.tree.impurity.Impurity;
-import org.apache.spark.mllib.tree.impurity.Variance$;
-import org.apache.spark.mllib.tree.model.DecisionTreeModel;
-import org.apache.spark.rdd.RDD;
-
-import scala.Enumeration;
-
-/**
- * The Class DecisionTreeAlgo.
- */
-@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
-public class DecisionTreeAlgo extends BaseSparkAlgo {
-
- /** The algo. */
- @AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
- private Enumeration.Value algo;
-
- /** The decision tree impurity. */
- @AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
- + "Allowed values are 'gini', 'entropy' and 'variance'")
- private Impurity decisionTreeImpurity;
-
- /** The max depth. */
- @AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
- defaultValue = "100")
- private int maxDepth;
-
- /**
- * Instantiates a new decision tree algo.
- *
- * @param name the name
- * @param description the description
- */
- public DecisionTreeAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- String dtreeAlgoName = params.get("algo");
- if ("classification".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Classification();
- } else if ("regression".equalsIgnoreCase(dtreeAlgoName)) {
- algo = Algo$.MODULE$.Regression();
- }
-
- String impurity = params.get("impurity");
- if ("gini".equals(impurity)) {
- decisionTreeImpurity = Gini$.MODULE$;
- } else if ("entropy".equals(impurity)) {
- decisionTreeImpurity = Entropy$.MODULE$;
- } else if ("variance".equals(impurity)) {
- decisionTreeImpurity = Variance$.MODULE$;
- }
-
- maxDepth = getParamValue("maxDepth", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
- return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
deleted file mode 100644
index 7ca5a79..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/KMeansAlgo.java
+++ /dev/null
@@ -1,163 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.List;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.*;
-import org.apache.lens.ml.spark.HiveTableRDD;
-import org.apache.lens.ml.spark.models.KMeansClusteringModel;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hive.hcatalog.data.HCatRecord;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.clustering.KMeans;
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-
-import scala.Tuple2;
-
-/**
- * The Class KMeansAlgo.
- */
-@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
-public class KMeansAlgo implements MLAlgo {
-
- /** The conf. */
- private transient LensConf conf;
-
- /** The spark context. */
- private JavaSparkContext sparkContext;
-
- /** The part filter. */
- @AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
- private String partFilter = null;
-
- /** The k. */
- @AlgoParam(name = "k", help = "Number of cluster")
- private int k;
-
- /** The max iterations. */
- @AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
- private int maxIterations = 100;
-
- /** The runs. */
- @AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
- private int runs = 1;
-
- /** The initialization mode. */
- @AlgoParam(name = "initializationMode",
- help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
- private String initializationMode = "k-means||";
-
- @Override
- public String getName() {
- return getClass().getAnnotation(Algorithm.class).name();
- }
-
- @Override
- public String getDescription() {
- return getClass().getAnnotation(Algorithm.class).description();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
- */
- @Override
- public void configure(LensConf configuration) {
- this.conf = configuration;
- }
-
- @Override
- public LensConf getConf() {
- return conf;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
- List<String> features = AlgoArgParser.parseArgs(this, params);
- final int[] featurePositions = new int[features.size()];
- final int NUM_FEATURES = features.size();
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
- try {
- // Map feature names to positions
- Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
- List<FieldSchema> allCols = tbl.getAllCols();
- int f = 0;
- for (int i = 0; i < tbl.getAllCols().size(); i++) {
- String colName = allCols.get(i).getName();
- if (features.contains(colName)) {
- featurePositions[f++] = i;
- }
- }
-
- rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
- JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
- @Override
- public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
- HCatRecord hCatRecord = v1._2();
- double[] arr = new double[NUM_FEATURES];
- for (int i = 0; i < NUM_FEATURES; i++) {
- Object val = hCatRecord.get(featurePositions[i]);
- arr[i] = val == null ? 0d : (Double) val;
- }
- return Vectors.dense(arr);
- }
- });
-
- KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
- return new KMeansClusteringModel(modelId, model);
- } catch (Exception e) {
- throw new LensException("KMeans algo failed for " + db + "." + table, e);
- }
- }
-
- /**
- * To hive conf.
- *
- * @param conf the conf
- * @return the hive conf
- */
- private HiveConf toHiveConf(LensConf conf) {
- HiveConf hiveConf = new HiveConf();
- for (String key : conf.getProperties().keySet()) {
- hiveConf.set(key, conf.getProperties().get(key));
- }
- return hiveConf;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
deleted file mode 100644
index 106b3c5..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/LogisticRegressionAlgo.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.LogitRegressionClassificationModel;
-
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class LogisticRegressionAlgo.
- */
-@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
-public class LogisticRegressionAlgo extends BaseSparkAlgo {
-
- /** The iterations. */
- @AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
- private int iterations;
-
- /** The step size. */
- @AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The min batch fraction. */
- @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /**
- * Instantiates a new logistic regression algo.
- *
- * @param name the name
- * @param description the description
- */
- public LogisticRegressionAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- iterations = getParamValue("iterations", 100);
- stepSize = getParamValue("stepSize", 1.0d);
- minBatchFraction = getParamValue("minBatchFraction", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
- minBatchFraction);
- return new LogitRegressionClassificationModel(modelId, lrModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
deleted file mode 100644
index f7652d1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/NaiveBayesAlgo.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.NaiveBayesClassificationModel;
-
-import org.apache.spark.mllib.classification.NaiveBayes;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class NaiveBayesAlgo.
- */
-@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
-public class NaiveBayesAlgo extends BaseSparkAlgo {
-
- /** The lambda. */
- @AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
- private double lambda = 1.0;
-
- /**
- * Instantiates a new naive bayes algo.
- *
- * @param name the name
- * @param description the description
- */
- public NaiveBayesAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- lambda = getParamValue("lambda", 1.0d);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
deleted file mode 100644
index 09251b7..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/algos/SVMAlgo.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.algos;
-
-import java.util.Map;
-
-import org.apache.lens.api.LensException;
-import org.apache.lens.ml.AlgoParam;
-import org.apache.lens.ml.Algorithm;
-import org.apache.lens.ml.spark.models.BaseSparkClassificationModel;
-import org.apache.lens.ml.spark.models.SVMClassificationModel;
-
-import org.apache.spark.mllib.classification.SVMModel;
-import org.apache.spark.mllib.classification.SVMWithSGD;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.rdd.RDD;
-
-/**
- * The Class SVMAlgo.
- */
-@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
-public class SVMAlgo extends BaseSparkAlgo {
-
- /** The min batch fraction. */
- @AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
- private double minBatchFraction;
-
- /** The reg param. */
- @AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
- private double regParam;
-
- /** The step size. */
- @AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
- private double stepSize;
-
- /** The iterations. */
- @AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
- private int iterations;
-
- /**
- * Instantiates a new SVM algo.
- *
- * @param name the name
- * @param description the description
- */
- public SVMAlgo(String name, String description) {
- super(name, description);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#parseAlgoParams(java.util.Map)
- */
- @Override
- public void parseAlgoParams(Map<String, String> params) {
- minBatchFraction = getParamValue("minBatchFraction", 1.0);
- regParam = getParamValue("regParam", 1.0);
- stepSize = getParamValue("stepSize", 1.0);
- iterations = getParamValue("iterations", 100);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
- */
- @Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
- throws LensException {
- SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
- return new SVMClassificationModel(modelId, svmModel);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
deleted file mode 100644
index deee1b7..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/BaseSparkClassificationModel.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.lens.ml.ClassifierBaseModel;
-
-import org.apache.spark.mllib.classification.ClassificationModel;
-import org.apache.spark.mllib.linalg.Vectors;
-
-/**
- * The Class BaseSparkClassificationModel.
- *
- * @param <MODEL> the generic type
- */
-public class BaseSparkClassificationModel<MODEL extends ClassificationModel> extends ClassifierBaseModel {
-
- /** The model id. */
- private final String modelId;
-
- /** The spark model. */
- private final MODEL sparkModel;
-
- /**
- * Instantiates a new base spark classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public BaseSparkClassificationModel(String modelId, MODEL model) {
- this.modelId = modelId;
- this.sparkModel = model;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public Double predict(Object... args) {
- return sparkModel.predict(Vectors.dense(getFeatureVector(args)));
- }
-
- @Override
- public String getId() {
- return modelId;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
deleted file mode 100644
index 0460024..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/DecisionTreeClassificationModel.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-/**
- * The Class DecisionTreeClassificationModel.
- */
-public class DecisionTreeClassificationModel extends BaseSparkClassificationModel<SparkDecisionTreeModel> {
-
- /**
- * Instantiates a new decision tree classification model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public DecisionTreeClassificationModel(String modelId, SparkDecisionTreeModel model) {
- super(modelId, model);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-lens/blob/0f5ea4c7/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
----------------------------------------------------------------------
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
deleted file mode 100644
index 959d9f4..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/spark/models/KMeansClusteringModel.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.spark.models;
-
-import org.apache.lens.ml.MLModel;
-
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vectors;
-
-/**
- * The Class KMeansClusteringModel.
- */
-public class KMeansClusteringModel extends MLModel<Integer> {
-
- /** The model. */
- private final KMeansModel model;
-
- /** The model id. */
- private final String modelId;
-
- /**
- * Instantiates a new k means clustering model.
- *
- * @param modelId the model id
- * @param model the model
- */
- public KMeansClusteringModel(String modelId, KMeansModel model) {
- this.model = model;
- this.modelId = modelId;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public Integer predict(Object... args) {
- // Convert the params to array of double
- double[] arr = new double[args.length];
- for (int i = 0; i < args.length; i++) {
- if (args[i] != null) {
- arr[i] = (Double) args[i];
- } else {
- arr[i] = 0d;
- }
- }
-
- return model.predict(Vectors.dense(arr));
- }
-}