You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/11/05 04:20:14 UTC
svn commit: r1031413 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/vectorizer/encoders/
examples/src/main/java/org/apache/mahout/classifier/sgd/
Author: tdunning
Date: Fri Nov 5 03:20:13 2010
New Revision: 1031413
URL: http://svn.apache.org/viewvc?rev=1031413&view=rev
Log:
MAHOUT-539 - Added speedup examples and improved some value encoders
to make the speedups work even better.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsv.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java
mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java?rev=1031413&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java Fri Nov 5 03:20:13 2010
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+/**
+ * Provides basic hashing semantics for encoders where the probe locations
+ * depend only on the name of the variable.
+ */
+public abstract class CachingValueEncoder extends FeatureVectorEncoder {
+ private int[] cachedProbes;
+
+ public CachingValueEncoder(String name, int seed) {
+ super(name);
+ cacheProbeLocations(seed);
+ }
+
+ /**
+ * Sets the number of locations in the feature vector that a value should be in.
+ * This causes the cached probe locations to be recomputed.
+ *
+ * @param probes Number of locations to increment.
+ */
+ @Override
+ public void setProbes(int probes) {
+ super.setProbes(probes);
+ cacheProbeLocations(CONTINUOUS_VALUE_HASH_SEED);
+ }
+
+ private void cacheProbeLocations(int seed) {
+ cachedProbes = new int[getProbes()];
+ for (int i = 0; i < getProbes(); i++) {
+ // note that the modulo operation is deferred
+ cachedProbes[i] = (int) MurmurHash.hash64A(bytesForString(getName()), seed + i);
+ }
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ int h = cachedProbes[probe] % dataSize;
+ if (h < 0) {
+ h += dataSize;
+ }
+ return h;
+ }
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java?rev=1031413&r1=1031412&r2=1031413&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java Fri Nov 5 03:20:13 2010
@@ -22,9 +22,9 @@ import org.apache.mahout.math.Vector;
/**
* An encoder that does the standard thing for a virtual bias term.
*/
-public class ConstantValueEncoder extends FeatureVectorEncoder {
+public class ConstantValueEncoder extends CachingValueEncoder {
public ConstantValueEncoder(String name) {
- super(name);
+ super(name, 0);
}
@Override
@@ -49,10 +49,4 @@ public class ConstantValueEncoder extend
public String asString(String originalForm) {
return getName();
}
-
- @Override
- protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe){
- return hash(name, probe, dataSize);
- }
-
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java?rev=1031413&r1=1031412&r2=1031413&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java Fri Nov 5 03:20:13 2010
@@ -22,10 +22,9 @@ import org.apache.mahout.math.Vector;
/**
* Continuous values are stored in fixed randomized location in the feature vector.
*/
-public class ContinuousValueEncoder extends FeatureVectorEncoder {
-
+public class ContinuousValueEncoder extends CachingValueEncoder {
public ContinuousValueEncoder(String name) {
- super(name);
+ super(name, CONTINUOUS_VALUE_HASH_SEED);
}
/**
@@ -48,11 +47,6 @@ public class ContinuousValueEncoder exte
}
@Override
- protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
- return hash(name, CONTINUOUS_VALUE_HASH_SEED + probe, dataSize);
- }
-
- @Override
protected double getWeight(byte[] originalForm, double w) {
return w * Double.parseDouble(new String(originalForm));
}
Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsv.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsv.java?rev=1031413&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsv.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsv.java Fri Nov 5 03:20:13 2010
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Oct 24, 2010 Time: 7:45:24 PM To change this
+ * template use File | Settings | File Templates.
+ */
+public class SimpleCsv {
+ public static final int SEPARATOR_CHAR = '\t';
+ public static final String SEPARATOR = "\t";
+ private static final int FIELDS = 100;
+
+ public static void main(String[] args) throws IOException {
+ FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ encoder[i] = new ConstantValueEncoder("v" + 1);
+ }
+
+ OnlineSummarizer[] s = new OnlineSummarizer[FIELDS];
+ for (int i = 0; i < FIELDS; i++) {
+ s[i] = new OnlineSummarizer();
+ }
+ long t0 = System.currentTimeMillis();
+ Vector v = new DenseVector(1000);
+ if (args[0].equals("--generate")) {
+ PrintWriter out = new PrintWriter(new File(args[2]));
+ int n = Integer.parseInt(args[1]);
+ for (int i = 0; i < n; i++) {
+ Line x = Line.generate();
+ out.println(x);
+ }
+ out.close();
+ } else if ("--parse".equals(args[0])) {
+ BufferedReader in = new BufferedReader(new FileReader(args[1]));
+ String line = in.readLine();
+ while (line != null) {
+ v.assign(0);
+ Line x = new Line(line);
+ for (int i = 0; i < FIELDS; i++) {
+ s[i].add(x.getDouble(i));
+ encoder[i].addToVector(x.get(i), v);
+ }
+ line = in.readLine();
+ }
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ } else if ("--fast".equals(args[0])) {
+ FastLineReader in = new FastLineReader(new FileInputStream(args[1]));
+ FastLine line = in.read();
+ while (line != null) {
+ v.assign(0);
+ for (int i = 0; i < FIELDS; i++) {
+ double z = line.getDouble(i);
+ s[i].add(z);
+// encoder[i].addToVector((byte[]) null, z, v);
+ }
+ line = in.read();
+ }
+ String separator = "";
+ for (int i = 0; i < FIELDS; i++) {
+ System.out.printf("%s%.3f", separator, s[i].getMean());
+ separator = ",";
+ }
+ }
+ System.out.printf("\nElapsed time = %.3f\n", (System.currentTimeMillis() - t0) / 1000.0);
+ }
+
+
+ private static class Line {
+ private static final Splitter onTabs = Splitter.on(SEPARATOR).trimResults();
+ public static final Joiner withCommas = Joiner.on(SEPARATOR);
+
+ public static final Random rand = new Random(1);
+
+ private List<String> data;
+
+ private Line(String line) {
+ data = Lists.newArrayList(onTabs.split(line));
+ }
+
+ public Line() {
+ data = Lists.newArrayList();
+ }
+
+ public double getDouble(int field) {
+ return Double.parseDouble(data.get(field));
+ }
+
+ /**
+ * Generate a random line with 20 fields each with integer values.
+ *
+ * @return A new line with data.
+ */
+ public static Line generate() {
+ Line r = new Line();
+ for (int i = 0; i < FIELDS; i++) {
+ double mean = ((i + 1) * 257) % 50 + 1;
+ r.data.add(Integer.toString(randomValue(mean)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a random exponentially distributed integer with a particular mean value. This is
+ * just a way to create more small numbers than big numbers.
+ *
+ * @param mean
+ * @return
+ */
+ private static int randomValue(double mean) {
+ return (int) (-mean * Math.log(1 - rand.nextDouble()));
+ }
+
+ @Override
+ public String toString() {
+ return withCommas.join(data);
+ }
+
+ public String get(int field) {
+ return data.get(field);
+ }
+ }
+
+ private static class FastLine {
+
+ private ByteBuffer base;
+ private IntArrayList start = new IntArrayList();
+ private IntArrayList length = new IntArrayList();
+
+ public FastLine(ByteBuffer base) {
+ this.base = base;
+ }
+
+ public static FastLine read(ByteBuffer buf) {
+ FastLine r = new FastLine(buf);
+ r.start.add(buf.position());
+ int offset = buf.position();
+ while (offset < buf.limit()) {
+ int ch = buf.get();
+ switch (ch) {
+ case '\n':
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ return r;
+ case SEPARATOR_CHAR:
+ r.length.add(offset - r.start.get(r.length.size()) - 1);
+ r.start.add(offset);
+ break;
+ default:
+ // nothing to do for now
+ }
+ }
+ throw new IllegalArgumentException("Not enough bytes in buffer");
+ }
+
+ public double getDouble(int field) {
+ int offset = start.get(field);
+ int size = length.get(field);
+ switch (size) {
+ case 1:
+ return base.get(offset) - '0';
+ case 2:
+ return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0';
+ default:
+ double r = 0;
+ for (int i = 0; i < size; i++) {
+ r = 10 * r + base.get(offset + i);
+ }
+ return r;
+ }
+ }
+ }
+
+ private static class FastLineReader {
+ private InputStream in;
+ private ByteBuffer buf = ByteBuffer.allocate(100000);
+
+ public FastLineReader(InputStream in) throws IOException {
+ this.in = in;
+ buf.limit(0);
+ fillBuffer();
+ }
+
+ public FastLine read() throws IOException {
+ fillBuffer();
+ if (buf.remaining() > 0) {
+ return FastLine.read(buf);
+ } else {
+ return null;
+ }
+ }
+
+ private void fillBuffer() throws IOException {
+ if (buf.remaining() < 10000) {
+ buf.compact();
+ int n = in.read(buf.array(), buf.position(), buf.remaining());
+ if (n != -1) {
+ buf.limit(buf.position() + n);
+ buf.position(0);
+ } else {
+ buf.flip();
+ }
+ }
+ }
+ }
+}