You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2014/06/13 10:14:29 UTC
[2/2] git commit: Add LinearRegression example (Java API)
Add LinearRegression example (Java API)
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/22949fce
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/22949fce
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/22949fce
Branch: refs/heads/master
Commit: 22949fce59dedc24004e0f66e499eaed11274eec
Parents: 6c6fc89
Author: wilsoncao <27...@qq.com>
Authored: Thu Jun 12 16:54:59 2014 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Jun 13 10:13:45 2014 +0200
----------------------------------------------------------------------
.../example/java/ml/LinearRegression.java | 314 +++++++++++++++++++
.../java/ml/util/LinearRegressionData.java | 62 ++++
.../ml/util/LinearRegressionDataGenerator.java | 109 +++++++
3 files changed, 485 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
new file mode 100644
index 0000000..8f7c098
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
@@ -0,0 +1,314 @@
+/***********************************************************************************************************************
+ *
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ *
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml;
+
+import java.io.Serializable;
+import java.util.Collection;
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.api.java.IterativeDataSet;
+import eu.stratosphere.api.java.functions.MapFunction;
+import eu.stratosphere.api.java.functions.ReduceFunction;
+import eu.stratosphere.api.java.tuple.Tuple2;
+import eu.stratosphere.configuration.Configuration;
+import eu.stratosphere.example.java.ml.util.LinearRegressionData;
+
+/**
+ * This example implements a basic Linear Regression using batch gradient descent algorithm.
+ *
+ * <p>
+ * Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering algorithm and works as follows:<br>
+ * Giving a data set and target set, the BGD try to find out the best parameters for the data set to fit the target set.
+ * In each iteration, the algorithm computes the gradient of the cost function and use it to update all the parameters.
+ * The algorithm terminates after a fixed number of iterations (as in this implementation)
+ * With enough iteration, the algorithm can minimize the cost function and find the best parameters
+ * This is the Wikipedia entry for the <a href = "http://en.wikipedia.org/wiki/Linear_regression">Linear regression</a> and <a href = "http://en.wikipedia.org/wiki/Gradient_descent">Gradient descent algorithm</a>.
+ *
+ * <p>
+ * This implementation works on one-dimensional data. And find the two-dimensional theta.<br>
+ * It find the best Theta parameter to fit the target.
+ *
+ * <p>
+ * Input files are plain text files and must be formatted as follows:
+ * <ul>
+ * <li>Data points are represented as two double values separated by a blank character. The first one represent the X(the training data) and the second represent the Y(target).
+ * Data points are separated by newline characters.<br>
+ * For example <code>"-0.02 -0.04\n5.3 10.6\n"</code> gives two data points (x=-0.02, y=-0.04) and (x=5.3, y=10.6).
+ * </ul>
+ *
+ * <p>
+ * This example shows how to use:
+ * <ul>
+ * <li> Bulk iterations
+ * <li> Broadcast variables in bulk iterations
+ * <li> Custom Java objects (PoJos)
+ * </ul>
+ */
+
+/**
+ * A linearRegression example to solve the y = theta0 + theta1*x problem.
+ */
+@SuppressWarnings("serial")
+public class LinearRegression {
+
+ // *************************************************************************
+ // PROGRAM
+ // *************************************************************************
+
+ public static void main(String[] args) throws Exception{
+
+ if(!parseParameters(args)) {
+ return;
+ }
+
+ // set up execution environment
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+ // get input x data from elements
+ DataSet<Data> data = getDataSet(env);
+
+ // get the parameters from elements
+ DataSet<Params> parameters = getParamsDataSet(env);
+
+ // set number of bulk iterations for SGD linear Regression
+ IterativeDataSet<Params> loop = parameters.iterate(numIterations);
+
+ DataSet<Params> new_parameters = data
+ // compute a single step using every sample
+ .map(new SubUpdate()).withBroadcastSet(loop, "parameters")
+ // sum up all the steps
+ .reduce(new UpdateAccumulator())
+ // average the steps and update all parameters
+ .map(new Update());
+
+ // feed new parameters back into next iteration
+ DataSet<Params> result = loop.closeWith(new_parameters);
+
+ // emit result
+ if(fileOutput) {
+ result.writeAsCsv(outputPath, "\n", " ");
+ } else {
+ result.print();
+ }
+
+ // execute program
+ env.execute("Linear Regression example");
+
+ }
+
+ // *************************************************************************
+ // DATA TYPES
+ // *************************************************************************
+
+ /**
+ * A simple data sample, x means the input, and y means the target.
+ */
+ public static class Data implements Serializable{
+ public double x,y;
+
+ public Data() {};
+
+ public Data(double x ,double y){
+ this.x = x;
+ this.y = y;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + x + "|" + y + ")";
+ }
+
+ }
+
+ /**
+ * A set of parameters -- theta0, theta1.
+ */
+ public static class Params implements Serializable{
+
+ private double theta0,theta1;
+
+ public Params(){};
+
+ public Params(double x0, double x1){
+ this.theta0 = x0;
+ this.theta1 = x1;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + theta0 + "|" + theta1 + ")";
+ }
+
+ public double getTheta0() {
+ return theta0;
+ }
+
+ public double getTheta1() {
+ return theta1;
+ }
+
+ public void setTheta0(double theta0) {
+ this.theta0 = theta0;
+ }
+
+ public void setTheta1(double theta1) {
+ this.theta1 = theta1;
+ }
+
+ public Params div(Integer a){
+ this.theta0 = theta0 / a ;
+ this.theta1 = theta1 / a ;
+ return this;
+ }
+
+ }
+
+ // *************************************************************************
+ // USER FUNCTIONS
+ // *************************************************************************
+
+ /** Converts a Tuple2<Double,Double> into a Data. */
+ public static final class TupleDataConverter extends MapFunction<Tuple2<Double, Double>, Data> {
+
+ @Override
+ public Data map(Tuple2<Double, Double> t) throws Exception {
+ return new Data(t.f0, t.f1);
+ }
+ }
+
+ /** Converts a Tuple2<Double,Double> into a Params. */
+ public static final class TupleParamsConverter extends MapFunction<Tuple2<Double, Double>,Params> {
+
+ @Override
+ public Params map(Tuple2<Double, Double> t)throws Exception {
+ return new Params(t.f0,t.f1);
+ }
+ }
+
+ /**
+ * Compute a single BGD type update for every parameters.
+ */
+ public static class SubUpdate extends MapFunction<Data,Tuple2<Params,Integer>>{
+
+ private Collection<Params> parameters;
+
+ private Params parameter;
+
+ private int count = 1;
+
+ /** Reads the parameters from a broadcast variable into a collection. */
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ this.parameters = getRuntimeContext().getBroadcastVariable("parameters");
+ }
+
+ @Override
+ public Tuple2<Params, Integer> map(Data in) throws Exception {
+
+ for(Params p : parameters){
+ this.parameter = p;
+ }
+
+ double theta_0 = parameter.theta0 - 0.01*((parameter.theta0 + (parameter.theta1*in.x)) - in.y);
+ double theta_1 = parameter.theta1 - 0.01*(((parameter.theta0 + (parameter.theta1*in.x)) - in.y) * in.x);
+
+ return new Tuple2<Params,Integer>(new Params(theta_0,theta_1),count);
+ }
+ }
+
+ /**
+ * Accumulator all the update.
+ * */
+ public static class UpdateAccumulator extends ReduceFunction<Tuple2<Params, Integer>> {
+
+ @Override
+ public Tuple2<Params, Integer> reduce(Tuple2<Params, Integer> val1, Tuple2<Params, Integer> val2) {
+
+ double new_theta0 = val1.f0.theta0 + val2.f0.theta0;
+ double new_theta1 = val1.f0.theta1 + val2.f0.theta1;
+ Params result = new Params(new_theta0,new_theta1);
+ return new Tuple2<Params, Integer>( result, val1.f1 + val2.f1);
+
+ }
+ }
+
+ /**
+ * Compute the final update by average them.
+ */
+ public static class Update extends MapFunction<Tuple2<Params, Integer>,Params>{
+
+ @Override
+ public Params map(Tuple2<Params, Integer> arg0) throws Exception {
+
+ return arg0.f0.div(arg0.f1);
+
+ }
+
+ }
+ // *************************************************************************
+ // UTIL METHODS
+ // *************************************************************************
+
+ private static boolean fileOutput = false;
+ private static String dataPath = null;
+ private static String outputPath = null;
+ private static int numIterations = 10;
+
+ private static boolean parseParameters(String[] programArguments) {
+
+ if(programArguments.length > 0) {
+ // parse input arguments
+ fileOutput = true;
+ if(programArguments.length == 3) {
+ dataPath = programArguments[0];
+ outputPath = programArguments[1];
+ numIterations = Integer.parseInt(programArguments[2]);
+ } else {
+ System.err.println("Usage: LinearRegression <data path> <result path> <num iterations>");
+ return false;
+ }
+ } else {
+ System.out.println("Executing Linear Regression example with default parameters and built-in default data.");
+ System.out.println(" Provide parameters to read input data from files.");
+ System.out.println(" See the documentation for the correct format of input files.");
+ System.out.println(" We provide a data generator to create synthetic input files for this program.");
+ System.out.println(" Usage: LinearRegression <data path> <result path> <num iterations>");
+ }
+ return true;
+ }
+
+ private static DataSet<Data> getDataSet(ExecutionEnvironment env) {
+ if(fileOutput) {
+ // read data from CSV file
+ return env.readCsvFile(dataPath)
+ .fieldDelimiter(' ')
+ .includeFields(true, true)
+ .types(Double.class, Double.class)
+ .map(new TupleDataConverter());
+ } else {
+ return LinearRegressionData.getDefaultDataDataSet(env);
+ }
+ }
+
+ private static DataSet<Params> getParamsDataSet(ExecutionEnvironment env) {
+
+ return LinearRegressionData.getDefaultParamsDataSet(env);
+
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
new file mode 100644
index 0000000..39d86ec
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
@@ -0,0 +1,62 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml.util;
+
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.example.java.ml.LinearRegression.Data;
+import eu.stratosphere.example.java.ml.LinearRegression.Params;
+
+/**
+ * Provides the default data sets used for the Linear Regression example program.
+ * The default data sets are used, if no parameters are given to the program.
+ *
+ */
+public class LinearRegressionData{
+
+ public static DataSet<Params> getDefaultParamsDataSet(ExecutionEnvironment env){
+
+ return env.fromElements(
+ new Params(0.0,0.0)
+ );
+ }
+
+ public static DataSet<Data> getDefaultDataDataSet(ExecutionEnvironment env){
+
+ return env.fromElements(
+ new Data(0.5,1.0),
+ new Data(1.0,2.0),
+ new Data(2.0,4.0),
+ new Data(3.0,6.0),
+ new Data(4.0,8.0),
+ new Data(5.0,10.0),
+ new Data(6.0,12.0),
+ new Data(7.0,14.0),
+ new Data(8.0,16.0),
+ new Data(9.0,18.0),
+ new Data(10.0,20.0),
+ new Data(-0.08,-0.16),
+ new Data(0.13,0.26),
+ new Data(-1.17,-2.35),
+ new Data(1.72,3.45),
+ new Data(1.70,3.41),
+ new Data(1.20,2.41),
+ new Data(-0.59,-1.18),
+ new Data(0.28,0.57),
+ new Data(1.65,3.30),
+ new Data(-0.55,-1.08)
+ );
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
new file mode 100644
index 0000000..fe34681
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
@@ -0,0 +1,109 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml.util;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.text.DecimalFormat;
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Generates data for the {@link LinearRegression} example program.
+ */
+public class LinearRegressionDataGenerator {
+
+ static {
+ Locale.setDefault(Locale.US);
+ }
+
+ private static final String POINTS_FILE = "data";
+ private static final long DEFAULT_SEED = 4650285087650871364L;
+ private static final int DIMENSIONALITY = 1;
+ private static final DecimalFormat FORMAT = new DecimalFormat("#0.00");
+ private static final char DELIMITER = ' ';
+
+ /**
+ * Main method to generate data for the {@link LinearRegression} example program.
+ * <p>
+ * The generator creates to files:
+ * <ul>
+ * <li><code>{tmp.dir}/data</code> for the data points
+ * </ul>
+ *
+ * @param args
+ * <ol>
+ * <li>Int: Number of data points
+ * <li><b>Optional</b> Long: Random seed
+ * </ol>
+ */
+ public static void main(String[] args) throws IOException {
+
+ System.out.println(args.length);
+
+ // check parameter count
+ if (args.length < 1) {
+ System.out.println("LinearRegressionDataGenerator <numberOfDataPoints> [<seed>]");
+ System.exit(1);
+ }
+
+ // parse parameters
+ final int numDataPoints = Integer.parseInt(args[0]);
+ final long firstSeed = args.length > 1 ? Long.parseLong(args[4]) : DEFAULT_SEED;
+ final Random random = new Random(firstSeed);
+ final String tmpDir = System.getProperty("java.io.tmpdir");
+
+ // write the points out
+ BufferedWriter pointsOut = null;
+ try {
+ pointsOut = new BufferedWriter(new FileWriter(new File(POINTS_FILE)));
+ StringBuilder buffer = new StringBuilder();
+
+ // DIMENSIONALITY + 1 means that the number of x(dimensionality) and target y
+ double[] point = new double[DIMENSIONALITY+1];
+
+ for (int i = 1; i <= numDataPoints; i++) {
+ point[0] = random.nextGaussian();
+ point[1] = 2 * point[0] + 0.01*random.nextGaussian();
+ writePoint(point, buffer, pointsOut);
+ }
+
+ }
+ finally {
+ if (pointsOut != null) {
+ pointsOut.close();
+ }
+ }
+
+ System.out.println("Wrote "+numDataPoints+" data points to "+tmpDir+"/"+POINTS_FILE);
+ }
+
+
+ private static void writePoint(double[] data, StringBuilder buffer, BufferedWriter out) throws IOException {
+ buffer.setLength(0);
+
+ // write coordinates
+ for (int j = 0; j < data.length; j++) {
+ buffer.append(FORMAT.format(data[j]));
+ if(j < data.length - 1) {
+ buffer.append(DELIMITER);
+ }
+ }
+
+ out.write(buffer.toString());
+ out.newLine();
+ }
+}