You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2014/06/13 10:14:29 UTC

[2/2] git commit: Add LinearRegression example (Java API)

Add LinearRegression example (Java API)


Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/22949fce
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/22949fce
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/22949fce

Branch: refs/heads/master
Commit: 22949fce59dedc24004e0f66e499eaed11274eec
Parents: 6c6fc89
Author: wilsoncao <27...@qq.com>
Authored: Thu Jun 12 16:54:59 2014 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Jun 13 10:13:45 2014 +0200

----------------------------------------------------------------------
 .../example/java/ml/LinearRegression.java       | 314 +++++++++++++++++++
 .../java/ml/util/LinearRegressionData.java      |  62 ++++
 .../ml/util/LinearRegressionDataGenerator.java  | 109 +++++++
 3 files changed, 485 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
new file mode 100644
index 0000000..8f7c098
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/LinearRegression.java
@@ -0,0 +1,314 @@
+/***********************************************************************************************************************
+ *
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ *
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml;
+
+import java.io.Serializable;
+import java.util.Collection;
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.api.java.IterativeDataSet;
+import eu.stratosphere.api.java.functions.MapFunction;
+import eu.stratosphere.api.java.functions.ReduceFunction;
+import eu.stratosphere.api.java.tuple.Tuple2;
+import eu.stratosphere.configuration.Configuration;
+import eu.stratosphere.example.java.ml.util.LinearRegressionData;
+
+/**
+ * This example implements a basic Linear Regression using batch gradient descent algorithm.
+ *
+ * <p>
+ * Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering algorithm and works as follows:<br>
+ * Giving a data set and target set, the BGD try to find out the best parameters for the data set to fit the target set.
+ * In each iteration, the algorithm computes the gradient of the cost function and use it to update all the parameters.
+ * The algorithm terminates after a fixed number of iterations (as in this implementation)
+ * With enough iteration, the algorithm can minimize the cost function and find the best parameters
+ * This is the Wikipedia entry for the <a href = "http://en.wikipedia.org/wiki/Linear_regression">Linear regression</a> and <a href = "http://en.wikipedia.org/wiki/Gradient_descent">Gradient descent algorithm</a>.
+ * 
+ * <p>
+ * This implementation works on one-dimensional data. And find the two-dimensional theta.<br>
+ * It find the best Theta parameter to fit the target.
+ * 
+ * <p>
+ * Input files are plain text files and must be formatted as follows:
+ * <ul>
+ * <li>Data points are represented as two double values separated by a blank character. The first one represent the X(the training data) and the second represent the Y(target).
+ * Data points are separated by newline characters.<br>
+ * For example <code>"-0.02 -0.04\n5.3 10.6\n"</code> gives two data points (x=-0.02, y=-0.04) and (x=5.3, y=10.6).
+ * </ul>
+ * 
+ * <p>
+ * This example shows how to use:
+ * <ul>
+ * <li> Bulk iterations
+ * <li> Broadcast variables in bulk iterations
+ * <li> Custom Java objects (PoJos)
+ * </ul>
+ */
+
+/**
+ * A linearRegression example to solve the y = theta0 + theta1*x problem.
+ */
+@SuppressWarnings("serial")
+public class LinearRegression {
+
+	// *************************************************************************
+	//     PROGRAM
+	// *************************************************************************
+
+	public static void main(String[] args) throws Exception{
+
+		if(!parseParameters(args)) {
+			return;
+		}
+
+		// set up execution environment
+
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+		// get input x data from elements
+		DataSet<Data> data = getDataSet(env);
+
+		// get the parameters from elements
+		DataSet<Params> parameters = getParamsDataSet(env);
+
+		// set number of bulk iterations for SGD linear Regression
+		IterativeDataSet<Params> loop = parameters.iterate(numIterations);
+
+		DataSet<Params> new_parameters = data
+				// compute a single step using every sample
+				.map(new SubUpdate()).withBroadcastSet(loop, "parameters")
+				// sum up all the steps
+				.reduce(new UpdateAccumulator())
+				// average the steps and update all parameters
+				.map(new Update());
+
+		// feed new parameters back into next iteration
+		DataSet<Params> result = loop.closeWith(new_parameters);
+
+		// emit result
+		if(fileOutput) {
+			result.writeAsCsv(outputPath, "\n", " ");
+		} else {
+			result.print();
+		}
+
+		// execute program
+		env.execute("Linear Regression example");
+
+	}
+
+	// *************************************************************************
+	//     DATA TYPES
+	// *************************************************************************
+
+	/**
+	 * A simple data sample, x means the input, and y means the target.
+	 */
+	public static class Data implements Serializable{
+		public double x,y;
+
+		public Data() {};
+
+		public Data(double x ,double y){
+			this.x = x;
+			this.y = y;
+		}
+
+		@Override
+		public String toString() {
+			return "(" + x + "|" + y + ")";
+		}
+
+	}
+
+	/**
+	 * A set of parameters -- theta0, theta1.
+	 */
+	public static class Params implements Serializable{
+
+		private double theta0,theta1;
+
+		public Params(){};
+
+		public Params(double x0, double x1){
+			this.theta0 = x0;
+			this.theta1 = x1;
+		}
+
+		@Override
+		public String toString() {
+			return "(" + theta0 + "|" + theta1 + ")";
+		}
+
+		public double getTheta0() {
+			return theta0;
+		}
+
+		public double getTheta1() {
+			return theta1;
+		}
+
+		public void setTheta0(double theta0) {
+			this.theta0 = theta0;
+		}
+
+		public void setTheta1(double theta1) {
+			this.theta1 = theta1;
+		}
+
+		public Params div(Integer a){
+			this.theta0 = theta0 / a ;
+			this.theta1 = theta1 / a ;
+			return this;
+		}
+
+	}
+
+	// *************************************************************************
+	//     USER FUNCTIONS
+	// *************************************************************************
+
+	/** Converts a Tuple2<Double,Double> into a Data. */
+	public static final class TupleDataConverter extends MapFunction<Tuple2<Double, Double>, Data> {
+
+		@Override
+		public Data map(Tuple2<Double, Double> t) throws Exception {
+			return new Data(t.f0, t.f1);
+		}
+	}
+
+	/** Converts a Tuple2<Double,Double> into a Params. */
+	public static final class TupleParamsConverter extends MapFunction<Tuple2<Double, Double>,Params> {
+
+		@Override
+		public Params map(Tuple2<Double, Double> t)throws Exception {
+			return new Params(t.f0,t.f1);
+		}
+	}
+
+	/**
+	 * Compute a single BGD type update for every parameters.
+	 */
+	public static class SubUpdate extends MapFunction<Data,Tuple2<Params,Integer>>{
+
+		private Collection<Params> parameters; 
+
+		private Params parameter;
+
+		private int count = 1;
+
+		/** Reads the parameters from a broadcast variable into a collection. */
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			this.parameters = getRuntimeContext().getBroadcastVariable("parameters");
+		}
+
+		@Override
+		public Tuple2<Params, Integer> map(Data in) throws Exception {
+
+			for(Params p : parameters){
+				this.parameter = p; 
+			}
+
+			double theta_0 = parameter.theta0 - 0.01*((parameter.theta0 + (parameter.theta1*in.x)) - in.y);
+			double theta_1 = parameter.theta1 - 0.01*(((parameter.theta0 + (parameter.theta1*in.x)) - in.y) * in.x);
+
+			return new Tuple2<Params,Integer>(new Params(theta_0,theta_1),count);
+		}
+	}
+
+	/**  
+	 * Accumulator all the update.
+	 * */
+	public static class UpdateAccumulator extends ReduceFunction<Tuple2<Params, Integer>> {
+
+		@Override
+		public Tuple2<Params, Integer> reduce(Tuple2<Params, Integer> val1, Tuple2<Params, Integer> val2) {
+
+			double new_theta0 = val1.f0.theta0 + val2.f0.theta0;
+			double new_theta1 = val1.f0.theta1 + val2.f0.theta1;
+			Params result = new Params(new_theta0,new_theta1);
+			return new Tuple2<Params, Integer>( result, val1.f1 + val2.f1);
+
+		}
+	}
+
+	/**
+	 * Compute the final update by average them.
+	 */
+	public static class Update extends MapFunction<Tuple2<Params, Integer>,Params>{
+
+		@Override
+		public Params map(Tuple2<Params, Integer> arg0) throws Exception {
+
+			return arg0.f0.div(arg0.f1);
+
+		}
+
+	}
+	// *************************************************************************
+	//     UTIL METHODS
+	// *************************************************************************
+
+	private static boolean fileOutput = false;
+	private static String dataPath = null;
+	private static String outputPath = null;
+	private static int numIterations = 10;
+
+	private static boolean parseParameters(String[] programArguments) {
+
+		if(programArguments.length > 0) {
+			// parse input arguments
+			fileOutput = true;
+			if(programArguments.length == 3) {
+				dataPath = programArguments[0];
+				outputPath = programArguments[1];
+				numIterations = Integer.parseInt(programArguments[2]);
+			} else {
+				System.err.println("Usage: LinearRegression <data path> <result path> <num iterations>");
+				return false;
+			}
+		} else {
+			System.out.println("Executing Linear Regression example with default parameters and built-in default data.");
+			System.out.println("  Provide parameters to read input data from files.");
+			System.out.println("  See the documentation for the correct format of input files.");
+			System.out.println("  We provide a data generator to create synthetic input files for this program.");
+			System.out.println("  Usage: LinearRegression <data path> <result path> <num iterations>");
+		}
+		return true;
+	}
+
+	private static DataSet<Data> getDataSet(ExecutionEnvironment env) {
+		if(fileOutput) {
+			// read data from CSV file
+			return env.readCsvFile(dataPath)
+					.fieldDelimiter(' ')
+					.includeFields(true, true)
+					.types(Double.class, Double.class)
+					.map(new TupleDataConverter());
+		} else {
+			return LinearRegressionData.getDefaultDataDataSet(env);
+		}
+	}
+
+	private static DataSet<Params> getParamsDataSet(ExecutionEnvironment env) {
+
+		return LinearRegressionData.getDefaultParamsDataSet(env);
+
+	}
+
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
new file mode 100644
index 0000000..39d86ec
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionData.java
@@ -0,0 +1,62 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml.util;
+
+import eu.stratosphere.api.java.DataSet;
+import eu.stratosphere.api.java.ExecutionEnvironment;
+import eu.stratosphere.example.java.ml.LinearRegression.Data;
+import eu.stratosphere.example.java.ml.LinearRegression.Params;
+
+/**
+ * Provides the default data sets used for the Linear Regression example program.
+ * The default data sets are used, if no parameters are given to the program.
+ *
+ */
+public class LinearRegressionData{
+
+	public static DataSet<Params> getDefaultParamsDataSet(ExecutionEnvironment env){
+
+		return env.fromElements(
+				new Params(0.0,0.0)
+				);
+	}
+
+	public static DataSet<Data> getDefaultDataDataSet(ExecutionEnvironment env){
+
+		return env.fromElements(
+				new Data(0.5,1.0),
+				new Data(1.0,2.0),
+				new Data(2.0,4.0),
+				new Data(3.0,6.0),
+				new Data(4.0,8.0),
+				new Data(5.0,10.0),
+				new Data(6.0,12.0),
+				new Data(7.0,14.0),
+				new Data(8.0,16.0),
+				new Data(9.0,18.0),
+				new Data(10.0,20.0),
+				new Data(-0.08,-0.16),
+				new Data(0.13,0.26),
+				new Data(-1.17,-2.35),
+				new Data(1.72,3.45),
+				new Data(1.70,3.41),
+				new Data(1.20,2.41),
+				new Data(-0.59,-1.18),
+				new Data(0.28,0.57),
+				new Data(1.65,3.30),
+				new Data(-0.55,-1.08)
+				);
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/22949fce/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
----------------------------------------------------------------------
diff --git a/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
new file mode 100644
index 0000000..fe34681
--- /dev/null
+++ b/stratosphere-examples/stratosphere-java-examples/src/main/java/eu/stratosphere/example/java/ml/util/LinearRegressionDataGenerator.java
@@ -0,0 +1,109 @@
+/***********************************************************************************************************************
+ * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ **********************************************************************************************************************/
+
+package eu.stratosphere.example.java.ml.util;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.text.DecimalFormat;
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Generates data for the {@link LinearRegression} example program.
+ */
+public class LinearRegressionDataGenerator {
+
+	static {
+		Locale.setDefault(Locale.US);
+	}
+
+	private static final String POINTS_FILE = "data";
+	private static final long DEFAULT_SEED = 4650285087650871364L;
+	private static final int DIMENSIONALITY = 1;
+	private static final DecimalFormat FORMAT = new DecimalFormat("#0.00");
+	private static final char DELIMITER = ' ';
+
+	/**
+	 * Main method to generate data for the {@link LinearRegression} example program.
+	 * <p>
+	 * The generator creates to files:
+	 * <ul>
+	 * <li><code>{tmp.dir}/data</code> for the data points
+	 * </ul> 
+	 * 
+	 * @param args 
+	 * <ol>
+	 * <li>Int: Number of data points
+	 * <li><b>Optional</b> Long: Random seed
+	 * </ol>
+	 */
+	public static void main(String[] args) throws IOException {
+
+		System.out.println(args.length);
+
+		// check parameter count
+		if (args.length < 1) {
+			System.out.println("LinearRegressionDataGenerator <numberOfDataPoints> [<seed>]");
+			System.exit(1);
+		}
+
+		// parse parameters
+		final int numDataPoints = Integer.parseInt(args[0]);
+		final long firstSeed = args.length > 1 ? Long.parseLong(args[4]) : DEFAULT_SEED;
+		final Random random = new Random(firstSeed);
+		final String tmpDir = System.getProperty("java.io.tmpdir");
+
+		// write the points out
+		BufferedWriter pointsOut = null;
+		try {
+			pointsOut = new BufferedWriter(new FileWriter(new File(POINTS_FILE)));
+			StringBuilder buffer = new StringBuilder();
+
+			// DIMENSIONALITY + 1 means that the number of x(dimensionality) and target y
+			double[] point = new double[DIMENSIONALITY+1];
+
+			for (int i = 1; i <= numDataPoints; i++) {
+				point[0] = random.nextGaussian();
+				point[1] = 2 * point[0] + 0.01*random.nextGaussian();
+				writePoint(point, buffer, pointsOut);
+			}
+
+		}
+		finally {
+			if (pointsOut != null) {
+				pointsOut.close();
+			}
+		}
+
+		System.out.println("Wrote "+numDataPoints+" data points to "+tmpDir+"/"+POINTS_FILE);
+	}
+
+
+	private static void writePoint(double[] data, StringBuilder buffer, BufferedWriter out) throws IOException {
+		buffer.setLength(0);
+
+		// write coordinates
+		for (int j = 0; j < data.length; j++) {
+			buffer.append(FORMAT.format(data[j]));
+			if(j < data.length - 1) {
+				buffer.append(DELIMITER);
+			}
+		}
+
+		out.write(buffer.toString());
+		out.newLine();
+	}
+}