You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:48 UTC
[17/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
new file mode 100644
index 0000000..378a885
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
@@ -0,0 +1,220 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import org.apache.commons.lang3.Validate;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.UpperTriangular;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+/**
+ * Job that accumulates Y'Y output
+ */
+public final class YtYJob {
+
+ public static final String PROP_OMEGA_SEED = "ssvd.omegaseed";
+ public static final String PROP_K = "ssvd.k";
+ public static final String PROP_P = "ssvd.p";
+
+ // we have single output, so we use standard output
+ public static final String OUTPUT_YT_Y = "part-";
+
+ private YtYJob() {
+ }
+
+ public static class YtYMapper extends
+ Mapper<Writable, VectorWritable, IntWritable, VectorWritable> {
+
+ private int kp;
+ private Omega omega;
+ private UpperTriangular mYtY;
+
+ /*
+ * we keep yRow in a dense form here but keep an eye not to dense up while
+ * doing YtY products. I am not sure that sparse vector would create much
+ * performance benefits since we must to assume that y would be more often
+ * dense than sparse, so for bulk dense operations that would perform
+ * somewhat better than a RandomAccessSparse vector frequent updates.
+ */
+ private Vector yRow;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ int k = context.getConfiguration().getInt(PROP_K, -1);
+ int p = context.getConfiguration().getInt(PROP_P, -1);
+
+ Validate.isTrue(k > 0, "invalid k parameter");
+ Validate.isTrue(p > 0, "invalid p parameter");
+
+ kp = k + p;
+ long omegaSeed =
+ Long.parseLong(context.getConfiguration().get(PROP_OMEGA_SEED));
+
+ omega = new Omega(omegaSeed, k + p);
+
+ mYtY = new UpperTriangular(kp);
+
+ // see which one works better!
+ // yRow = new RandomAccessSparseVector(kp);
+ yRow = new DenseVector(kp);
+ }
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ omega.computeYRow(value.get(), yRow);
+ // compute outer product update for YtY
+
+ if (yRow.isDense()) {
+ for (int i = 0; i < kp; i++) {
+ double yi;
+ if ((yi = yRow.getQuick(i)) == 0.0) {
+ continue; // avoid densing up here unnecessarily
+ }
+ for (int j = i; j < kp; j++) {
+ double yj;
+ if ((yj = yRow.getQuick(j)) != 0.0) {
+ mYtY.setQuick(i, j, mYtY.getQuick(i, j) + yi * yj);
+ }
+ }
+ }
+ } else {
+ /*
+ * the disadvantage of using sparse vector (aside from the fact that we
+ * are creating some short-lived references) here is that we obviously
+ * do two times more iterations then necessary if y row is pretty dense.
+ */
+ for (Vector.Element eli : yRow.nonZeroes()) {
+ int i = eli.index();
+ for (Vector.Element elj : yRow.nonZeroes()) {
+ int j = elj.index();
+ if (j < i) {
+ continue;
+ }
+ mYtY.setQuick(i, j, mYtY.getQuick(i, j) + eli.get() * elj.get());
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ context.write(new IntWritable(context.getTaskAttemptID().getTaskID()
+ .getId()),
+ new VectorWritable(new DenseVector(mYtY.getData())));
+ }
+ }
+
+ public static class YtYReducer extends
+ Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+ private final VectorWritable accum = new VectorWritable();
+ private DenseVector acc;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ int k = context.getConfiguration().getInt(PROP_K, -1);
+ int p = context.getConfiguration().getInt(PROP_P, -1);
+
+ Validate.isTrue(k > 0, "invalid k parameter");
+ Validate.isTrue(p > 0, "invalid p parameter");
+ accum.set(acc = new DenseVector(k + p));
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ context.write(new IntWritable(), accum);
+ }
+
+ @Override
+ protected void reduce(IntWritable key,
+ Iterable<VectorWritable> values,
+ Context arg2) throws IOException,
+ InterruptedException {
+ for (VectorWritable vw : values) {
+ acc.addAll(vw.get());
+ }
+ }
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputPaths,
+ Path outputPath,
+ int k,
+ int p,
+ long seed) throws ClassNotFoundException,
+ InterruptedException, IOException {
+
+ Job job = new Job(conf);
+ job.setJobName("YtY-job");
+ job.setJarByClass(YtYJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPaths);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(YtYMapper.class);
+
+ job.getConfiguration().setLong(PROP_OMEGA_SEED, seed);
+ job.getConfiguration().setInt(PROP_K, k);
+ job.getConfiguration().setInt(PROP_P, p);
+
+ /*
+ * we must reduce to just one matrix which means we need only one reducer.
+ * But it's ok since each mapper outputs only one vector (a packed
+ * UpperTriangular) so even if there're thousands of mappers, one reducer
+ * should cope just fine.
+ */
+ job.setNumReduceTasks(1);
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("YtY job unsuccessful.");
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
new file mode 100644
index 0000000..7033efe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
@@ -0,0 +1,638 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.OrderedIntDoubleMapping;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.UpperTriangular;
+
+/**
+ * Givens Thin solver. Standard Givens operations are reordered in a way that
+ * helps us to push them thru MapReduce operations in a block fashion.
+ */
+public class GivensThinSolver {
+
+ private double[] vARow;
+ private double[] vQtRow;
+ private final double[][] mQt;
+ private final double[][] mR;
+ private int qtStartRow;
+ private int rStartRow;
+ private int m;
+ private final int n; // m-row cnt, n- column count, m>=n
+ private int cnt;
+ private final double[] cs = new double[2];
+
+ public GivensThinSolver(int m, int n) {
+ if (!(m >= n)) {
+ throw new IllegalArgumentException("Givens thin QR: must be true: m>=n");
+ }
+
+ this.m = m;
+ this.n = n;
+
+ mQt = new double[n][];
+ mR = new double[n][];
+ vARow = new double[n];
+ vQtRow = new double[m];
+
+ for (int i = 0; i < n; i++) {
+ mQt[i] = new double[this.m];
+ mR[i] = new double[this.n];
+ }
+ cnt = 0;
+ }
+
+ public void reset() {
+ cnt = 0;
+ }
+
+ public void solve(Matrix a) {
+
+ assert a.rowSize() == m;
+ assert a.columnSize() == n;
+
+ double[] aRow = new double[n];
+ for (int i = 0; i < m; i++) {
+ Vector aRowV = a.viewRow(i);
+ for (int j = 0; j < n; j++) {
+ aRow[j] = aRowV.getQuick(j);
+ }
+ appendRow(aRow);
+ }
+ }
+
+ public boolean isFull() {
+ return cnt == m;
+ }
+
+ public int getM() {
+ return m;
+ }
+
+ public int getN() {
+ return n;
+ }
+
+ public int getCnt() {
+ return cnt;
+ }
+
+ public void adjust(int newM) {
+ if (newM == m) {
+ // no adjustment is required.
+ return;
+ }
+ if (newM < n) {
+ throw new IllegalArgumentException("new m can't be less than n");
+ }
+ if (newM < cnt) {
+ throw new IllegalArgumentException(
+ "new m can't be less than rows accumulated");
+ }
+ vQtRow = new double[newM];
+
+ // grow or shrink qt rows
+ if (newM > m) {
+ // grow qt rows
+ for (int i = 0; i < n; i++) {
+ mQt[i] = Arrays.copyOf(mQt[i], newM);
+ System.arraycopy(mQt[i], 0, mQt[i], newM - m, m);
+ Arrays.fill(mQt[i], 0, newM - m, 0);
+ }
+ } else {
+ // shrink qt rows
+ for (int i = 0; i < n; i++) {
+ mQt[i] = Arrays.copyOfRange(mQt[i], m - newM, m);
+ }
+ }
+
+ m = newM;
+
+ }
+
+ public void trim() {
+ adjust(cnt);
+ }
+
+ /**
+ * api for row-by-row addition
+ *
+ * @param aRow
+ */
+ public void appendRow(double[] aRow) {
+ if (cnt >= m) {
+ throw new IllegalStateException("thin QR solver fed more rows than initialized for");
+ }
+ try {
+ /*
+ * moving pointers around is inefficient but for the sanity's sake i am
+ * keeping it this way so i don't have to guess how R-tilde index maps to
+ * actual block index
+ */
+ Arrays.fill(vQtRow, 0);
+ vQtRow[m - cnt - 1] = 1;
+ int height = cnt > n ? n : cnt;
+ System.arraycopy(aRow, 0, vARow, 0, n);
+
+ if (height > 0) {
+ givens(vARow[0], getRRow(0)[0], cs);
+ applyGivensInPlace(cs[0], cs[1], vARow, getRRow(0), 0, n);
+ applyGivensInPlace(cs[0], cs[1], vQtRow, getQtRow(0), 0, m);
+ }
+
+ for (int i = 1; i < height; i++) {
+ givens(getRRow(i - 1)[i], getRRow(i)[i], cs);
+ applyGivensInPlace(cs[0], cs[1], getRRow(i - 1), getRRow(i), i,
+ n - i);
+ applyGivensInPlace(cs[0], cs[1], getQtRow(i - 1), getQtRow(i), 0,
+ m);
+ }
+ /*
+ * push qt and r-tilde 1 row down
+ *
+ * just swap the references to reduce GC churning
+ */
+ pushQtDown();
+ double[] swap = getQtRow(0);
+ setQtRow(0, vQtRow);
+ vQtRow = swap;
+
+ pushRDown();
+ swap = getRRow(0);
+ setRRow(0, vARow);
+ vARow = swap;
+
+ } finally {
+ cnt++;
+ }
+ }
+
+ private double[] getQtRow(int row) {
+
+ return mQt[(row += qtStartRow) >= n ? row - n : row];
+ }
+
+ private void setQtRow(int row, double[] qtRow) {
+ mQt[(row += qtStartRow) >= n ? row - n : row] = qtRow;
+ }
+
+ private void pushQtDown() {
+ qtStartRow = qtStartRow == 0 ? n - 1 : qtStartRow - 1;
+ }
+
+ private double[] getRRow(int row) {
+ row += rStartRow;
+ return mR[row >= n ? row - n : row];
+ }
+
+ private void setRRow(int row, double[] rrow) {
+ mR[(row += rStartRow) >= n ? row - n : row] = rrow;
+ }
+
+ private void pushRDown() {
+ rStartRow = rStartRow == 0 ? n - 1 : rStartRow - 1;
+ }
+
+ /*
+ * warning: both of these return actually n+1 rows with the last one being //
+ * not interesting.
+ */
+ public UpperTriangular getRTilde() {
+ UpperTriangular packedR = new UpperTriangular(n);
+ for (int i = 0; i < n; i++) {
+ packedR.assignNonZeroElementsInRow(i, getRRow(i));
+ }
+ return packedR;
+ }
+
+ public double[][] getThinQtTilde() {
+ if (qtStartRow != 0) {
+ /*
+ * rotate qt rows into place
+ *
+ * double[~500][], once per block, not a big deal.
+ */
+ double[][] qt = new double[n][];
+ System.arraycopy(mQt, qtStartRow, qt, 0, n - qtStartRow);
+ System.arraycopy(mQt, 0, qt, n - qtStartRow, qtStartRow);
+ return qt;
+ }
+ return mQt;
+ }
+
+ public static void applyGivensInPlace(double c, double s, double[] row1,
+ double[] row2, int offset, int len) {
+
+ int n = offset + len;
+ for (int j = offset; j < n; j++) {
+ double tau1 = row1[j];
+ double tau2 = row2[j];
+ row1[j] = c * tau1 - s * tau2;
+ row2[j] = s * tau1 + c * tau2;
+ }
+ }
+
+ public static void applyGivensInPlace(double c, double s, Vector row1,
+ Vector row2, int offset, int len) {
+
+ int n = offset + len;
+ for (int j = offset; j < n; j++) {
+ double tau1 = row1.getQuick(j);
+ double tau2 = row2.getQuick(j);
+ row1.setQuick(j, c * tau1 - s * tau2);
+ row2.setQuick(j, s * tau1 + c * tau2);
+ }
+ }
+
+ public static void applyGivensInPlace(double c, double s, int i, int k,
+ Matrix mx) {
+ int n = mx.columnSize();
+
+ for (int j = 0; j < n; j++) {
+ double tau1 = mx.get(i, j);
+ double tau2 = mx.get(k, j);
+ mx.set(i, j, c * tau1 - s * tau2);
+ mx.set(k, j, s * tau1 + c * tau2);
+ }
+ }
+
+ public static void fromRho(double rho, double[] csOut) {
+ if (rho == 1) {
+ csOut[0] = 0;
+ csOut[1] = 1;
+ return;
+ }
+ if (Math.abs(rho) < 1) {
+ csOut[1] = 2 * rho;
+ csOut[0] = Math.sqrt(1 - csOut[1] * csOut[1]);
+ return;
+ }
+ csOut[0] = 2 / rho;
+ csOut[1] = Math.sqrt(1 - csOut[0] * csOut[0]);
+ }
+
+ public static void givens(double a, double b, double[] csOut) {
+ if (b == 0) {
+ csOut[0] = 1;
+ csOut[1] = 0;
+ return;
+ }
+ if (Math.abs(b) > Math.abs(a)) {
+ double tau = -a / b;
+ csOut[1] = 1 / Math.sqrt(1 + tau * tau);
+ csOut[0] = csOut[1] * tau;
+ } else {
+ double tau = -b / a;
+ csOut[0] = 1 / Math.sqrt(1 + tau * tau);
+ csOut[1] = csOut[0] * tau;
+ }
+ }
+
+ public static double toRho(double c, double s) {
+ if (c == 0) {
+ return 1;
+ }
+ if (Math.abs(s) < Math.abs(c)) {
+ return Math.signum(c) * s / 2;
+ } else {
+ return Math.signum(s) * 2 / c;
+ }
+ }
+
+ public static void mergeR(UpperTriangular r1, UpperTriangular r2) {
+ TriangularRowView r1Row = new TriangularRowView(r1);
+ TriangularRowView r2Row = new TriangularRowView(r2);
+
+ int kp = r1Row.size();
+ assert kp == r2Row.size();
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u),
+ cs);
+ applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u);
+ }
+ }
+ }
+
+ public static void mergeR(double[][] r1, double[][] r2) {
+ int kp = r1[0].length;
+ assert kp == r2[0].length;
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1[u][u], r2[u - v][u], cs);
+ applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u);
+ }
+ }
+
+ }
+
+ public static void mergeRonQ(UpperTriangular r1, UpperTriangular r2,
+ double[][] qt1, double[][] qt2) {
+ TriangularRowView r1Row = new TriangularRowView(r1);
+ TriangularRowView r2Row = new TriangularRowView(r2);
+ int kp = r1Row.size();
+ assert kp == r2Row.size();
+ assert kp == qt1.length;
+ assert kp == qt2.length;
+
+ int r = qt1[0].length;
+ assert qt2[0].length == r;
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u),
+ cs);
+ applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u);
+ applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r);
+ }
+ }
+ }
+
+ public static void mergeRonQ(double[][] r1, double[][] r2, double[][] qt1,
+ double[][] qt2) {
+
+ int kp = r1[0].length;
+ assert kp == r2[0].length;
+ assert kp == qt1.length;
+ assert kp == qt2.length;
+
+ int r = qt1[0].length;
+ assert qt2[0].length == r;
+ double[] cs = new double[2];
+
+ /*
+ * pairwise givens(a,b) so that a come off main diagonal in r1 and bs come
+ * off u-th upper subdiagonal in r2.
+ */
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1[u][u], r2[u - v][u], cs);
+ applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u);
+ applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r);
+ }
+ }
+ }
+
+ // returns merged Q (which in this case is the qt1)
+ public static double[][] mergeQrUp(double[][] qt1, double[][] r1,
+ double[][] r2) {
+ int kp = qt1.length;
+ int r = qt1[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qt1, qTilde);
+ return qt1;
+ }
+
+ // returns merged Q (which in this case is the qt1)
+ public static double[][] mergeQrUp(double[][] qt1, UpperTriangular r1, UpperTriangular r2) {
+ int kp = qt1.length;
+ int r = qt1[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qt1, qTilde);
+ return qt1;
+ }
+
+ public static double[][] mergeQrDown(double[][] r1, double[][] qt2, double[][] r2) {
+ int kp = qt2.length;
+ int r = qt2[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qTilde, qt2);
+ return qTilde;
+
+ }
+
+ public static double[][] mergeQrDown(UpperTriangular r1, double[][] qt2, UpperTriangular r2) {
+ int kp = qt2.length;
+ int r = qt2[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qTilde, qt2);
+ return qTilde;
+
+ }
+
+ public static double[][] computeQtHat(double[][] qt, int i,
+ Iterator<UpperTriangular> rIter) {
+ UpperTriangular rTilde = rIter.next();
+ for (int j = 1; j < i; j++) {
+ mergeR(rTilde, rIter.next());
+ }
+ if (i > 0) {
+ qt = mergeQrDown(rTilde, qt, rIter.next());
+ }
+ while (rIter.hasNext()) {
+ qt = mergeQrUp(qt, rTilde, rIter.next());
+ }
+ return qt;
+ }
+
+ // test helpers
+ public static boolean isOrthonormal(double[][] qt, boolean insufficientRank, double epsilon) {
+ int n = qt.length;
+ int rank = 0;
+ for (int i = 0; i < n; i++) {
+ Vector ei = new DenseVector(qt[i], true);
+
+ double norm = ei.norm(2);
+
+ if (Math.abs(1.0 - norm) < epsilon) {
+ rank++;
+ } else if (Math.abs(norm) > epsilon) {
+ return false; // not a rank deficiency, either
+ }
+
+ for (int j = 0; j <= i; j++) {
+ Vector ej = new DenseVector(qt[j], true);
+ double dot = ei.dot(ej);
+ if (!(Math.abs((i == j && rank > j ? 1.0 : 0.0) - dot) < epsilon)) {
+ return false;
+ }
+ }
+ }
+ return insufficientRank ? rank < n : rank == n;
+ }
+
+ public static boolean isOrthonormalBlocked(Iterable<double[][]> qtHats,
+ boolean insufficientRank, double epsilon) {
+ int n = qtHats.iterator().next().length;
+ int rank = 0;
+ for (int i = 0; i < n; i++) {
+ List<Vector> ei = Lists.newArrayList();
+ // Vector e_i=new DenseVector (qt[i],true);
+ for (double[][] qtHat : qtHats) {
+ ei.add(new DenseVector(qtHat[i], true));
+ }
+
+ double norm = 0;
+ for (Vector v : ei) {
+ norm += v.dot(v);
+ }
+ norm = Math.sqrt(norm);
+ if (Math.abs(1 - norm) < epsilon) {
+ rank++;
+ } else if (Math.abs(norm) > epsilon) {
+ return false; // not a rank deficiency, either
+ }
+
+ for (int j = 0; j <= i; j++) {
+ List<Vector> ej = Lists.newArrayList();
+ for (double[][] qtHat : qtHats) {
+ ej.add(new DenseVector(qtHat[j], true));
+ }
+
+ // Vector e_j = new DenseVector ( qt[j], true);
+ double dot = 0;
+ for (int k = 0; k < ei.size(); k++) {
+ dot += ei.get(k).dot(ej.get(k));
+ }
+ if (!(Math.abs((i == j && rank > j ? 1 : 0) - dot) < epsilon)) {
+ return false;
+ }
+ }
+ }
+ return insufficientRank ? rank < n : rank == n;
+ }
+
+ private static final class TriangularRowView extends AbstractVector {
+ private final UpperTriangular viewed;
+ private int rowNum;
+
+ private TriangularRowView(UpperTriangular viewed) {
+ super(viewed.columnSize());
+ this.viewed = viewed;
+
+ }
+
+ TriangularRowView setViewedRow(int row) {
+ rowNum = row;
+ return this;
+ }
+
+ @Override
+ public boolean isDense() {
+ return true;
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return false;
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return viewed.getQuick(rowNum, index);
+ }
+
+ @Override
+ public Vector like() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ viewed.setQuick(rowNum, index, value);
+
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return true;
+ }
+
+ @Override
+ public Matrix matrixLike(int rows, int columns) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p/>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ int[] indices = updates.getIndices();
+ double[] values = updates.getValues();
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ viewed.setQuick(rowNum, indices[i], values[i]);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
new file mode 100644
index 0000000..09be91f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Gram Schmidt quick helper.
+ */
+public final class GramSchmidt {
+
+ private GramSchmidt() {
+ }
+
+ public static void orthonormalizeColumns(Matrix mx) {
+
+ int n = mx.numCols();
+
+ for (int c = 0; c < n; c++) {
+ Vector col = mx.viewColumn(c);
+ for (int c1 = 0; c1 < c; c1++) {
+ Vector viewC1 = mx.viewColumn(c1);
+ col.assign(col.minus(viewC1.times(viewC1.dot(col))));
+
+ }
+ final double norm2 = col.norm(2);
+ col.assign(new DoubleFunction() {
+ @Override
+ public double apply(double x) {
+ return x / norm2;
+ }
+ });
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
new file mode 100644
index 0000000..8509e0a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.lib.MultipleOutputs;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.iterator.CopyConstructorIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
+import org.apache.mahout.math.UpperTriangular;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * QR first step without MR abstractions and doing it just in terms of iterators
+ * and collectors. (although Collector is probably an outdated api).
+ *
+ *
+ */
+@SuppressWarnings("deprecation")
+public class QRFirstStep implements Closeable, OutputCollector<Writable, Vector> {
+
+ public static final String PROP_K = "ssvd.k";
+ public static final String PROP_P = "ssvd.p";
+ public static final String PROP_AROWBLOCK_SIZE = "ssvd.arowblock.size";
+
+ private int kp;
+ private List<double[]> yLookahead;
+ private GivensThinSolver qSolver;
+ private int blockCnt;
+ private final DenseBlockWritable value = new DenseBlockWritable();
+ private final Writable tempKey = new IntWritable();
+ private MultipleOutputs outputs;
+ private final Deque<Closeable> closeables = Lists.newLinkedList();
+ private SequenceFile.Writer tempQw;
+ private Path tempQPath;
+ private final List<UpperTriangular> rSubseq = Lists.newArrayList();
+ private final Configuration jobConf;
+
+ private final OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut;
+ private final OutputCollector<? super Writable, ? super VectorWritable> rHatOut;
+
+ public QRFirstStep(Configuration jobConf,
+ OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut,
+ OutputCollector<? super Writable, ? super VectorWritable> rHatOut) {
+ this.jobConf = jobConf;
+ this.qtHatOut = qtHatOut;
+ this.rHatOut = rHatOut;
+ setup();
+ }
+
+ @Override
+ public void close() throws IOException {
+ cleanup();
+ }
+
+ public int getKP() {
+ return kp;
+ }
+
+ private void flushSolver() throws IOException {
+ UpperTriangular r = qSolver.getRTilde();
+ double[][] qt = qSolver.getThinQtTilde();
+
+ rSubseq.add(r);
+
+ value.setBlock(qt);
+ getTempQw().append(tempKey, value);
+
+ /*
+ * this probably should be a sparse row matrix, but compressor should get it
+ * for disk and in memory we want it dense anyway, sparse random
+ * implementations would be a mostly a memory management disaster consisting
+ * of rehashes and GC // thrashing. (IMHO)
+ */
+ value.setBlock(null);
+ qSolver.reset();
+ }
+
+ // second pass to run a modified version of computeQHatSequence.
+ private void flushQBlocks() throws IOException {
+ if (blockCnt == 1) {
+ /*
+ * only one block, no temp file, no second pass. should be the default
+ * mode for efficiency in most cases. Sure mapper should be able to load
+ * the entire split in memory -- and we don't require even that.
+ */
+ value.setBlock(qSolver.getThinQtTilde());
+ outputQHat(value);
+ outputR(new VectorWritable(new DenseVector(qSolver.getRTilde().getData(),
+ true)));
+
+ } else {
+ secondPass();
+ }
+ }
+
+ private void outputQHat(DenseBlockWritable value) throws IOException {
+ qtHatOut.collect(NullWritable.get(), value);
+ }
+
+ private void outputR(VectorWritable value) throws IOException {
+ rHatOut.collect(NullWritable.get(), value);
+ }
+
+ private void secondPass() throws IOException {
+ qSolver = null; // release mem
+ FileSystem localFs = FileSystem.getLocal(jobConf);
+ SequenceFile.Reader tempQr =
+ new SequenceFile.Reader(localFs, tempQPath, jobConf);
+ closeables.addFirst(tempQr);
+ int qCnt = 0;
+ while (tempQr.next(tempKey, value)) {
+ value
+ .setBlock(GivensThinSolver.computeQtHat(value.getBlock(),
+ qCnt,
+ new CopyConstructorIterator<>(rSubseq.iterator())));
+ if (qCnt == 1) {
+ /*
+ * just merge r[0] <- r[1] so it doesn't have to repeat in subsequent
+ * computeQHat iterators
+ */
+ GivensThinSolver.mergeR(rSubseq.get(0), rSubseq.remove(1));
+ } else {
+ qCnt++;
+ }
+ outputQHat(value);
+ }
+
+ assert rSubseq.size() == 1;
+
+ outputR(new VectorWritable(new DenseVector(rSubseq.get(0).getData(), true)));
+
+ }
+
+ protected void map(Vector incomingYRow) throws IOException {
+ double[] yRow;
+ if (yLookahead.size() == kp) {
+ if (qSolver.isFull()) {
+
+ flushSolver();
+ blockCnt++;
+
+ }
+ yRow = yLookahead.remove(0);
+
+ qSolver.appendRow(yRow);
+ } else {
+ yRow = new double[kp];
+ }
+
+ if (incomingYRow.isDense()) {
+ for (int i = 0; i < kp; i++) {
+ yRow[i] = incomingYRow.get(i);
+ }
+ } else {
+ Arrays.fill(yRow, 0);
+ for (Element yEl : incomingYRow.nonZeroes()) {
+ yRow[yEl.index()] = yEl.get();
+ }
+ }
+
+ yLookahead.add(yRow);
+ }
+
+ protected void setup() {
+
+ int r = Integer.parseInt(jobConf.get(PROP_AROWBLOCK_SIZE));
+ int k = Integer.parseInt(jobConf.get(PROP_K));
+ int p = Integer.parseInt(jobConf.get(PROP_P));
+ kp = k + p;
+
+ yLookahead = Lists.newArrayListWithCapacity(kp);
+ qSolver = new GivensThinSolver(r, kp);
+ outputs = new MultipleOutputs(new JobConf(jobConf));
+ closeables.addFirst(new Closeable() {
+ @Override
+ public void close() throws IOException {
+ outputs.close();
+ }
+ });
+
+ }
+
+ protected void cleanup() throws IOException {
+ try {
+ if (qSolver == null && yLookahead.isEmpty()) {
+ return;
+ }
+ if (qSolver == null) {
+ qSolver = new GivensThinSolver(yLookahead.size(), kp);
+ }
+ // grow q solver up if necessary
+
+ qSolver.adjust(qSolver.getCnt() + yLookahead.size());
+ while (!yLookahead.isEmpty()) {
+
+ qSolver.appendRow(yLookahead.remove(0));
+
+ }
+ assert qSolver.isFull();
+ if (++blockCnt > 1) {
+ flushSolver();
+ assert tempQw != null;
+ closeables.remove(tempQw);
+ Closeables.close(tempQw, false);
+ }
+ flushQBlocks();
+
+ } finally {
+ IOUtils.close(closeables);
+ }
+
+ }
+
+ private SequenceFile.Writer getTempQw() throws IOException {
+ if (tempQw == null) {
+ /*
+ * temporary Q output hopefully will not exceed size of IO cache in which
+ * case it is only good since it is going to be managed by kernel, not
+ * java GC. And if IO cache is not good enough, then at least it is always
+ * sequential.
+ */
+ String taskTmpDir = System.getProperty("java.io.tmpdir");
+
+ FileSystem localFs = FileSystem.getLocal(jobConf);
+ Path parent = new Path(taskTmpDir);
+ Path sub = new Path(parent, "qw_" + System.currentTimeMillis());
+ tempQPath = new Path(sub, "q-temp.seq");
+ tempQw =
+ SequenceFile.createWriter(localFs,
+ jobConf,
+ tempQPath,
+ IntWritable.class,
+ DenseBlockWritable.class,
+ CompressionType.BLOCK);
+ closeables.addFirst(tempQw);
+ closeables.addFirst(new IOUtils.DeleteFileOnClose(new File(tempQPath
+ .toString())));
+ }
+ return tempQw;
+ }
+
+ @Override
+ public void collect(Writable key, Vector vw) throws IOException {
+ map(vw);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
new file mode 100644
index 0000000..545f1f9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import org.apache.commons.lang3.Validate;
+import org.apache.mahout.common.iterator.CopyConstructorIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
+import org.apache.mahout.math.UpperTriangular;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Second/last step of QR iterations. Takes input of qtHats and rHats and
+ * provides iterator to pull ready rows of final Q.
+ *
+ */
+public class QRLastStep implements Closeable, Iterator<Vector> {
+
+ private final Iterator<DenseBlockWritable> qHatInput;
+
+ private final List<UpperTriangular> mRs = Lists.newArrayList();
+ private final int blockNum;
+ private double[][] mQt;
+ private int cnt;
+ private int r;
+ private int kp;
+ private Vector qRow;
+
+ /**
+ *
+ * @param qHatInput
+ * the Q-Hat input that was output in the first step
+ * @param rHatInput
+ * all RHat outputs int the group in order of groups
+ * @param blockNum
+ * our RHat number in the group
+ */
+ public QRLastStep(Iterator<DenseBlockWritable> qHatInput,
+ Iterator<VectorWritable> rHatInput,
+ int blockNum) {
+ this.blockNum = blockNum;
+ this.qHatInput = qHatInput;
+ /*
+ * in this implementation we actually preload all Rs into memory to make R
+ * sequence modifications more efficient.
+ */
+ int block = 0;
+ while (rHatInput.hasNext()) {
+ Vector value = rHatInput.next().get();
+ if (block < blockNum && block > 0) {
+ GivensThinSolver.mergeR(mRs.get(0), new UpperTriangular(value));
+ } else {
+ mRs.add(new UpperTriangular(value));
+ }
+ block++;
+ }
+
+ }
+
+ private boolean loadNextQt() {
+ boolean more = qHatInput.hasNext();
+ if (!more) {
+ return false;
+ }
+ DenseBlockWritable v = qHatInput.next();
+ mQt =
+ GivensThinSolver
+ .computeQtHat(v.getBlock(),
+ blockNum == 0 ? 0 : 1,
+ new CopyConstructorIterator<>(mRs.iterator()));
+ r = mQt[0].length;
+ kp = mQt.length;
+ if (qRow == null) {
+ qRow = new DenseVector(kp);
+ }
+ return true;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (mQt != null && cnt == r) {
+ mQt = null;
+ }
+ boolean result = true;
+ if (mQt == null) {
+ result = loadNextQt();
+ cnt = 0;
+ }
+ return result;
+ }
+
+ @Override
+ public Vector next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ Validate.isTrue(hasNext(), "Q input overrun");
+ /*
+ * because Q blocks are initially stored in inverse order
+ */
+ int qRowIndex = r - cnt - 1;
+ for (int j = 0; j < kp; j++) {
+ qRow.setQuick(j, mQt[j][qRowIndex]);
+ }
+ cnt++;
+ return qRow;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() throws IOException {
+ mQt = null;
+ mRs.clear();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
new file mode 100644
index 0000000..51484c7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Search for nearest neighbors using a complete search (i.e. looping through
+ * the references and comparing each vector to the query).
+ */
+public class BruteSearch extends UpdatableSearcher {
+ /**
+ * The list of reference vectors.
+ */
+ private final List<Vector> referenceVectors;
+
+ public BruteSearch(DistanceMeasure distanceMeasure) {
+ super(distanceMeasure);
+ referenceVectors = Lists.newArrayList();
+ }
+
+ @Override
+ public void add(Vector vector) {
+ referenceVectors.add(vector);
+ }
+
+ @Override
+ public int size() {
+ return referenceVectors.size();
+ }
+
+ /**
+ * Scans the list of reference vectors one at a time for @limit neighbors of
+ * the query vector.
+ * The weights of the WeightedVectors are not taken into account.
+ *
+ * @param query The query vector.
+ * @param limit The number of results to returned; must be at least 1.
+ * @return A list of the closest @limit neighbors for the given query.
+ */
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ Preconditions.checkArgument(limit > 0, "limit must be greater then 0!");
+ limit = Math.min(limit, referenceVectors.size());
+ // A priority queue of the best @limit elements, ordered from worst to best so that the worst
+ // element is always on top and can easily be removed.
+ PriorityQueue<WeightedThing<Integer>> bestNeighbors =
+ new PriorityQueue<>(limit, Ordering.natural().reverse());
+ // The resulting list of weighted WeightedVectors (the weight is the distance from the query).
+ List<WeightedThing<Vector>> results =
+ Lists.newArrayListWithCapacity(limit);
+ int rowNumber = 0;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ // Only add a new neighbor if the result is better than the worst element
+ // in the queue or the queue isn't full.
+ if (bestNeighbors.size() < limit || bestNeighbors.peek().getWeight() > distance) {
+ bestNeighbors.add(new WeightedThing<>(rowNumber, distance));
+ if (bestNeighbors.size() > limit) {
+ bestNeighbors.poll();
+ } else {
+ // Increase the size of the results list by 1 so we can add elements in the reverse
+ // order from the queue.
+ results.add(null);
+ }
+ }
+ ++rowNumber;
+ }
+ for (int i = limit - 1; i >= 0; --i) {
+ WeightedThing<Integer> neighbor = bestNeighbors.poll();
+ results.set(i, new WeightedThing<>(
+ referenceVectors.get(neighbor.getValue()), neighbor.getWeight()));
+ }
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Vector bestVector = null;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ if (distance < bestDistance && (!differentThanQuery || !row.equals(query))) {
+ bestDistance = distance;
+ bestVector = row;
+ }
+ }
+ return new WeightedThing<>(bestVector, bestDistance);
+ }
+
+ /**
+ * Searches with a list full of queries in a threaded fashion.
+ *
+ * @param queries The queries to search for.
+ * @param limit The number of results to return.
+ * @param numThreads Number of threads to use in searching.
+ * @return A list of result lists.
+ */
+ public List<List<WeightedThing<Vector>>> search(Iterable<WeightedVector> queries,
+ final int limit, int numThreads) throws InterruptedException {
+ ExecutorService executor = Executors.newFixedThreadPool(numThreads);
+ List<Callable<Object>> tasks = Lists.newArrayList();
+
+ final List<List<WeightedThing<Vector>>> results = Lists.newArrayList();
+ int i = 0;
+ for (final Vector query : queries) {
+ results.add(null);
+ final int index = i++;
+ tasks.add(new Callable<Object>() {
+ @Override
+ public Object call() throws Exception {
+ results.set(index, BruteSearch.this.search(query, limit));
+ return null;
+ }
+ });
+ }
+
+ executor.invokeAll(tasks);
+ executor.shutdown();
+
+ return results;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return referenceVectors.iterator();
+ }
+
+ @Override
+ public boolean remove(Vector query, double epsilon) {
+ int rowNumber = 0;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ if (distance < epsilon) {
+ referenceVectors.remove(rowNumber);
+ return true;
+ }
+ rowNumber++;
+ }
+ return false;
+ }
+
+ @Override
+ public void clear() {
+ referenceVectors.clear();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
new file mode 100644
index 0000000..006f4b6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Does approximate nearest neighbor search by projecting the vectors similar to ProjectionSearch.
+ * The main difference between this class and the ProjectionSearch is the use of sorted arrays
+ * instead of binary search trees to implement the sets of scalar projections.
+ *
+ * Instead of taking log n time to add a vector to each of the vectors, * the pending additions are
+ * kept separate and are searched using a brute search. When there are "enough" pending additions,
+ * they're committed into the main pool of vectors.
+ */
+public class FastProjectionSearch extends UpdatableSearcher {
+ // The list of vectors that have not yet been projected (that are pending).
+ private final List<Vector> pendingAdditions = Lists.newArrayList();
+
+ // The list of basis vectors. Populated when the first vector's dimension is know by calling
+ // initialize once.
+ private Matrix basisMatrix = null;
+
+ // The list of sorted lists of scalar projections. The outer list has one entry for each basis
+ // vector that all the other vectors will be projected on.
+ // For each basis vector, the inner list has an entry for each vector that has been projected.
+ // These entries are WeightedThing<Vector> where the weight is the value of the scalar
+ // projection and the value is the vector begin referred to.
+ private List<List<WeightedThing<Vector>>> scalarProjections;
+
+ // The number of projection used for approximating the distance.
+ private final int numProjections;
+
+ // The number of elements to keep on both sides of the closest estimated distance as possible
+ // candidates for the best actual distance.
+ private final int searchSize;
+
+ // Initially, the dimension of the vectors searched by this searcher is unknown. After adding
+ // the first vector, the basis will be initialized. This marks whether initialization has
+ // happened or not so we only do it once.
+ private boolean initialized = false;
+
+ // Removing vectors from the searcher is done lazily to avoid the linear time cost of removing
+ // elements from an array. This member keeps track of the number of removed vectors (marked as
+ // "impossible" values in the array) so they can be removed when updating the structure.
+ private int numPendingRemovals = 0;
+
+ private static final double ADDITION_THRESHOLD = 0.05;
+ private static final double REMOVAL_THRESHOLD = 0.02;
+
+ public FastProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) {
+ super(distanceMeasure);
+ Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
+ "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
+ this.numProjections = numProjections;
+ this.searchSize = searchSize;
+ scalarProjections = Lists.newArrayListWithCapacity(numProjections);
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.add(Lists.<WeightedThing<Vector>>newArrayList());
+ }
+ }
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
+ initialized = true;
+ }
+
+ /**
+ * Add a new Vector to the Searcher that will be checked when getting
+ * the nearest neighbors.
+ * <p/>
+ * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal
+ * Searcher data structures could be invalidated.
+ */
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ pendingAdditions.add(vector);
+ }
+
+ /**
+ * Returns the number of WeightedVectors being searched for nearest neighbors.
+ */
+ @Override
+ public int size() {
+ return pendingAdditions.size() + scalarProjections.get(0).size() - numPendingRemovals;
+ }
+
+ /**
+ * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is
+ * returned. The value of the WeightedThing is the neighbor and the weight is the
+ * the distance (calculated by some metric - see a concrete implementation) between the query
+ * and neighbor.
+ * The actual type of vector in the pair is the same as the vector added to the Searcher.
+ */
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ reindex(false);
+
+ Set<Vector> candidates = Sets.newHashSet();
+ Vector projection = basisMatrix.times(query);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ int middle = Collections.binarySearch(currProjections,
+ new WeightedThing<Vector>(projection.get(i)));
+ if (middle < 0) {
+ middle = -(middle + 1);
+ }
+ for (int j = Math.max(0, middle - searchSize);
+ j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) {
+ if (currProjections.get(j).getValue() == null) {
+ continue;
+ }
+ candidates.add(currProjections.get(j).getValue());
+ }
+ }
+
+ List<WeightedThing<Vector>> top =
+ Lists.newArrayListWithCapacity(candidates.size() + pendingAdditions.size());
+ for (Vector candidate : Iterables.concat(candidates, pendingAdditions)) {
+ top.add(new WeightedThing<>(candidate, distanceMeasure.distance(candidate, query)));
+ }
+ Collections.sort(top);
+
+ return top.subList(0, Math.min(top.size(), limit));
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ reindex(false);
+
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Vector bestVector = null;
+
+ Vector projection = basisMatrix.times(query);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ int middle = Collections.binarySearch(currProjections,
+ new WeightedThing<Vector>(projection.get(i)));
+ if (middle < 0) {
+ middle = -(middle + 1);
+ }
+ for (int j = Math.max(0, middle - searchSize);
+ j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) {
+ if (currProjections.get(j).getValue() == null) {
+ continue;
+ }
+ Vector vector = currProjections.get(j).getValue();
+ double distance = distanceMeasure.distance(vector, query);
+ if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) {
+ bestDistance = distance;
+ bestVector = vector;
+ }
+ }
+ }
+
+ for (Vector vector : pendingAdditions) {
+ double distance = distanceMeasure.distance(vector, query);
+ if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) {
+ bestDistance = distance;
+ bestVector = vector;
+ }
+ }
+
+ return new WeightedThing<>(bestVector, bestDistance);
+ }
+
+ @Override
+ public boolean remove(Vector vector, double epsilon) {
+ WeightedThing<Vector> closestPair = searchFirst(vector, false);
+ if (distanceMeasure.distance(closestPair.getValue(), vector) > epsilon) {
+ return false;
+ }
+
+ boolean isProjected = true;
+ Vector projection = basisMatrix.times(vector);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ WeightedThing<Vector> searchedThing = new WeightedThing<>(projection.get(i));
+ int middle = Collections.binarySearch(currProjections, searchedThing);
+ if (middle < 0) {
+ isProjected = false;
+ break;
+ }
+ // Elements to be removed are kept in the sorted array until the next reindex, but their inner vector
+ // is set to null.
+ scalarProjections.get(i).set(middle, searchedThing);
+ }
+ if (isProjected) {
+ ++numPendingRemovals;
+ return true;
+ }
+
+ for (int i = 0; i < pendingAdditions.size(); ++i) {
+ if (pendingAdditions.get(i).equals(vector)) {
+ pendingAdditions.remove(i);
+ break;
+ }
+ }
+ return true;
+ }
+
+ private void reindex(boolean force) {
+ int numProjected = scalarProjections.get(0).size();
+ if (force || pendingAdditions.size() > ADDITION_THRESHOLD * numProjected
+ || numPendingRemovals > REMOVAL_THRESHOLD * numProjected) {
+
+ // We only need to copy the first list because when iterating we use only that list for the Vector
+ // references.
+ // see public Iterator<Vector> iterator()
+ List<List<WeightedThing<Vector>>> scalarProjections = Lists.newArrayListWithCapacity(numProjections);
+ for (int i = 0; i < numProjections; ++i) {
+ if (i == 0) {
+ scalarProjections.add(Lists.newArrayList(this.scalarProjections.get(i)));
+ } else {
+ scalarProjections.add(this.scalarProjections.get(i));
+ }
+ }
+
+ // Project every pending vector onto every basis vector.
+ for (Vector pending : pendingAdditions) {
+ Vector projection = basisMatrix.times(pending);
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.get(i).add(new WeightedThing<>(pending, projection.get(i)));
+ }
+ }
+ pendingAdditions.clear();
+ // For each basis vector, sort the resulting list (for binary search) and remove the number
+ // of pending removals (it's the same for every basis vector) at the end (the weights are
+ // set to Double.POSITIVE_INFINITY when removing).
+ for (int i = 0; i < numProjections; ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ for (WeightedThing<Vector> v : currProjections) {
+ if (v.getValue() == null) {
+ v.setWeight(Double.POSITIVE_INFINITY);
+ }
+ }
+ Collections.sort(currProjections);
+ for (int j = 0; j < numPendingRemovals; ++j) {
+ currProjections.remove(currProjections.size() - 1);
+ }
+ }
+ numPendingRemovals = 0;
+
+ this.scalarProjections = scalarProjections;
+ }
+ }
+
+ @Override
+ public void clear() {
+ pendingAdditions.clear();
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.get(i).clear();
+ }
+ numPendingRemovals = 0;
+ }
+
+ /**
+ * This iterates on the snapshot of the contents first instantiated regardless of any future modifications.
+ * Changes done after the iterator is created will not be visible to the iterator but will be visible
+ * when searching.
+ * @return iterator through the vectors in this searcher.
+ */
+ @Override
+ public Iterator<Vector> iterator() {
+ reindex(true);
+ return new AbstractIterator<Vector>() {
+ private final Iterator<WeightedThing<Vector>> data = scalarProjections.get(0).iterator();
+ @Override
+ protected Vector computeNext() {
+ do {
+ if (!data.hasNext()) {
+ return endOfData();
+ }
+ WeightedThing<Vector> next = data.next();
+ if (next.getValue() != null) {
+ return next.getValue();
+ }
+ } while (true);
+ }
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
new file mode 100644
index 0000000..eb91813
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+
+/**
+ * Decorates a weighted vector with a locality sensitive hash.
+ *
+ * The LSH function implemented is the random hyperplane based hash function.
+ * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S. Charikar, section 3.
+ * http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf
+ */
+public class HashedVector extends WeightedVector {
+ protected static final int INVALID_INDEX = -1;
+
+ /**
+ * Value of the locality sensitive hash. It is 64 bit.
+ */
+ private final long hash;
+
+ public HashedVector(Vector vector, long hash, int index) {
+ super(vector, 1, index);
+ this.hash = hash;
+ }
+
+ public HashedVector(Vector vector, Matrix projection, int index, long mask) {
+ super(vector, 1, index);
+ this.hash = mask & computeHash64(vector, projection);
+ }
+
+ public HashedVector(WeightedVector weightedVector, Matrix projection, long mask) {
+ super(weightedVector.getVector(), weightedVector.getWeight(), weightedVector.getIndex());
+ this.hash = mask & computeHash64(weightedVector, projection);
+ }
+
+ public static long computeHash64(Vector vector, Matrix projection) {
+ long hash = 0;
+ for (Element element : projection.times(vector).nonZeroes()) {
+ if (element.get() > 0) {
+ hash += 1L << element.index();
+ }
+ }
+ return hash;
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection) {
+ return hash(v, projection, 0);
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection, long mask) {
+ return new HashedVector(v, projection, mask);
+ }
+
+ public int hammingDistance(long otherHash) {
+ return Long.bitCount(hash ^ otherHash);
+ }
+
+ public long getHash() {
+ return hash;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash, getVector());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof HashedVector)) {
+ return o instanceof Vector && this.minus((Vector) o).norm(1) == 0;
+ }
+ HashedVector v = (HashedVector) o;
+ return v.hash == this.hash && this.minus(v).norm(1) == 0;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = super.hashCode();
+ result = 31 * result + (int) (hash ^ (hash >>> 32));
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
new file mode 100644
index 0000000..aa1f103
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/**
+ * Implements a Searcher that uses locality sensitivity hash as a first pass approximation
+ * to estimate distance without floating point math. The clever bit about this implementation
+ * is that it does an adaptive cutoff for the cutoff on the bitwise distance. Making this
+ * cutoff adaptive means that we only needs to make a single pass through the data.
+ */
+public class LocalitySensitiveHashSearch extends UpdatableSearcher {
+ /**
+ * Number of bits in the locality sensitive hash. 64 bits fix neatly into a long.
+ */
+ private static final int BITS = 64;
+
+ /**
+ * Bit mask for the computed hash. Currently, it's 0xffffffffffff.
+ */
+ private static final long BIT_MASK = -1L;
+
+ /**
+ * The maximum Hamming distance between two hashes that the hash limit can grow back to.
+ * It starts at BITS and decreases as more points than are needed are added to the candidate priority queue.
+ * But, after the observed distribution of distances becomes too good (we're seeing less than some percentage of the
+ * total number of points; using the hash strategy somewhere less than 25%) the limit is increased to compute
+ * more distances.
+ * This is because
+ */
+ private static final int MAX_HASH_LIMIT = 32;
+
+ /**
+ * Minimum number of points with a given Hamming from the query that must be observed to consider raising the minimum
+ * distance for a candidate.
+ */
+ private static final int MIN_DISTRIBUTION_COUNT = 10;
+
+ private final Multiset<HashedVector> trainingVectors = HashMultiset.create();
+
+ /**
+ * This matrix of BITS random vectors is used to compute the Locality Sensitive Hash
+ * we compute the dot product with these vectors using a matrix multiplication and then use just
+ * sign of each result as one bit in the hash
+ */
+ private Matrix projection;
+
+ /**
+ * The search size determines how many top results we retain. We do this because the hash distance
+ * isn't guaranteed to be entirely monotonic with respect to the real distance. To the extent that
+ * actual distance is well approximated by hash distance, then the searchSize can be decreased to
+ * roughly the number of results that you want.
+ */
+ private int searchSize;
+
+ /**
+ * Controls how the hash limit is raised. 0 means use minimum of distribution, 1 means use first quartile.
+ * Intermediate values indicate an interpolation should be used. Negative values mean to never increase.
+ */
+ private double hashLimitStrategy = 0.9;
+
+ /**
+ * Number of evaluations of the full distance between two points that was required.
+ */
+ private int distanceEvaluations = 0;
+
+ /**
+ * Whether the projection matrix was initialized. This has to be deferred until the size of the vectors is known,
+ * effectively until the first vector is added.
+ */
+ private boolean initialized = false;
+
+ public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int searchSize) {
+ super(distanceMeasure);
+ this.searchSize = searchSize;
+ this.projection = null;
+ }
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+ projection = RandomProjector.generateBasisNormal(BITS, numDimensions);
+ }
+
+ private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) {
+ long queryHash = HashedVector.computeHash64(query, projection);
+
+ // We keep an approximation of the closest vectors here.
+ PriorityQueue<WeightedThing<Vector>> top = Searcher.getCandidateQueue(getSearchSize());
+
+ // We scan the vectors using bit counts as an approximation of the dot product so we can do as few
+ // full distance computations as possible. Our goal is to only do full distance computations for
+ // vectors with hash distance at most as large as the searchSize biggest hash distance seen so far.
+
+ OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1];
+ for (int i = 0; i < BITS + 1; i++) {
+ distribution[i] = new OnlineSummarizer();
+ }
+
+ distanceEvaluations = 0;
+
+ // We keep the counts of the hash distances here. This lets us accurately
+ // judge what hash distance cutoff we should use.
+ int[] hashCounts = new int[BITS + 1];
+
+ // Maximum number of different bits to still consider a vector a candidate for nearest neighbor.
+ // Starts at the maximum number of bits, but decreases and can increase.
+ int hashLimit = BITS;
+ int limitCount = 0;
+ double distanceLimit = Double.POSITIVE_INFINITY;
+
+ // In this loop, we have the invariants that:
+ //
+ // limitCount = sum_{i<hashLimit} hashCount[i]
+ // and
+ // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] < searchSize
+ for (HashedVector vector : trainingVectors) {
+ // This computes the Hamming Distance between the vector's hash and the query's hash.
+ // The result is correlated with the angle between the vectors.
+ int bitDot = vector.hammingDistance(queryHash);
+ if (bitDot <= hashLimit) {
+ distanceEvaluations++;
+
+ double distance = distanceMeasure.distance(query, vector);
+ distribution[bitDot].add(distance);
+
+ if (distance < distanceLimit) {
+ top.insertWithOverflow(new WeightedThing<Vector>(vector, distance));
+ if (top.size() == searchSize) {
+ distanceLimit = top.top().getWeight();
+ }
+
+ hashCounts[bitDot]++;
+ limitCount++;
+ while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] > searchSize) {
+ hashLimit--;
+ limitCount -= hashCounts[hashLimit];
+ }
+
+ if (hashLimitStrategy >= 0) {
+ while (hashLimit < MAX_HASH_LIMIT && distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT
+ && ((1 - hashLimitStrategy) * distribution[hashLimit].getQuartile(0)
+ + hashLimitStrategy * distribution[hashLimit].getQuartile(1)) < distanceLimit) {
+ limitCount += hashCounts[hashLimit];
+ hashLimit++;
+ }
+ }
+ }
+ }
+ }
+ return top;
+ }
+
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size());
+ while (top.size() != 0) {
+ WeightedThing<Vector> wv = top.pop();
+ results.add(new WeightedThing<>(((HashedVector) wv.getValue()).getVector(), wv.getWeight()));
+ }
+ Collections.reverse(results);
+ if (limit < results.size()) {
+ results = results.subList(0, limit);
+ }
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ * This is nearly the same as search().
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ // We get the top searchSize neighbors.
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ // We then cut the number down to just the best 2.
+ while (top.size() > 2) {
+ top.pop();
+ }
+ // If there are fewer than 2 results, we just return the one we have.
+ if (top.size() < 2) {
+ return removeHash(top.pop());
+ }
+ // There are exactly 2 results.
+ WeightedThing<Vector> secondBest = top.pop();
+ WeightedThing<Vector> best = top.pop();
+ // If the best result is the same as the query, but we don't want to return the query.
+ if (differentThanQuery && best.getValue().equals(query)) {
+ best = secondBest;
+ }
+ return removeHash(best);
+ }
+
+ protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) {
+ return new WeightedThing<>(((HashedVector) input.getValue()).getVector(), input.getWeight());
+ }
+
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ trainingVectors.add(new HashedVector(vector, projection, HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ @Override
+ public int size() {
+ return trainingVectors.size();
+ }
+
+ public int getSearchSize() {
+ return searchSize;
+ }
+
+ public void setSearchSize(int size) {
+ searchSize = size;
+ }
+
+ public void setRaiseHashLimitStrategy(double strategy) {
+ hashLimitStrategy = strategy;
+ }
+
+ /**
+ * This is only for testing.
+ * @return the number of times the actual distance between two vectors was computed.
+ */
+ public int resetEvaluationCount() {
+ int result = distanceEvaluations;
+ distanceEvaluations = 0;
+ return result;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return Iterators.transform(trainingVectors.iterator(), new Function<HashedVector, Vector>() {
+ @Override
+ public Vector apply(org.apache.mahout.math.neighborhood.HashedVector input) {
+ Preconditions.checkNotNull(input);
+ //noinspection ConstantConditions
+ return input.getVector();
+ }
+ });
+ }
+
+ @Override
+ public boolean remove(Vector v, double epsilon) {
+ return trainingVectors.remove(new HashedVector(v, projection, HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ @Override
+ public void clear() {
+ trainingVectors.clear();
+ }
+}