You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2017/09/28 03:17:28 UTC
[2/3] incubator-hivemall git commit: Close #117,
Close #111: [HIVEMALL-17] Support SLIM neighborhood-learning
recommendation algorithm
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/recommend/SlimUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/recommend/SlimUDTF.java b/core/src/main/java/hivemall/recommend/SlimUDTF.java
new file mode 100644
index 0000000..e205c18
--- /dev/null
+++ b/core/src/main/java/hivemall/recommend/SlimUDTF.java
@@ -0,0 +1,759 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.math.matrix.sparse.DoKFloatMatrix;
+import hivemall.math.vector.VectorProcedure;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable;
+import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import hivemall.utils.lang.mutable.MutableDouble;
+import hivemall.utils.lang.mutable.MutableInt;
+import hivemall.utils.lang.mutable.MutableObject;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+/**
+ * Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
+ *
+ * <pre>
+ * Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N Recommender Systems, Proc. ICDM, 2011.
+ * </pre>
+ */
+@Description(
+ name = "train_slim",
+ value = "_FUNC_( int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]) "
+ + "- Returns row index, column index and non-zero weight value of prediction model")
+public class SlimUDTF extends UDTFWithOptions {
+ private static final Log logger = LogFactory.getLog(SlimUDTF.class);
+
+ //--------------------------------------------
+ // intput OIs
+
+ private PrimitiveObjectInspector itemIOI;
+ private PrimitiveObjectInspector itemJOI;
+ private MapObjectInspector riOI;
+ private MapObjectInspector rjOI;
+
+ private MapObjectInspector knnItemsOI;
+ private PrimitiveObjectInspector knnItemsKeyOI;
+ private MapObjectInspector knnItemsValueOI;
+ private PrimitiveObjectInspector knnItemsValueKeyOI;
+ private PrimitiveObjectInspector knnItemsValueValueOI;
+
+ private PrimitiveObjectInspector riKeyOI;
+ private PrimitiveObjectInspector riValueOI;
+
+ private PrimitiveObjectInspector rjKeyOI;
+ private PrimitiveObjectInspector rjValueOI;
+
+ //--------------------------------------------
+ // hyperparameters
+
+ private double l1;
+ private double l2;
+ private int numIterations;
+
+ //--------------------------------------------
+ // model parameters and else
+
+ /** item-item weight matrix */
+ private transient DoKFloatMatrix _weightMatrix;
+
+ //--------------------------------------------
+ // caching for each item i
+
+ private int _previousItemId;
+
+ @Nullable
+ private transient Int2FloatOpenHashTable _ri;
+ @Nullable
+ private transient IntOpenHashTable<Int2FloatOpenHashTable> _kNNi;
+ /** The number of elements in kNNi */
+ @Nullable
+ private transient MutableInt _nnzKNNi;
+
+ //--------------------------------------------
+ // variables for iteration supports
+
+ /** item-user matrix holding the input data */
+ @Nullable
+ private transient DoKFloatMatrix _dataMatrix;
+
+ // used to store KNN data into temporary file for iterative training
+ private transient NioStatefullSegment _fileIO;
+ private transient ByteBuffer _inputBuf;
+
+ private ConversionState _cvState;
+ private long _observedTrainingExamples;
+
+ //--------------------------------------------
+
+ public SlimUDTF() {}
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for -help option
+ String rawArgs = HiveUtils.getConstString(argOIs[0]);
+ parseOptions(rawArgs);
+ }
+
+ if (numArgs != 5 && numArgs != 6) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 5 or 6 arguments: int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
+
+ this.riOI = HiveUtils.asMapOI(argOIs[1]);
+ this.riKeyOI = HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
+ this.riValueOI = HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
+
+ this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
+ this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
+ this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
+ this.knnItemsValueKeyOI = HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
+ this.knnItemsValueValueOI = HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
+
+ this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
+
+ this.rjOI = HiveUtils.asMapOI(argOIs[4]);
+ this.rjKeyOI = HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
+ this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
+
+ processOptions(argOIs);
+
+ this._observedTrainingExamples = 0L;
+ this._previousItemId = Integer.MIN_VALUE;
+ this._weightMatrix = null;
+ this._dataMatrix = null;
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("j");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("nn");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("w");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("l1", "l1coefficient", true,
+ "Coefficient for l1 regularizer [default: 0.001]");
+ opts.addOption("l2", "l2coefficient", true,
+ "Coefficient for l2 regularizer [default: 0.0005]");
+ opts.addOption("iters", "iterations", true,
+ "The number of iterations for coordinate descent [default: 30]");
+ opts.addOption("disable_cv", "disable_cvtest", false,
+ "Whether to disable convergence check [default: enabled]");
+ opts.addOption("cv_rate", "convergence_rate", true,
+ "Threshold to determine convergence [default: 0.005]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ CommandLine cl = null;
+ double l1 = 0.001d;
+ double l2 = 0.0005d;
+ int numIterations = 30;
+ boolean conversionCheck = true;
+ double cv_rate = 0.005d;
+
+ if (argOIs.length >= 6) {
+ String rawArgs = HiveUtils.getConstString(argOIs[5]);
+ cl = parseOptions(rawArgs);
+
+ l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
+ if (l1 < 0.d) {
+ throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1);
+ }
+
+ l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
+ if (l2 < 0.d) {
+ throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2);
+ }
+
+ numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
+ if (numIterations <= 0) {
+ throw new UDFArgumentException("Argument `int iters` must be greater than 0: "
+ + numIterations);
+ }
+
+ conversionCheck = !cl.hasOption("disable_cvtest");
+
+ cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate);
+ if (cv_rate <= 0) {
+ throw new UDFArgumentException(
+ "Argument `double cv_rate` must be greater than 0.0: " + cv_rate);
+ }
+ }
+
+ this.l1 = l1;
+ this.l2 = l2;
+ this.numIterations = numIterations;
+ this._cvState = new ConversionState(conversionCheck, cv_rate);
+
+ return cl;
+ }
+
+ @Override
+ public void process(@Nonnull Object[] args) throws HiveException {
+ if (_weightMatrix == null) {// initialize variables
+ this._weightMatrix = new DoKFloatMatrix();
+ if (numIterations >= 2) {
+ this._dataMatrix = new DoKFloatMatrix();
+ }
+ this._nnzKNNi = new MutableInt();
+ }
+
+ final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], itemIOI);
+
+ if (itemI != _previousItemId || _ri == null) {
+ // cache Ri and kNNi
+ this._ri = int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, riValueOI, _dataMatrix,
+ _ri);
+ this._kNNi = kNNentries(args[2], knnItemsOI, knnItemsKeyOI, knnItemsValueOI,
+ knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
+
+ final int numKNNItems = _nnzKNNi.getValue();
+ if (numIterations >= 2 && numKNNItems >= 1) {
+ recordTrainingInput(itemI, _kNNi, numKNNItems);
+ }
+ this._previousItemId = itemI;
+ }
+
+ int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
+ Int2FloatOpenHashTable rj = int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI,
+ _dataMatrix);
+
+ train(itemI, _ri, _kNNi, itemJ, rj);
+ _observedTrainingExamples++;
+ }
+
+ private void recordTrainingInput(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int numKNNItems)
+ throws HiveException {
+ ByteBuffer buf = this._inputBuf;
+ NioStatefullSegment dst = this._fileIO;
+
+ if (buf == null) {
+ // invoke only at task node (initialize is also invoked in compilation)
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_slim", ".sgmt"); // to save KNN data
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException("Cannot write a temporary file: "
+ + file.getAbsolutePath());
+ }
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ }
+
+ this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 1024); // 8MB
+ this._fileIO = dst = new NioStatefullSegment(file, false);
+ }
+
+ int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * knnItems.size()
+ + (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
+ int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
+
+ int remain = buf.remaining();
+ if (remain < requiredBytes) {
+ writeBuffer(buf, dst);
+ }
+
+ buf.putInt(recordBytes);
+ buf.putInt(itemI);
+ buf.putInt(knnItems.size());
+
+ final IMapIterator<Int2FloatOpenHashTable> entries = knnItems.entries();
+ while (entries.next() != -1) {
+ int user = entries.getKey();
+ buf.putInt(user);
+
+ Int2FloatOpenHashTable ru = entries.getValue();
+ buf.putInt(ru.size());
+ final Int2FloatOpenHashTable.IMapIterator itor = ru.entries();
+ while (itor.next() != -1) {
+ buf.putInt(itor.getKey());
+ buf.putFloat(itor.getValue());
+ }
+ }
+ }
+
+ private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
+ @Nonnull final NioStatefullSegment dst) throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ private void train(final int itemI, @Nonnull final Int2FloatOpenHashTable ri,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> kNNi, final int itemJ,
+ @Nonnull final Int2FloatOpenHashTable rj) {
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = rj.size();
+ if (N == 0) {
+ return;
+ }
+
+ double gradSum = 0.d;
+ double rateSum = 0.d;
+ double lossSum = 0.d;
+
+ final Int2FloatOpenHashTable.IMapIterator itor = rj.entries();
+ while (itor.next() != -1) {
+ int user = itor.getKey();
+ double ruj = itor.getValue();
+ double rui = ri.get(user, 0.f);
+
+ double eui = rui - predict(user, itemI, kNNi, itemJ, W);
+ gradSum += ruj * eui;
+ rateSum += ruj * ruj;
+ lossSum += eui * eui;
+ }
+
+ gradSum /= N;
+ rateSum /= N;
+
+ double wij = W.get(itemI, itemJ, 0.d);
+ double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
+ _cvState.incrLoss(loss);
+
+ W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
+ }
+
+ private void train(final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems, final int itemJ) {
+ final DoKFloatMatrix A = _dataMatrix;
+ final DoKFloatMatrix W = _weightMatrix;
+
+ final int N = A.numColumns(itemJ);
+ if (N == 0) {
+ return;
+ }
+
+ final MutableDouble mutableGradSum = new MutableDouble(0.d);
+ final MutableDouble mutableRateSum = new MutableDouble(0.d);
+ final MutableDouble mutableLossSum = new MutableDouble(0.d);
+
+ A.eachNonZeroInRow(itemJ, new VectorProcedure() {
+ @Override
+ public void apply(int user, double ruj) {
+ double rui = A.get(itemI, user, 0.d);
+ double eui = rui - predict(user, itemI, knnItems, itemJ, W);
+
+ mutableGradSum.addValue(ruj * eui);
+ mutableRateSum.addValue(ruj * ruj);
+ mutableLossSum.addValue(eui * eui);
+ }
+ });
+
+ double gradSum = mutableGradSum.getValue() / N;
+ double rateSum = mutableRateSum.getValue() / N;
+
+ double wij = W.get(itemI, itemJ, 0.d);
+ double loss = mutableLossSum.getValue() / N + 0.5 * l2 * wij * wij + l1 * wij;
+ _cvState.incrLoss(loss);
+
+ W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
+ }
+
+ private static double predict(final int user, final int itemI,
+ @Nonnull final IntOpenHashTable<Int2FloatOpenHashTable> knnItems,
+ final int excludeIndex, @Nonnull final DoKFloatMatrix weightMatrix) {
+ final Int2FloatOpenHashTable kNNu = knnItems.get(user);
+ if (kNNu == null) {
+ return 0.d;
+ }
+
+ double pred = 0.d;
+ final Int2FloatOpenHashTable.IMapIterator itor = kNNu.entries();
+ while (itor.next() != -1) {
+ final int itemK = itor.getKey();
+ if (itemK == excludeIndex) {
+ continue;
+ }
+ float ruk = itor.getValue();
+ pred += ruk * weightMatrix.get(itemI, itemK, 0.d);
+ }
+ return pred;
+ }
+
+ private static double getUpdateTerm(final double gradSum, final double rateSum,
+ final double l1, final double l2) {
+ double update = 0.d;
+ if (Math.abs(gradSum) > l1) {
+ if (gradSum > 0.d) {
+ update = (gradSum - l1) / (rateSum + l2);
+ } else {
+ update = (gradSum + l1) / (rateSum + l2);
+ }
+ // non-negative constraints
+ if (update < 0.d) {
+ update = 0.d;
+ }
+ }
+ return update;
+ }
+
+ @Override
+ public void close() throws HiveException {
+ finalizeTraining();
+ forwardModel();
+ this._weightMatrix = null;
+ }
+
+ @VisibleForTesting
+ void finalizeTraining() throws HiveException {
+ if (numIterations > 1) {
+ this._ri = null;
+ this._kNNi = null;
+
+ runIterativeTraining();
+
+ this._dataMatrix = null;
+ }
+ }
+
+ private void runIterativeTraining() throws HiveException {
+ final ByteBuffer buf = this._inputBuf;
+ final NioStatefullSegment dst = this._fileIO;
+ assert (buf != null);
+ assert (dst != null);
+
+ final Reporter reporter = getReporter();
+ final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter(
+ "hivemall.recommend.slim$Counter", "iteration");
+
+ try {
+ if (dst.getPosition() == 0L) {// run iterations w/o temporary file
+ if (buf.position() == 0) {
+ return; // no training example
+ }
+ buf.flip();
+ for (int iter = 2; iter < numIterations; iter++) {
+ _cvState.next();
+ reportProgress(reporter);
+ setCounterValue(iterCounter, iter);
+
+ while (buf.remaining() > 0) {
+ int recordBytes = buf.getInt();
+ assert (recordBytes > 0) : recordBytes;
+ replayTrain(buf);
+ }
+ buf.rewind();
+ if (_cvState.isConverged(_observedTrainingExamples)) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + _cvState.getCurrentIteration()
+ + " iterations of "
+ + NumberUtils.formatNumber(_observedTrainingExamples)
+ + " training examples on memory (thus "
+ + NumberUtils.formatNumber(_observedTrainingExamples
+ * _cvState.getCurrentIteration()) + " training updates in total) ");
+
+ } else { // read training examples in the temporary file and invoke train for each example
+ // write KNNi in buffer to a temporary file
+ if (buf.remaining() > 0) {
+ writeBuffer(buf, dst);
+ }
+
+ try {
+ dst.flush();
+ } catch (IOException e) {
+ throw new HiveException("Failed to flush a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+
+ if (logger.isInfoEnabled()) {
+ File tmpFile = dst.getFile();
+ logger.info("Wrote KNN entries of axis items to a temporary file for iterative training: "
+ + tmpFile.getAbsolutePath()
+ + " ("
+ + FileUtils.prettyFileSize(tmpFile)
+ + ")");
+ }
+
+ // run iterations
+ for (int iter = 2; iter < numIterations; iter++) {
+ _cvState.next();
+ setCounterValue(iterCounter, iter);
+
+ buf.clear();
+ dst.resetPosition();
+ while (true) {
+ reportProgress(reporter);
+ // load a KNNi to a buffer in the temporary file
+ final int bytesRead;
+ try {
+ bytesRead = dst.read(buf);
+ } catch (IOException e) {
+ throw new HiveException("Failed to read a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ if (bytesRead == 0) { // reached file EOF
+ break;
+ }
+ assert (bytesRead > 0) : bytesRead;
+
+ // reads training examples from a buffer
+ buf.flip();
+ int remain = buf.remaining();
+ if (remain < SizeOf.INT) {
+ throw new HiveException("Illegal file format was detected");
+ }
+ while (remain >= SizeOf.INT) {
+ int pos = buf.position();
+ int recordBytes = buf.getInt();
+ remain -= SizeOf.INT;
+ if (remain < recordBytes) {
+ buf.position(pos);
+ break;
+ }
+
+ replayTrain(buf);
+ remain -= recordBytes;
+ }
+ buf.compact();
+ }
+ if (_cvState.isConverged(_observedTrainingExamples)) {
+ break;
+ }
+ }
+ logger.info("Performed "
+ + _cvState.getCurrentIteration()
+ + " iterations of "
+ + NumberUtils.formatNumber(_observedTrainingExamples)
+ + " training examples on memory and KNNi data on secondary storage (thus "
+ + NumberUtils.formatNumber(_observedTrainingExamples
+ * _cvState.getCurrentIteration()) + " training updates in total) ");
+
+ }
+ } catch (Throwable e) {
+ throw new HiveException("Exception caused in the iterative training", e);
+ } finally {
+ // delete the temporary file and release resources
+ try {
+ dst.close(true);
+ } catch (IOException e) {
+ throw new HiveException("Failed to close a file: "
+ + dst.getFile().getAbsolutePath(), e);
+ }
+ this._inputBuf = null;
+ this._fileIO = null;
+ }
+ }
+
+ private void replayTrain(@Nonnull final ByteBuffer buf) {
+ final int itemI = buf.getInt();
+ final int knnSize = buf.getInt();
+
+ final IntOpenHashTable<Int2FloatOpenHashTable> knnItems = new IntOpenHashTable<>(1024);
+ final Set<Integer> pairItems = new HashSet<>();
+ for (int i = 0; i < knnSize; i++) {
+ int user = buf.getInt();
+ int ruSize = buf.getInt();
+ Int2FloatOpenHashTable ru = new Int2FloatOpenHashTable(ruSize);
+ ru.defaultReturnValue(0.f);
+
+ for (int j = 0; j < ruSize; j++) {
+ int itemK = buf.getInt();
+ pairItems.add(itemK);
+ float ruk = buf.getFloat();
+ ru.put(itemK, ruk);
+ }
+ knnItems.put(user, ru);
+ }
+
+ for (int itemJ : pairItems) {
+ train(itemI, knnItems, itemJ);
+ }
+ }
+
+ private void forwardModel() throws HiveException {
+ final IntWritable f0 = new IntWritable(); // i
+ final IntWritable f1 = new IntWritable(); // nn
+ final FloatWritable f2 = new FloatWritable(); // w
+ final Object[] forwardObj = new Object[] {f0, f1, f2};
+
+ final MutableObject<HiveException> catched = new MutableObject<>();
+ _weightMatrix.eachNonZeroCell(new VectorProcedure() {
+ @Override
+ public void apply(int i, int j, float value) {
+ if (value == 0.f) {
+ return;
+ }
+ f0.set(i);
+ f1.set(j);
+ f2.set(value);
+ try {
+ forward(forwardObj);
+ } catch (HiveException e) {
+ catched.setIfAbsent(e);
+ }
+ }
+ });
+ HiveException ex = catched.get();
+ if (ex != null) {
+ throw ex;
+ }
+ logger.info("Forwarded SLIM's weights matrix");
+ }
+
+ @Nonnull
+ private static IntOpenHashTable<Int2FloatOpenHashTable> kNNentries(
+ @Nonnull final Object kNNiObj, @Nonnull final MapObjectInspector knnItemsOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsKeyOI,
+ @Nonnull final MapObjectInspector knnItemsValueOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI,
+ @Nonnull final PrimitiveObjectInspector knnItemsValueValueOI,
+ @Nullable IntOpenHashTable<Int2FloatOpenHashTable> knnItems,
+ @Nonnull final MutableInt nnzKNNi) {
+ if (knnItems == null) {
+ knnItems = new IntOpenHashTable<>(1024);
+ } else {
+ knnItems.clear();
+ }
+
+ int numElementOfKNNItems = 0;
+ for (Map.Entry<?, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) {
+ int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI);
+ Int2FloatOpenHashTable ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()),
+ knnItemsValueKeyOI, knnItemsValueValueOI);
+ knnItems.put(user, ru);
+ numElementOfKNNItems += ru.size();
+ }
+
+ nnzKNNi.setValue(numElementOfKNNItems);
+ return knnItems;
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(@Nonnull final Map<?, ?> map,
+ @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI) {
+ final Int2FloatOpenHashTable result = new Int2FloatOpenHashTable(map.size());
+ result.defaultReturnValue(0.f);
+
+ for (Map.Entry<?, ?> entry : map.entrySet()) {
+ float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
+ if (v == 0.f) {
+ continue;
+ }
+ int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
+ result.put(k, v);
+ }
+
+ return result;
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(final int item,
+ @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI,
+ @Nullable final DoKFloatMatrix dataMatrix) {
+ return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null);
+ }
+
+ @Nonnull
+ private static Int2FloatOpenHashTable int2floatMap(final int item,
+ @Nonnull final Map<?, ?> map, @Nonnull final PrimitiveObjectInspector keyOI,
+ @Nonnull final PrimitiveObjectInspector valueOI,
+ @Nullable final DoKFloatMatrix dataMatrix, @Nullable Int2FloatOpenHashTable dst) {
+ if (dst == null) {
+ dst = new Int2FloatOpenHashTable(map.size());
+ dst.defaultReturnValue(0.f);
+ } else {
+ dst.clear();
+ }
+
+ for (Map.Entry<?, ?> entry : map.entrySet()) {
+ float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
+ if (rating == 0.f) {
+ continue;
+ }
+ int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
+ dst.put(user, rating);
+ if (dataMatrix != null) {
+ dataMatrix.set(item, user, rating);
+ }
+ }
+
+ return dst;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
new file mode 100644
index 0000000..3b5585e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2DoubleOpenHashTable.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.collections.maps;
+
+import hivemall.utils.math.Primes;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.Arrays;
+
+/**
+ * An open-addressing hash table using double hashing.
+ *
+ * <pre>
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * </pre>
+ *
+ * @see http://en.wikipedia.org/wiki/Double_hashing
+ */
+public class Int2DoubleOpenHashTable implements Externalizable {
+
+ protected static final byte FREE = 0;
+ protected static final byte FULL = 1;
+ protected static final byte REMOVED = 2;
+
+ private static final float DEFAULT_LOAD_FACTOR = 0.75f;
+ private static final float DEFAULT_GROW_FACTOR = 2.0f;
+
+ protected final transient float _loadFactor;
+ protected final transient float _growFactor;
+
+ protected int _used = 0;
+ protected int _threshold;
+ protected double defaultReturnValue = -1.d;
+
+ protected int[] _keys;
+ protected double[] _values;
+ protected byte[] _states;
+
+ protected Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor,
+ boolean forcePrime) {
+ if (size < 1) {
+ throw new IllegalArgumentException();
+ }
+ this._loadFactor = loadFactor;
+ this._growFactor = growFactor;
+ int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size;
+ this._keys = new int[actualSize];
+ this._values = new double[actualSize];
+ this._states = new byte[actualSize];
+ this._threshold = (int) (actualSize * _loadFactor);
+ }
+
+ public Int2DoubleOpenHashTable(int size, float loadFactor, float growFactor) {
+ this(size, loadFactor, growFactor, true);
+ }
+
+ public Int2DoubleOpenHashTable(int size) {
+ this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
+ }
+
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Int2DoubleOpenHashTable() {// required for serialization
+ this._loadFactor = DEFAULT_LOAD_FACTOR;
+ this._growFactor = DEFAULT_GROW_FACTOR;
+ }
+
+ public void defaultReturnValue(double v) {
+ this.defaultReturnValue = v;
+ }
+
+ public boolean containsKey(final int key) {
+ return findKey(key) >= 0;
+ }
+
+ /**
+ * @return -1.d if not found
+ */
+ public double get(final int key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public double get(final int key, final double defaultValue) {
+ final int i = findKey(key);
+ if (i < 0) {
+ return defaultValue;
+ }
+ return _values[i];
+ }
+
+ public double put(final int key, final double value) {
+ final int hash = keyHash(key);
+ int keyLength = _keys.length;
+ int keyIdx = hash % keyLength;
+
+ boolean expanded = preAddEntry(keyIdx);
+ if (expanded) {
+ keyLength = _keys.length;
+ keyIdx = hash % keyLength;
+ }
+
+ final int[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+
+ if (states[keyIdx] == FULL) {// double hashing
+ if (keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ break;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ values[keyIdx] = value;
+ return old;
+ }
+ }
+ }
+ keys[keyIdx] = key;
+ values[keyIdx] = value;
+ states[keyIdx] = FULL;
+ ++_used;
+ return defaultReturnValue;
+ }
+
+ /** Return weather the required slot is free for new entry */
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
+ if (stat == FREE) {
+ return true;
+ }
+ if (stat == REMOVED && _keys[index] == key) {
+ return true;
+ }
+ return false;
+ }
+
+ /** @return expanded or not */
+ protected boolean preAddEntry(final int index) {
+ if ((_used + 1) >= _threshold) {// too filled
+ int newCapacity = Math.round(_keys.length * _growFactor);
+ ensureCapacity(newCapacity);
+ return true;
+ }
+ return false;
+ }
+
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ // try second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (isFree(keyIdx, key)) {
+ return -1;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ return keyIdx;
+ }
+ }
+ }
+ return -1;
+ }
+
+ public double remove(final int key) {
+ final int[] keys = _keys;
+ final double[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
+
+ final int hash = keyHash(key);
+ int keyIdx = hash % keyLength;
+ if (states[keyIdx] != FREE) {
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ // second hash
+ final int decr = 1 + (hash % (keyLength - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keyLength;
+ }
+ if (states[keyIdx] == FREE) {
+ return defaultReturnValue;
+ }
+ if (states[keyIdx] == FULL && keys[keyIdx] == key) {
+ double old = values[keyIdx];
+ states[keyIdx] = REMOVED;
+ --_used;
+ return old;
+ }
+ }
+ }
+ return defaultReturnValue;
+ }
+
+ public int size() {
+ return _used;
+ }
+
+ public void clear() {
+ Arrays.fill(_states, FREE);
+ this._used = 0;
+ }
+
+ public IMapIterator entries() {
+ return new MapIterator();
+ }
+
+ @Override
+ public String toString() {
+ int len = size() * 10 + 2;
+ StringBuilder buf = new StringBuilder(len);
+ buf.append('{');
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ buf.append(i.getKey());
+ buf.append('=');
+ buf.append(i.getValue());
+ if (i.hasNext()) {
+ buf.append(',');
+ }
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ protected void ensureCapacity(final int newCapacity) {
+ int prime = Primes.findLeastPrimeNumber(newCapacity);
+ rehash(prime);
+ this._threshold = Math.round(prime * _loadFactor);
+ }
+
+ private void rehash(final int newCapacity) {
+ int oldCapacity = _keys.length;
+ if (newCapacity <= oldCapacity) {
+ throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
+ }
+ final int[] newkeys = new int[newCapacity];
+ final double[] newValues = new double[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
+ int used = 0;
+ for (int i = 0; i < oldCapacity; i++) {
+ if (_states[i] == FULL) {
+ used++;
+ final int k = _keys[i];
+ final double v = _values[i];
+ final int hash = keyHash(k);
+ int keyIdx = hash % newCapacity;
+ if (newStates[keyIdx] == FULL) {// second hashing
+ int decr = 1 + (hash % (newCapacity - 2));
+ while (newStates[keyIdx] != FREE) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += newCapacity;
+ }
+ }
+ }
+ newkeys[keyIdx] = k;
+ newValues[keyIdx] = v;
+ newStates[keyIdx] = FULL;
+ }
+ }
+ this._keys = newkeys;
+ this._values = newValues;
+ this._states = newStates;
+ this._used = used;
+ }
+
+ private static int keyHash(int key) {
+ return key & 0x7fffffff;
+ }
+
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_threshold);
+ out.writeInt(_used);
+
+ out.writeInt(_keys.length);
+ IMapIterator i = entries();
+ while (i.next() != -1) {
+ out.writeInt(i.getKey());
+ out.writeDouble(i.getValue());
+ }
+ }
+
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ this._threshold = in.readInt();
+ this._used = in.readInt();
+
+ int keylen = in.readInt();
+ int[] keys = new int[keylen];
+ double[] values = new double[keylen];
+ byte[] states = new byte[keylen];
+ for (int i = 0; i < _used; i++) {
+ int k = in.readInt();
+ double v = in.readDouble();
+ int hash = keyHash(k);
+ int keyIdx = hash % keylen;
+ if (states[keyIdx] != FREE) {// second hash
+ int decr = 1 + (hash % (keylen - 2));
+ for (;;) {
+ keyIdx -= decr;
+ if (keyIdx < 0) {
+ keyIdx += keylen;
+ }
+ if (states[keyIdx] == FREE) {
+ break;
+ }
+ }
+ }
+ states[keyIdx] = FULL;
+ keys[keyIdx] = k;
+ values[keyIdx] = v;
+ }
+ this._keys = keys;
+ this._values = values;
+ this._states = states;
+ }
+
+ public interface IMapIterator {
+
+ public boolean hasNext();
+
+ /**
+ * @return -1 if not found
+ */
+ public int next();
+
+ public int getKey();
+
+ public double getValue();
+
+ }
+
+ private final class MapIterator implements IMapIterator {
+
+ int nextEntry;
+ int lastEntry = -1;
+
+ MapIterator() {
+ this.nextEntry = nextEntry(0);
+ }
+
+ /** find the index of next full entry */
+ int nextEntry(int index) {
+ while (index < _keys.length && _states[index] != FULL) {
+ index++;
+ }
+ return index;
+ }
+
+ public boolean hasNext() {
+ return nextEntry < _keys.length;
+ }
+
+ public int next() {
+ if (!hasNext()) {
+ return -1;
+ }
+ int curEntry = nextEntry;
+ this.lastEntry = curEntry;
+ this.nextEntry = nextEntry(curEntry + 1);
+ return curEntry;
+ }
+
+ public int getKey() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _keys[lastEntry];
+ }
+
+ public double getValue() {
+ if (lastEntry == -1) {
+ throw new IllegalStateException();
+ }
+ return _values[lastEntry];
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
index e9b5c8a..22de115 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java
@@ -90,23 +90,27 @@ public class Int2FloatOpenHashTable implements Externalizable {
this.defaultReturnValue = v;
}
- public boolean containsKey(int key) {
+ public boolean containsKey(final int key) {
return findKey(key) >= 0;
}
/**
* @return -1.f if not found
*/
- public float get(int key) {
- int i = findKey(key);
+ public float get(final int key) {
+ return get(key, defaultReturnValue);
+ }
+
+ public float get(final int key, final float defaultValue) {
+ final int i = findKey(key);
if (i < 0) {
- return defaultReturnValue;
+ return defaultValue;
}
return _values[i];
}
- public float put(int key, float value) {
- int hash = keyHash(key);
+ public float put(final int key, final float value) {
+ final int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
@@ -116,9 +120,9 @@ public class Int2FloatOpenHashTable implements Externalizable {
keyIdx = hash % keyLength;
}
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
+ final int[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
if (states[keyIdx] == FULL) {// double hashing
if (keys[keyIdx] == key) {
@@ -127,7 +131,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
return old;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -151,8 +155,8 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
/** Return weather the required slot is free for new entry */
- protected boolean isFree(int index, int key) {
- byte stat = _states[index];
+ protected boolean isFree(final int index, final int key) {
+ final byte stat = _states[index];
if (stat == FREE) {
return true;
}
@@ -163,7 +167,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
/** @return expanded or not */
- protected boolean preAddEntry(int index) {
+ protected boolean preAddEntry(final int index) {
if ((_used + 1) >= _threshold) {// too filled
int newCapacity = Math.round(_keys.length * _growFactor);
ensureCapacity(newCapacity);
@@ -172,19 +176,19 @@ public class Int2FloatOpenHashTable implements Externalizable {
return false;
}
- protected int findKey(int key) {
- int[] keys = _keys;
- byte[] states = _states;
- int keyLength = keys.length;
+ protected int findKey(final int key) {
+ final int[] keys = _keys;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
return keyIdx;
}
// try second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -201,13 +205,13 @@ public class Int2FloatOpenHashTable implements Externalizable {
return -1;
}
- public float remove(int key) {
- int[] keys = _keys;
- float[] values = _values;
- byte[] states = _states;
- int keyLength = keys.length;
+ public float remove(final int key) {
+ final int[] keys = _keys;
+ final float[] values = _values;
+ final byte[] states = _states;
+ final int keyLength = keys.length;
- int hash = keyHash(key);
+ final int hash = keyHash(key);
int keyIdx = hash % keyLength;
if (states[keyIdx] != FREE) {
if (states[keyIdx] == FULL && keys[keyIdx] == key) {
@@ -217,7 +221,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
return old;
}
// second hash
- int decr = 1 + (hash % (keyLength - 2));
+ final int decr = 1 + (hash % (keyLength - 2));
for (;;) {
keyIdx -= decr;
if (keyIdx < 0) {
@@ -242,6 +246,9 @@ public class Int2FloatOpenHashTable implements Externalizable {
}
public void clear() {
+ if (_used == 0) {
+ return; // no need to clear
+ }
Arrays.fill(_states, FREE);
this._used = 0;
}
@@ -274,21 +281,21 @@ public class Int2FloatOpenHashTable implements Externalizable {
this._threshold = Math.round(prime * _loadFactor);
}
- private void rehash(int newCapacity) {
+ private void rehash(final int newCapacity) {
int oldCapacity = _keys.length;
if (newCapacity <= oldCapacity) {
throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity);
}
- int[] newkeys = new int[newCapacity];
- float[] newValues = new float[newCapacity];
- byte[] newStates = new byte[newCapacity];
+ final int[] newkeys = new int[newCapacity];
+ final float[] newValues = new float[newCapacity];
+ final byte[] newStates = new byte[newCapacity];
int used = 0;
for (int i = 0; i < oldCapacity; i++) {
if (_states[i] == FULL) {
used++;
int k = _keys[i];
float v = _values[i];
- int hash = keyHash(k);
+ final int hash = keyHash(k);
int keyIdx = hash % newCapacity;
if (newStates[keyIdx] == FULL) {// second hashing
int decr = 1 + (hash % (newCapacity - 2));
@@ -310,7 +317,7 @@ public class Int2FloatOpenHashTable implements Externalizable {
this._used = used;
}
- private static int keyHash(int key) {
+ private static int keyHash(final int key) {
return key & 0x7fffffff;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
index 8e87fce..73431d1 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java
@@ -77,7 +77,10 @@ public final class Int2IntOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
- public Int2IntOpenHashTable() {// required for serialization
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Int2IntOpenHashTable() {
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
index dbade74..1c90ae0 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java
@@ -58,7 +58,10 @@ public final class IntOpenHashTable<V> implements Externalizable {
protected V[] _values;
protected byte[] _states;
- public IntOpenHashTable() {} // for Externalizable
+ /**
+ * Only for {@link Externalizable}
+ */
+ public IntOpenHashTable() {}
public IntOpenHashTable(int size) {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
index b4356ff..115571e 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java
@@ -78,6 +78,9 @@ public final class Long2DoubleOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
+ /**
+ * Only for {@link Externalizable}
+ */
public Long2DoubleOpenHashTable() {// required for serialization
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
index 6b0ab59..ba2de76 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java
@@ -78,7 +78,10 @@ public final class Long2FloatOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
- public Long2FloatOpenHashTable() {// required for serialization
+ /**
+ * Only for {@link Externalizable}
+ */
+ public Long2FloatOpenHashTable() {
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
}
@@ -113,7 +116,23 @@ public final class Long2FloatOpenHashTable implements Externalizable {
return _values[index];
}
+ public float _set(final int index, final float value) {
+ float old = _values[index];
+ _values[index] = value;
+ return old;
+ }
+
+ public float _remove(final int index) {
+ _states[index] = REMOVED;
+ --_used;
+ return _values[index];
+ }
+
public float put(final long key, final float value) {
+ return put(key, value, defaultReturnValue);
+ }
+
+ public float put(final long key, final float value, final float defaultValue) {
final int hash = keyHash(key);
int keyLength = _keys.length;
int keyIdx = hash % keyLength;
@@ -155,7 +174,7 @@ public final class Long2FloatOpenHashTable implements Externalizable {
values[keyIdx] = value;
states[keyIdx] = FULL;
++_used;
- return defaultReturnValue;
+ return defaultValue;
}
/** Return weather the required slot is free for new entry */
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
index 1ca4c40..6445231 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java
@@ -77,6 +77,9 @@ public final class Long2IntOpenHashTable implements Externalizable {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true);
}
+ /**
+ * Only for {@link Externalizable}
+ */
public Long2IntOpenHashTable() {// required for serialization
this._loadFactor = DEFAULT_LOAD_FACTOR;
this._growFactor = DEFAULT_GROW_FACTOR;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
index 4599bfc..c16567a 100644
--- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
+++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java
@@ -59,7 +59,10 @@ public final class OpenHashTable<K, V> implements Externalizable {
protected V[] _values;
protected byte[] _states;
- public OpenHashTable() {} // for Externalizable
+ /**
+ * Only for {@link Externalizable}
+ */
+ public OpenHashTable() {}
public OpenHashTable(int size) {
this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
new file mode 100644
index 0000000..bea2a9d
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/lang/mutable/MutableObject.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.lang.mutable;
+
+import javax.annotation.Nullable;
+
+public final class MutableObject<T> {
+
+ @Nullable
+ private T _value;
+
+ public MutableObject() {}
+
+ public MutableObject(@Nullable T obj) {
+ this._value = obj;
+ }
+
+ public boolean isSet() {
+ return _value != null;
+ }
+
+ @Nullable
+ public T get() {
+ return _value;
+ }
+
+ public void set(@Nullable T obj) {
+ this._value = obj;
+ }
+
+ public void setIfAbsent(@Nullable T obj) {
+ if (_value == null) {
+ this._value = obj;
+ }
+ }
+
+ @Override
+ public boolean equals(@Nullable Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ MutableObject<?> other = (MutableObject<?>) obj;
+ if (_value == null) {
+ if (other._value != null) {
+ return false;
+ }
+ }
+ return _value.equals(other._value);
+ }
+
+ @Override
+ public int hashCode() {
+ return _value == null ? 0 : _value.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return _value == null ? "null" : _value.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/main/java/hivemall/utils/math/MathUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java
index ee533dc..71d4c29 100644
--- a/core/src/main/java/hivemall/utils/math/MathUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MathUtils.java
@@ -43,7 +43,7 @@ import javax.annotation.Nullable;
import org.apache.commons.math3.special.Gamma;
public final class MathUtils {
- private static final double LOG2 = Math.log(2);
+ public static final double LOG2 = Math.log(2);
private MathUtils() {}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
index 5e8f253..574fc04 100644
--- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
+++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java
@@ -103,7 +103,7 @@ public class BinaryResponsesMeasuresTest {
List<Integer> groundTruth = Arrays.asList(1, 2, 4);
double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth,
- rankedList.size());
+ rankedList.size());
Assert.assertEquals(1.0d, actual, 0.0001d);
Collections.reverse(rankedList);
@@ -115,6 +115,22 @@ public class BinaryResponsesMeasuresTest {
Assert.assertEquals(0.0d, actual, 0.0001d);
}
+ public void testHit() {
+ List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
+ List<Integer> groundTruth = Arrays.asList(1, 2, 4);
+
+ double actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, rankedList.size());
+ Assert.assertEquals(1.d, actual, 0.0001d);
+
+ actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2);
+ Assert.assertEquals(1.d, actual, 0.0001d);
+
+ // not hitting case
+ rankedList = Arrays.asList(5, 6);
+ actual = BinaryResponsesMeasures.Hit(rankedList, groundTruth, 2);
+ Assert.assertEquals(0.d, actual, 0.0001d);
+ }
+
@Test
public void testAP() {
List<Integer> rankedList = Arrays.asList(1, 3, 2, 6);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
index 6a7cc9d..96ac030 100644
--- a/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
+++ b/core/src/test/java/hivemall/evaluation/GradedResponsesMeasuresTest.java
@@ -18,12 +18,12 @@
*/
package hivemall.evaluation;
-import org.junit.Assert;
-import org.junit.Test;
-
import java.util.Arrays;
import java.util.List;
+import org.junit.Assert;
+import org.junit.Test;
+
public class GradedResponsesMeasuresTest {
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
index decd7df..af3f024 100644
--- a/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
+++ b/core/src/test/java/hivemall/math/matrix/MatrixBuilderTest.java
@@ -225,7 +225,6 @@ public class MatrixBuilderTest {
Assert.assertEquals(Double.NaN, csc2.get(5, 4, Double.NaN), 0.d);
}
-
@Test
public void testDoKMatrixFromLibSVM() {
Matrix matrix = dokMatrixFromLibSVM();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
new file mode 100644
index 0000000..c9e6afd
--- /dev/null
+++ b/core/src/test/java/hivemall/math/matrix/sparse/DoKFloatMatrixTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.math.matrix.sparse;
+
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DoKFloatMatrixTest {
+
+ @Test
+ public void testGetSet() {
+ DoKFloatMatrix matrix = new DoKFloatMatrix();
+ Random rnd = new Random(43);
+
+ for (int i = 0; i < 1000; i++) {
+ int row = Math.abs(rnd.nextInt());
+ int col = Math.abs(rnd.nextInt());
+ double v = rnd.nextDouble();
+ matrix.set(row, col, v);
+ Assert.assertEquals(v, matrix.get(row, col), 0.00001d);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/recommend/SlimUDTFTest.java b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
new file mode 100644
index 0000000..00b78f0
--- /dev/null
+++ b/core/src/test/java/hivemall/recommend/SlimUDTFTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.recommend;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Test;
+
+public class SlimUDTFTest {
+ @Test
+ public void testAllSamples() throws HiveException {
+ SlimUDTF slim = new SlimUDTF();
+ ObjectInspector itemIOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+ ObjectInspector itemJOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+
+ ObjectInspector itemIRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
+ ObjectInspector itemJRatesOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
+ ObjectInspector topKRatesOfIOI = ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector,
+ PrimitiveObjectInspectorFactory.javaFloatObjectInspector));
+ ObjectInspector optionArgumentOI = ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-l2 0.01 -l1 0.01");
+
+ ObjectInspector[] argOIs = {itemIOI, itemIRatesOI, topKRatesOfIOI, itemJOI, itemJRatesOI,
+ optionArgumentOI};
+
+ slim.initialize(argOIs);
+ int numUser = 4;
+ int numItem = 5;
+
+ float[][] data = { {1.f, 4.f, 0.f, 0.f, 0.f}, {0.f, 3.f, 0.f, 1.f, 2.f},
+ {2.f, 2.f, 0.f, 0.f, 3.f}, {0.f, 1.f, 1.f, 0.f, 0.f}};
+
+ for (int i = 0; i < numItem; i++) {
+ Map<Integer, Float> Ri = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ if (data[u][i] != 0.) {
+ Ri.put(u, data[u][i]);
+ }
+ }
+
+ // most similar data
+ Map<Integer, Map<Integer, Float>> knnRatesOfI = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ Map<Integer, Float> Ru = new HashMap<>();
+ for (int k = 0; k < numItem; k++) {
+ if (k == i)
+ continue;
+ Ru.put(k, data[u][k]);
+ }
+ knnRatesOfI.put(u, Ru);
+ }
+
+ for (int j = 0; j < numItem; j++) {
+ if (i == j)
+ continue;
+ Map<Integer, Float> Rj = new HashMap<>();
+ for (int u = 0; u < numUser; u++) {
+ if (data[u][j] != 0.) {
+ Rj.put(u, data[u][j]);
+ }
+ }
+
+ Object[] args = {i, Ri, knnRatesOfI, j, Rj};
+ slim.process(args);
+ }
+ }
+ slim.finalizeTraining();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 3d640f8..8b76a7f 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -155,6 +155,7 @@
* [Item-based Collaborative Filtering](recommend/movielens_cf.md)
* [Matrix Factorization](recommend/movielens_mf.md)
* [Factorization Machine](recommend/movielens_fm.md)
+ * [SLIM for Fast Top-K Recommendation](recommend/movielens_slim.md)
* [10-fold Cross Validation (Matrix Factorization)](recommend/movielens_cv.md)
## Part X - Anomaly Detection
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/item_based_cf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/item_based_cf.md b/docs/gitbook/recommend/item_based_cf.md
index 053b225..dcd4f57 100644
--- a/docs/gitbook/recommend/item_based_cf.md
+++ b/docs/gitbook/recommend/item_based_cf.md
@@ -325,7 +325,7 @@ similarity as (
o.other,
cosine_similarity(t1.feature_vector, t2.feature_vector) as similarity
from
- cooccurrence_top100 o
+ cooccurrence_top100 o
-- cooccurrence_upper_triangular o
JOIN item_features t1 ON (o.itemid = t1.itemid)
JOIN item_features t2 ON (o.other = t2.itemid)
@@ -652,7 +652,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion
item_features f
left outer join item_magnitude m
),
-similarity as ( -- reduce (i.e., sum up) mappers' partial results
+similarity as (
+ -- reduce (i.e., sum up) mappers' partial results
select
itemid,
other,
@@ -702,7 +703,8 @@ partial_result as (
item_features f
left outer join item_magnitude m
),
-similarity_upper_triangular as ( -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted
+similarity_upper_triangular as (
+ -- if similarity of (i1, i2) pair is in this table, (i2, i1)'s similarity is omitted
select
itemid,
other,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cf.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_cf.md b/docs/gitbook/recommend/movielens_cf.md
index 1cf5aee..0602611 100644
--- a/docs/gitbook/recommend/movielens_cf.md
+++ b/docs/gitbook/recommend/movielens_cf.md
@@ -66,7 +66,8 @@ partial_result as ( -- launch DIMSUM in a MapReduce fashion
movie_features f
left outer join movie_magnitude m
),
-similarity as ( -- reduce (i.e., sum up) mappers' partial results
+similarity as (
+ -- reduce (i.e., sum up) mappers' partial results
select
movieid,
other,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_cv.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_cv.md b/docs/gitbook/recommend/movielens_cv.md
index 6ac54c7..80c0d19 100644
--- a/docs/gitbook/recommend/movielens_cv.md
+++ b/docs/gitbook/recommend/movielens_cv.md
@@ -17,7 +17,7 @@
under the License.
-->
-[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validationk-fold cross validation) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation) to evaluate prediction performance.
+[Cross-validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29) is a model validation technique for assessing how a prediction model will generalize to an independent data set. This example shows a way to perform [k-fold cross validation](http://en.wikipedia.org/wiki/Cross-validation_%28statistics%29#k-fold_cross-validation) to evaluate prediction performance.
*Caution:* Matrix factorization is supported in Hivemall v0.3 or later.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/995b9a88/docs/gitbook/recommend/movielens_fm.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/recommend/movielens_fm.md b/docs/gitbook/recommend/movielens_fm.md
index 64039fe..d3d2c82 100644
--- a/docs/gitbook/recommend/movielens_fm.md
+++ b/docs/gitbook/recommend/movielens_fm.md
@@ -19,6 +19,8 @@
_Caution: Factorization Machine is supported from Hivemall v0.4 or later._
+<!-- toc -->
+
# Data preparation
First of all, please create `ratings` table described in [this article](../recommend/movielens_dataset.html).
@@ -89,7 +91,7 @@ set hivevar:factor=10;
set hivevar:iters=50;
```
-## Build a prediction mdoel by Factorization Machine
+## Build a prediction model by Factorization Machine
```sql
drop table fm_model;