You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 14:52:15 UTC
[47/51] [partial] mahout git commit: MAHOUT-2042 and MAHOUT-2045
Delete directories which were moved/no longer in use
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java
new file mode 100644
index 0000000..3ae9990
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
+import org.apache.mahout.cf.taste.impl.common.jdbc.ResultSetIterator;
+import org.apache.mahout.cf.taste.impl.model.jdbc.ConnectionPoolDataSource;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.sql.DataSource;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * loads all similarities from the database into RAM
+ */
+abstract class AbstractJDBCInMemoryItemSimilarity extends AbstractJDBCComponent implements ItemSimilarity {
+
+ private ItemSimilarity delegate;
+
+ private final DataSource dataSource;
+ private final String getAllItemSimilaritiesSQL;
+ private final ReentrantLock reloadLock;
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractJDBCInMemoryItemSimilarity.class);
+
+ AbstractJDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) {
+
+ AbstractJDBCComponent.checkNotNullAndLog("getAllItemSimilaritiesSQL", getAllItemSimilaritiesSQL);
+
+ if (!(dataSource instanceof ConnectionPoolDataSource)) {
+ log.warn("You are not using ConnectionPoolDataSource. Make sure your DataSource pools connections "
+ + "to the database itself, or database performance will be severely reduced.");
+ }
+
+ this.dataSource = dataSource;
+ this.getAllItemSimilaritiesSQL = getAllItemSimilaritiesSQL;
+ this.reloadLock = new ReentrantLock();
+
+ reload();
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ return delegate.itemSimilarity(itemID1, itemID2);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ return delegate.itemSimilarities(itemID1, itemID2s);
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) throws TasteException {
+ return delegate.allSimilarItemIDs(itemID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ log.debug("Reloading...");
+ reload();
+ }
+
+ protected void reload() {
+ if (reloadLock.tryLock()) {
+ try {
+ delegate = new GenericItemSimilarity(new JDBCSimilaritiesIterable(dataSource, getAllItemSimilaritiesSQL));
+ } finally {
+ reloadLock.unlock();
+ }
+ }
+ }
+
+ private static final class JDBCSimilaritiesIterable implements Iterable<GenericItemSimilarity.ItemItemSimilarity> {
+
+ private final DataSource dataSource;
+ private final String getAllItemSimilaritiesSQL;
+
+ private JDBCSimilaritiesIterable(DataSource dataSource, String getAllItemSimilaritiesSQL) {
+ this.dataSource = dataSource;
+ this.getAllItemSimilaritiesSQL = getAllItemSimilaritiesSQL;
+ }
+
+ @Override
+ public Iterator<GenericItemSimilarity.ItemItemSimilarity> iterator() {
+ try {
+ return new JDBCSimilaritiesIterator(dataSource, getAllItemSimilaritiesSQL);
+ } catch (SQLException sqle) {
+ throw new IllegalStateException(sqle);
+ }
+ }
+ }
+
+ private static final class JDBCSimilaritiesIterator
+ extends ResultSetIterator<GenericItemSimilarity.ItemItemSimilarity> {
+
+ private JDBCSimilaritiesIterator(DataSource dataSource, String getAllItemSimilaritiesSQL) throws SQLException {
+ super(dataSource, getAllItemSimilaritiesSQL);
+ }
+
+ @Override
+ protected GenericItemSimilarity.ItemItemSimilarity parseElement(ResultSet resultSet) throws SQLException {
+ return new GenericItemSimilarity.ItemItemSimilarity(resultSet.getLong(1),
+ resultSet.getLong(2),
+ resultSet.getDouble(3));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java
new file mode 100644
index 0000000..1b8d109
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java
@@ -0,0 +1,213 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.Collection;
+
+import javax.sql.DataSource;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
+import org.apache.mahout.cf.taste.impl.model.jdbc.ConnectionPoolDataSource;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An {@link ItemSimilarity} which draws pre-computed item-item similarities from a database table via JDBC.
+ */
+public abstract class AbstractJDBCItemSimilarity extends AbstractJDBCComponent implements ItemSimilarity {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractJDBCItemSimilarity.class);
+
+ static final String DEFAULT_SIMILARITY_TABLE = "taste_item_similarity";
+ static final String DEFAULT_ITEM_A_ID_COLUMN = "item_id_a";
+ static final String DEFAULT_ITEM_B_ID_COLUMN = "item_id_b";
+ static final String DEFAULT_SIMILARITY_COLUMN = "similarity";
+
+ private final DataSource dataSource;
+ private final String similarityTable;
+ private final String itemAIDColumn;
+ private final String itemBIDColumn;
+ private final String similarityColumn;
+ private final String getItemItemSimilaritySQL;
+ private final String getAllSimilarItemIDsSQL;
+
+ protected AbstractJDBCItemSimilarity(DataSource dataSource,
+ String getItemItemSimilaritySQL,
+ String getAllSimilarItemIDsSQL) {
+ this(dataSource,
+ DEFAULT_SIMILARITY_TABLE,
+ DEFAULT_ITEM_A_ID_COLUMN,
+ DEFAULT_ITEM_B_ID_COLUMN,
+ DEFAULT_SIMILARITY_COLUMN,
+ getItemItemSimilaritySQL,
+ getAllSimilarItemIDsSQL);
+ }
+
+ protected AbstractJDBCItemSimilarity(DataSource dataSource,
+ String similarityTable,
+ String itemAIDColumn,
+ String itemBIDColumn,
+ String similarityColumn,
+ String getItemItemSimilaritySQL,
+ String getAllSimilarItemIDsSQL) {
+ AbstractJDBCComponent.checkNotNullAndLog("similarityTable", similarityTable);
+ AbstractJDBCComponent.checkNotNullAndLog("itemAIDColumn", itemAIDColumn);
+ AbstractJDBCComponent.checkNotNullAndLog("itemBIDColumn", itemBIDColumn);
+ AbstractJDBCComponent.checkNotNullAndLog("similarityColumn", similarityColumn);
+
+ AbstractJDBCComponent.checkNotNullAndLog("getItemItemSimilaritySQL", getItemItemSimilaritySQL);
+ AbstractJDBCComponent.checkNotNullAndLog("getAllSimilarItemIDsSQL", getAllSimilarItemIDsSQL);
+
+ if (!(dataSource instanceof ConnectionPoolDataSource)) {
+ log.warn("You are not using ConnectionPoolDataSource. Make sure your DataSource pools connections "
+ + "to the database itself, or database performance will be severely reduced.");
+ }
+
+ this.dataSource = dataSource;
+ this.similarityTable = similarityTable;
+ this.itemAIDColumn = itemAIDColumn;
+ this.itemBIDColumn = itemBIDColumn;
+ this.similarityColumn = similarityColumn;
+ this.getItemItemSimilaritySQL = getItemItemSimilaritySQL;
+ this.getAllSimilarItemIDsSQL = getAllSimilarItemIDsSQL;
+ }
+
+ protected String getSimilarityTable() {
+ return similarityTable;
+ }
+
+ protected String getItemAIDColumn() {
+ return itemAIDColumn;
+ }
+
+ protected String getItemBIDColumn() {
+ return itemBIDColumn;
+ }
+
+ protected String getSimilarityColumn() {
+ return similarityColumn;
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ if (itemID1 == itemID2) {
+ return 1.0;
+ }
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getItemItemSimilaritySQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+ stmt.setFetchSize(getFetchSize());
+ return doItemSimilarity(stmt, itemID1, itemID2);
+ } catch (SQLException sqle) {
+ log.warn("Exception while retrieving similarity", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.quietClose(null, stmt, conn);
+ }
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ double[] result = new double[itemID2s.length];
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getItemItemSimilaritySQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+ stmt.setFetchSize(getFetchSize());
+ for (int i = 0; i < itemID2s.length; i++) {
+ result[i] = doItemSimilarity(stmt, itemID1, itemID2s[i]);
+ }
+ } catch (SQLException sqle) {
+ log.warn("Exception while retrieving item similarities", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.quietClose(null, stmt, conn);
+ }
+ return result;
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) throws TasteException {
+ FastIDSet allSimilarItemIDs = new FastIDSet();
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getAllSimilarItemIDsSQL, ResultSet.TYPE_FORWARD_ONLY,
+ ResultSet.CONCUR_READ_ONLY);
+ stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+ stmt.setFetchSize(getFetchSize());
+ stmt.setLong(1, itemID);
+ stmt.setLong(2, itemID);
+ rs = stmt.executeQuery();
+ while (rs.next()) {
+ allSimilarItemIDs.add(rs.getLong(1));
+ allSimilarItemIDs.add(rs.getLong(2));
+ }
+ } catch (SQLException sqle) {
+ log.warn("Exception while retrieving all similar itemIDs", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.quietClose(rs, stmt, conn);
+ }
+ allSimilarItemIDs.remove(itemID);
+ return allSimilarItemIDs.toArray();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }
+
+ private double doItemSimilarity(PreparedStatement stmt, long itemID1, long itemID2) throws SQLException {
+ // Order as smaller - larger
+ if (itemID1 > itemID2) {
+ long temp = itemID1;
+ itemID1 = itemID2;
+ itemID2 = temp;
+ }
+ stmt.setLong(1, itemID1);
+ stmt.setLong(2, itemID2);
+ log.debug("Executing SQL query: {}", getItemItemSimilaritySQL);
+ ResultSet rs = null;
+ try {
+ rs = stmt.executeQuery();
+ // If not found, perhaps the items exist but have no presence in the table,
+ // so NaN is appropriate
+ return rs.next() ? rs.getDouble(1) : Double.NaN;
+ } finally {
+ IOUtils.quietClose(rs);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java
new file mode 100644
index 0000000..cc831d9
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+import javax.sql.DataSource;
+
+public class MySQLJDBCInMemoryItemSimilarity extends SQL92JDBCInMemoryItemSimilarity {
+
+ public MySQLJDBCInMemoryItemSimilarity() throws TasteException {
+ }
+
+ public MySQLJDBCInMemoryItemSimilarity(String dataSourceName) throws TasteException {
+ super(dataSourceName);
+ }
+
+ public MySQLJDBCInMemoryItemSimilarity(DataSource dataSource) {
+ super(dataSource);
+ }
+
+ public MySQLJDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) {
+ super(dataSource, getAllItemSimilaritiesSQL);
+ }
+
+ @Override
+ protected int getFetchSize() {
+ // Need to return this for MySQL Connector/J to make it use streaming mode
+ return Integer.MIN_VALUE;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java
new file mode 100644
index 0000000..af0742e
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java
@@ -0,0 +1,103 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import javax.sql.DataSource;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * An {@link org.apache.mahout.cf.taste.similarity.ItemSimilarity} backed by a MySQL database
+ * and accessed via JDBC. It may work with other JDBC
+ * databases. By default, this class assumes that there is a {@link DataSource} available under the JNDI name
+ * "jdbc/taste", which gives access to a database with a "taste_item_similarity" table with the following
+ * schema:
+ * </p>
+ *
+ * <table>
+ * <tr>
+ * <th>item_id_a</th>
+ * <th>item_id_b</th>
+ * <th>similarity</th>
+ * </tr>
+ * <tr>
+ * <td>ABC</td>
+ * <td>DEF</td>
+ * <td>0.9</td>
+ * </tr>
+ * <tr>
+ * <td>DEF</td>
+ * <td>EFG</td>
+ * <td>0.1</td>
+ * </tr>
+ * </table>
+ *
+ * <p>
+ * For example, the following command sets up a suitable table in MySQL, complete with primary key and
+ * indexes:
+ * </p>
+ *
+ * <p>
+ *
+ * <pre>
+ * CREATE TABLE taste_item_similarity (
+ * item_id_a BIGINT NOT NULL,
+ * item_id_b BIGINT NOT NULL,
+ * similarity FLOAT NOT NULL,
+ * PRIMARY KEY (item_id_a, item_id_b),
+ * )
+ * </pre>
+ *
+ * </p>
+ *
+ * <p>
+ * Note that for each row, item_id_a should be less than item_id_b. It is redundant to store it both ways,
+ * so the pair is always stored as a pair with the lesser one first.
+ *
+ * @see org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel
+ */
+public class MySQLJDBCItemSimilarity extends SQL92JDBCItemSimilarity {
+
+ public MySQLJDBCItemSimilarity() throws TasteException {
+ }
+
+ public MySQLJDBCItemSimilarity(String dataSourceName) throws TasteException {
+ super(dataSourceName);
+ }
+
+ public MySQLJDBCItemSimilarity(DataSource dataSource) {
+ super(dataSource);
+ }
+
+ public MySQLJDBCItemSimilarity(DataSource dataSource,
+ String similarityTable,
+ String itemAIDColumn,
+ String itemBIDColumn,
+ String similarityColumn) {
+ super(dataSource, similarityTable, itemAIDColumn, itemBIDColumn, similarityColumn);
+ }
+
+ @Override
+ protected int getFetchSize() {
+ // Need to return this for MySQL Connector/J to make it use streaming mode
+ return Integer.MIN_VALUE;
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java
new file mode 100644
index 0000000..b311a5e
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
+
+import javax.sql.DataSource;
+
+public class SQL92JDBCInMemoryItemSimilarity extends AbstractJDBCInMemoryItemSimilarity {
+
+ static final String DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL =
+ "SELECT " + AbstractJDBCItemSimilarity.DEFAULT_ITEM_A_ID_COLUMN + ", "
+ + AbstractJDBCItemSimilarity.DEFAULT_ITEM_B_ID_COLUMN + ", "
+ + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_COLUMN + " FROM "
+ + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_TABLE;
+
+
+ public SQL92JDBCInMemoryItemSimilarity() throws TasteException {
+ this(AbstractJDBCComponent.lookupDataSource(AbstractJDBCComponent.DEFAULT_DATASOURCE_NAME),
+ DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL);
+ }
+
+ public SQL92JDBCInMemoryItemSimilarity(String dataSourceName) throws TasteException {
+ this(AbstractJDBCComponent.lookupDataSource(dataSourceName), DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL);
+ }
+
+ public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource) {
+ this(dataSource, DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL);
+ }
+
+ public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) {
+ super(dataSource, getAllItemSimilaritiesSQL);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java
new file mode 100644
index 0000000..f449561
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.jdbc;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+import javax.sql.DataSource;
+
+public class SQL92JDBCItemSimilarity extends AbstractJDBCItemSimilarity {
+
+ public SQL92JDBCItemSimilarity() throws TasteException {
+ this(DEFAULT_DATASOURCE_NAME);
+ }
+
+ public SQL92JDBCItemSimilarity(String dataSourceName) throws TasteException {
+ this(lookupDataSource(dataSourceName));
+ }
+
+ public SQL92JDBCItemSimilarity(DataSource dataSource) {
+ this(dataSource,
+ DEFAULT_SIMILARITY_TABLE,
+ DEFAULT_ITEM_A_ID_COLUMN,
+ DEFAULT_ITEM_B_ID_COLUMN,
+ DEFAULT_SIMILARITY_COLUMN);
+ }
+
+ public SQL92JDBCItemSimilarity(DataSource dataSource,
+ String similarityTable,
+ String itemAIDColumn,
+ String itemBIDColumn,
+ String similarityColumn) {
+ super(dataSource,
+ similarityTable,
+ itemAIDColumn,
+ itemBIDColumn, similarityColumn,
+ "SELECT " + similarityColumn + " FROM " + similarityTable + " WHERE "
+ + itemAIDColumn + "=? AND " + itemBIDColumn + "=?",
+ "SELECT " + itemAIDColumn + ", " + itemBIDColumn + " FROM " + similarityTable + " WHERE "
+ + itemAIDColumn + "=? OR " + itemBIDColumn + "=?");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java
new file mode 100644
index 0000000..a5a89c6
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java
@@ -0,0 +1,215 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.web;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+import javax.servlet.ServletConfig;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServlet;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.List;
+
+/**
+ * <p>A servlet which returns recommendations, as its name implies. The servlet accepts GET and POST
+ * HTTP requests, and looks for two parameters:</p>
+ *
+ * <ul>
+ * <li><em>userID</em>: the user ID for which to produce recommendations</li>
+ * <li><em>howMany</em>: the number of recommendations to produce</li>
+ * <li><em>debug</em>: (optional) output a lot of information that is useful in debugging.
+ * Defaults to false, of course.</li>
+ * </ul>
+ *
+ * <p>The response is text, and contains a list of the IDs of recommended items, in descending
+ * order of relevance, one per line.</p>
+ *
+ * <p>For example, you can get 10 recommendations for user 123 from the following URL (assuming
+ * you are running taste in a web application running locally on port 8080):<br/>
+ * {@code http://localhost:8080/taste/RecommenderServlet?userID=123&howMany=10}</p>
+ *
+ * <p>This servlet requires one {@code init-param} in {@code web.xml}: it must find
+ * a parameter named "recommender-class" which is the name of a class that implements
+ * {@link Recommender} and has a no-arg constructor. The servlet will instantiate and use
+ * this {@link Recommender} to produce recommendations.</p>
+ */
+public final class RecommenderServlet extends HttpServlet {
+
+ private static final int NUM_TOP_PREFERENCES = 20;
+ private static final int DEFAULT_HOW_MANY = 20;
+
+ private Recommender recommender;
+
+ @Override
+ public void init(ServletConfig config) throws ServletException {
+ super.init(config);
+ String recommenderClassName = config.getInitParameter("recommender-class");
+ if (recommenderClassName == null) {
+ throw new ServletException("Servlet init-param \"recommender-class\" is not defined");
+ }
+ RecommenderSingleton.initializeIfNeeded(recommenderClassName);
+ recommender = RecommenderSingleton.getInstance().getRecommender();
+ }
+
+ @Override
+ public void doGet(HttpServletRequest request,
+ HttpServletResponse response) throws ServletException {
+
+ String userIDString = request.getParameter("userID");
+ if (userIDString == null) {
+ throw new ServletException("userID was not specified");
+ }
+ long userID = Long.parseLong(userIDString);
+ String howManyString = request.getParameter("howMany");
+ int howMany = howManyString == null ? DEFAULT_HOW_MANY : Integer.parseInt(howManyString);
+ boolean debug = Boolean.parseBoolean(request.getParameter("debug"));
+ String format = request.getParameter("format");
+ if (format == null) {
+ format = "text";
+ }
+
+ try {
+ List<RecommendedItem> items = recommender.recommend(userID, howMany);
+ if ("text".equals(format)) {
+ writePlainText(response, userID, debug, items);
+ } else if ("xml".equals(format)) {
+ writeXML(response, items);
+ } else if ("json".equals(format)) {
+ writeJSON(response, items);
+ } else {
+ throw new ServletException("Bad format parameter: " + format);
+ }
+ } catch (TasteException | IOException te) {
+ throw new ServletException(te);
+ }
+
+ }
+
+ private static void writeXML(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException {
+ response.setContentType("application/xml");
+ response.setCharacterEncoding("UTF-8");
+ response.setHeader("Cache-Control", "no-cache");
+ PrintWriter writer = response.getWriter();
+ writer.print("<?xml version=\"1.0\" encoding=\"UTF-8\"?><recommendedItems>");
+ for (RecommendedItem recommendedItem : items) {
+ writer.print("<item><value>");
+ writer.print(recommendedItem.getValue());
+ writer.print("</value><id>");
+ writer.print(recommendedItem.getItemID());
+ writer.print("</id></item>");
+ }
+ writer.println("</recommendedItems>");
+ }
+
+ private static void writeJSON(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException {
+ response.setContentType("application/json");
+ response.setCharacterEncoding("UTF-8");
+ response.setHeader("Cache-Control", "no-cache");
+ PrintWriter writer = response.getWriter();
+ writer.print("{\"recommendedItems\":{\"item\":[");
+ boolean first = true;
+ for (RecommendedItem recommendedItem : items) {
+ if (first) {
+ first = false;
+ } else {
+ writer.print(',');
+ }
+ writer.print("{\"value\":\"");
+ writer.print(recommendedItem.getValue());
+ writer.print("\",\"id\":\"");
+ writer.print(recommendedItem.getItemID());
+ writer.print("\"}");
+ }
+ writer.println("]}}");
+ }
+
+ private void writePlainText(HttpServletResponse response,
+ long userID,
+ boolean debug,
+ Iterable<RecommendedItem> items) throws IOException, TasteException {
+ response.setContentType("text/plain");
+ response.setCharacterEncoding("UTF-8");
+ response.setHeader("Cache-Control", "no-cache");
+ PrintWriter writer = response.getWriter();
+ if (debug) {
+ writeDebugRecommendations(userID, items, writer);
+ } else {
+ writeRecommendations(items, writer);
+ }
+ }
+
+ private static void writeRecommendations(Iterable<RecommendedItem> items, PrintWriter writer) {
+ for (RecommendedItem recommendedItem : items) {
+ writer.print(recommendedItem.getValue());
+ writer.print('\t');
+ writer.println(recommendedItem.getItemID());
+ }
+ }
+
+ private void writeDebugRecommendations(long userID, Iterable<RecommendedItem> items, PrintWriter writer)
+ throws TasteException {
+ DataModel dataModel = recommender.getDataModel();
+ writer.print("User:");
+ writer.println(userID);
+ writer.print("Recommender: ");
+ writer.println(recommender);
+ writer.println();
+ writer.print("Top ");
+ writer.print(NUM_TOP_PREFERENCES);
+ writer.println(" Preferences:");
+ PreferenceArray rawPrefs = dataModel.getPreferencesFromUser(userID);
+ int length = rawPrefs.length();
+ PreferenceArray sortedPrefs = rawPrefs.clone();
+ sortedPrefs.sortByValueReversed();
+ // Cap this at NUM_TOP_PREFERENCES just to be brief
+ int max = Math.min(NUM_TOP_PREFERENCES, length);
+ for (int i = 0; i < max; i++) {
+ Preference pref = sortedPrefs.get(i);
+ writer.print(pref.getValue());
+ writer.print('\t');
+ writer.println(pref.getItemID());
+ }
+ writer.println();
+ writer.println("Recommendations:");
+ for (RecommendedItem recommendedItem : items) {
+ writer.print(recommendedItem.getValue());
+ writer.print('\t');
+ writer.println(recommendedItem.getItemID());
+ }
+ }
+
+ @Override
+ public void doPost(HttpServletRequest request,
+ HttpServletResponse response) throws ServletException {
+ doGet(request, response);
+ }
+
+ @Override
+ public String toString() {
+ return "RecommenderServlet[recommender:" + recommender + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java
new file mode 100644
index 0000000..265d7c0
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.web;
+
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.ClassUtils;
+
+/**
+ * <p>A singleton which holds an instance of a {@link Recommender}. This is used to share
+ * a {@link Recommender} between {@link RecommenderServlet} and {@code RecommenderService.jws}.</p>
+ */
+public final class RecommenderSingleton {
+
+ private final Recommender recommender;
+
+ private static RecommenderSingleton instance;
+
+ public static synchronized RecommenderSingleton getInstance() {
+ if (instance == null) {
+ throw new IllegalStateException("Not initialized");
+ }
+ return instance;
+ }
+
+ public static synchronized void initializeIfNeeded(String recommenderClassName) {
+ if (instance == null) {
+ instance = new RecommenderSingleton(recommenderClassName);
+ }
+ }
+
+ private RecommenderSingleton(String recommenderClassName) {
+ if (recommenderClassName == null) {
+ throw new IllegalArgumentException("Recommender class name is null");
+ }
+ recommender = ClassUtils.instantiateAs(recommenderClassName, Recommender.class);
+ }
+
+ public Recommender getRecommender() {
+ return recommender;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java
new file mode 100644
index 0000000..e927098
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java
@@ -0,0 +1,126 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.web;
+
+import com.google.common.io.Files;
+import com.google.common.io.InputSupplier;
+import com.google.common.io.Resources;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Users of the packaging and deployment mechanism in this module need
+ * to produce a {@link Recommender} implementation with a no-arg constructor,
+ * which will internally build the desired {@link Recommender} and delegate
+ * to it. This wrapper simplifies that process. Simply extend this class and
+ * implement {@link #buildRecommender()}.
+ */
+public abstract class RecommenderWrapper implements Recommender {
+
+ private static final Logger log = LoggerFactory.getLogger(RecommenderWrapper.class);
+
+ private final Recommender delegate;
+
+ protected RecommenderWrapper() throws TasteException, IOException {
+ this.delegate = buildRecommender();
+ }
+
+ /**
+ * @return the {@link Recommender} which should be used to produce recommendations
+ * by this wrapper implementation
+ */
+ protected abstract Recommender buildRecommender() throws IOException, TasteException;
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return delegate.recommend(userID, howMany);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return delegate.recommend(userID, howMany, rescorer);
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return delegate.estimatePreference(userID, itemID);
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ delegate.setPreference(userID, itemID, value);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ delegate.removePreference(userID, itemID);
+ }
+
+ @Override
+ public DataModel getDataModel() {
+ return delegate.getDataModel();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ delegate.refresh(alreadyRefreshed);
+ }
+
+ /**
+ * Reads the given resource into a temporary file. This is intended to be used
+ * to read data files which are stored as a resource available on the classpath,
+ * such as in a JAR file. However for convenience the resource name will also
+ * be interpreted as a relative path to a local file, if no such resource is
+ * found. This facilitates testing.
+ *
+ * @param resourceName name of resource in classpath, or relative path to file
+ * @return temporary {@link File} with resource data
+ * @throws IOException if an error occurs while reading or writing data
+ */
+ public static File readResourceToTempFile(String resourceName) throws IOException {
+ String absoluteResource = resourceName.startsWith("/") ? resourceName : '/' + resourceName;
+ log.info("Loading resource {}", absoluteResource);
+ InputSupplier<? extends InputStream> inSupplier;
+ try {
+ URL resourceURL = Resources.getResource(RecommenderWrapper.class, absoluteResource);
+ inSupplier = Resources.newInputStreamSupplier(resourceURL);
+ } catch (IllegalArgumentException iae) {
+ File resourceFile = new File(resourceName);
+ log.info("Falling back to load file {}", resourceFile.getAbsolutePath());
+ inSupplier = Files.newInputStreamSupplier(resourceFile);
+ }
+ File tempFile = File.createTempFile("taste", null);
+ tempFile.deleteOnExit();
+ Files.copy(inSupplier, tempFile);
+ return tempFile;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
new file mode 100644
index 0000000..03a3000
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table
+ * Table of counts all with optional HTML wrappers
+ *
+ * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair
+ *
+ * Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class
+ */
+public final class ConfusionMatrixDumper extends AbstractJob {
+
+ private static final String TAB_SEPARATOR = "|";
+
+ // HTML wrapper - default CSS
+ private static final String HEADER = "<html>"
+ + "<head>\n"
+ + "<title>TITLE</title>\n"
+ + "</head>"
+ + "<body>\n"
+ + "<style type='text/css'> \n"
+ + "table\n"
+ + "{\n"
+ + "border:3px solid black; text-align:left;\n"
+ + "}\n"
+ + "th.normalHeader\n"
+ + "{\n"
+ + "border:1px solid black;border-collapse:collapse;text-align:center;"
+ + "background-color:white\n"
+ + "}\n"
+ + "th.tallHeader\n"
+ + "{\n"
+ + "border:1px solid black;border-collapse:collapse;text-align:center;"
+ + "background-color:white; height:6em\n"
+ + "}\n"
+ + "tr.label\n"
+ + "{\n"
+ + "border:1px solid black;border-collapse:collapse;text-align:center;"
+ + "background-color:white\n"
+ + "}\n"
+ + "tr.row\n"
+ + "{\n"
+ + "border:1px solid gray;text-align:center;background-color:snow\n"
+ + "}\n"
+ + "td\n"
+ + "{\n"
+ + "min-width:2em\n"
+ + "}\n"
+ + "td.cell\n"
+ + "{\n"
+ + "border:1px solid black;text-align:right;background-color:snow\n"
+ + "}\n"
+ + "td.empty\n"
+ + "{\n"
+ + "border:0px;text-align:right;background-color:snow\n"
+ + "}\n"
+ + "td.white\n"
+ + "{\n"
+ + "border:0px solid black;text-align:right;background-color:white\n"
+ + "}\n"
+ + "td.black\n"
+ + "{\n"
+ + "border:0px solid red;text-align:right;background-color:black\n"
+ + "}\n"
+ + "td.gray1\n"
+ + "{\n"
+ + "border:0px solid green;text-align:right; background-color:LightGray\n"
+ + "}\n" + "td.gray2\n" + "{\n"
+ + "border:0px solid blue;text-align:right;background-color:gray\n"
+ + "}\n" + "td.gray3\n" + "{\n"
+ + "border:0px solid red;text-align:right;background-color:DarkGray\n"
+ + "}\n" + "th" + "{\n" + " text-align: center;\n"
+ + " vertical-align: bottom;\n"
+ + " padding-bottom: 3px;\n" + " padding-left: 5px;\n"
+ + " padding-right: 5px;\n" + "}\n" + " .verticalText\n"
+ + " {\n" + " text-align: center;\n"
+ + " vertical-align: middle;\n" + " width: 20px;\n"
+ + " margin: 0px;\n" + " padding: 0px;\n"
+ + " padding-left: 3px;\n" + " padding-right: 3px;\n"
+ + " padding-top: 10px;\n" + " white-space: nowrap;\n"
+ + " -webkit-transform: rotate(-90deg); \n"
+ + " -moz-transform: rotate(-90deg); \n" + " };\n"
+ + "</style>\n";
+ private static final String FOOTER = "</html></body>";
+
+ // CSS style names.
+ private static final String CSS_TABLE = "table";
+ private static final String CSS_LABEL = "label";
+ private static final String CSS_TALL_HEADER = "tall";
+ private static final String CSS_VERTICAL = "verticalText";
+ private static final String CSS_CELL = "cell";
+ private static final String CSS_EMPTY = "empty";
+ private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"};
+
+ private ConfusionMatrixDumper() {}
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ConfusionMatrixDumper(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws IOException {
+ addInputOption();
+ addOption("output", "o", "Output path", null); // AbstractJob output feature requires param
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addFlag("html", null, "Create complete HTML page");
+ addFlag("text", null, "Dump simple text");
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path inputPath = getInputPath();
+ String outputFile = hasOption("output") ? getOption("output") : null;
+ boolean text = parsedArgs.containsKey("--text");
+ boolean wrapHtml = parsedArgs.containsKey("--html");
+ PrintStream out = getPrintStream(outputFile);
+ if (text) {
+ exportText(inputPath, out);
+ } else {
+ exportTable(inputPath, out, wrapHtml);
+ }
+ out.flush();
+ if (out != System.out) {
+ out.close();
+ }
+ return 0;
+ }
+
+ private static void exportText(Path inputPath, PrintStream out) throws IOException {
+ MatrixWritable mw = new MatrixWritable();
+ Text key = new Text();
+ readSeqFile(inputPath, key, mw);
+ Matrix m = mw.get();
+ ConfusionMatrix cm = new ConfusionMatrix(m);
+ out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total")
+ + TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR
+ + String.format("%-6s", "%") + TAB_SEPARATOR);
+ out.println(String.format("%-70s", "-").replace(' ', '-'));
+ List<String> labels = stripDefault(cm);
+ for (String label : labels) {
+ int correct = cm.getCorrect(label);
+ double accuracy = cm.getAccuracy(label);
+ int count = getCount(cm, label);
+ out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count)
+ + TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR
+ + String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR);
+ }
+ out.println(String.format("%-70s", "-").replace(' ', '-'));
+ out.println(cm.toString());
+ }
+
+ private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException {
+ MatrixWritable mw = new MatrixWritable();
+ Text key = new Text();
+ readSeqFile(inputPath, key, mw);
+ String fileName = inputPath.getName();
+ fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length());
+ Matrix m = mw.get();
+ ConfusionMatrix cm = new ConfusionMatrix(m);
+ if (wrapHtml) {
+ printHeader(out, fileName);
+ }
+ out.println("<p/>");
+ printSummaryTable(cm, out);
+ out.println("<p/>");
+ printGrayTable(cm, out);
+ out.println("<p/>");
+ printCountsTable(cm, out);
+ out.println("<p/>");
+ printTextInBox(cm, out);
+ out.println("<p/>");
+ if (wrapHtml) {
+ printFooter(out);
+ }
+ }
+
+ private static List<String> stripDefault(ConfusionMatrix cm) {
+ List<String> stripped = Lists.newArrayList(cm.getLabels().iterator());
+ String defaultLabel = cm.getDefaultLabel();
+ int unclassified = cm.getTotal(defaultLabel);
+ if (unclassified > 0) {
+ return stripped;
+ }
+ stripped.remove(defaultLabel);
+ return stripped;
+ }
+
+ // TODO: test - this should work with HDFS files
+ private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+ reader.next(key, m);
+ }
+
+ // TODO: test - this might not work with HDFS files?
+ // after all, it does no seeks
+ private static PrintStream getPrintStream(String outputFilename) throws IOException {
+ if (outputFilename != null) {
+ File outputFile = new File(outputFilename);
+ if (outputFile.exists()) {
+ outputFile.delete();
+ }
+ outputFile.createNewFile();
+ OutputStream os = new FileOutputStream(outputFile);
+ return new PrintStream(os, false, Charsets.UTF_8.displayName());
+ } else {
+ return System.out;
+ }
+ }
+
+ private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
+ Iterator<String> iter = cm.getLabels().iterator();
+ int count = 0;
+ while (iter.hasNext()) {
+ count += cm.getCount(rowLabel, iter.next());
+ }
+ return count;
+ }
+
+ // HTML generator code
+
+ private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
+ out.println("<div style='width:90%;overflow:scroll;'>");
+ out.println("<pre>");
+ out.println(cm.toString());
+ out.println("</pre>");
+ out.println("</div>");
+ }
+
+ public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ format("<tr class='%s'>", out, CSS_LABEL);
+ out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
+ out.println("</tr>");
+ List<String> labels = stripDefault(cm);
+ for (String label : labels) {
+ printSummaryRow(cm, out, label);
+ }
+ out.println("</table>");
+ }
+
+ private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) {
+ format("<tr class='%s'>", out, CSS_CELL);
+ int correct = cm.getCorrect(label);
+ double accuracy = cm.getAccuracy(label);
+ int count = getCount(cm, label);
+ format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", out, CSS_CELL, label, count, correct,
+ (int) Math.round(accuracy));
+ out.println("</tr>");
+ }
+
+ private static int getCount(ConfusionMatrix cm, String label) {
+ int count = 0;
+ for (String s : cm.getLabels()) {
+ count += cm.getCount(label, s);
+ }
+ return count;
+ }
+
+ public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ printCountsHeader(cm, out, true);
+ printGrayRows(cm, out);
+ out.println("</table>");
+ }
+
+ /**
+ * Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in
+ * misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent
+ */
+ private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
+ List<String> labels = stripDefault(cm);
+ for (String label : labels) {
+ printGrayRow(cm, out, labels, label);
+ }
+ }
+
+ private static void printGrayRow(ConfusionMatrix cm,
+ PrintStream out,
+ Iterable<String> labels,
+ String rowLabel) {
+ format("<tr class='%s'>", out, CSS_LABEL);
+ format("<td>%s</td>", out, rowLabel);
+ int total = getLabelTotal(cm, rowLabel);
+ for (String columnLabel : labels) {
+ printGrayCell(cm, out, total, rowLabel, columnLabel);
+ }
+ out.println("</tr>");
+ }
+
+ // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
+ // assign black to count = total, meaning complete success
+ // alternative rating is to use sqrt(total) instead of total - this is more drastic
+ private static void printGrayCell(ConfusionMatrix cm,
+ PrintStream out,
+ int total,
+ String rowLabel,
+ String columnLabel) {
+
+ int count = cm.getCount(rowLabel, columnLabel);
+ if (count == 0) {
+ out.format("<td class='%s'/>", CSS_EMPTY);
+ } else {
+ // 0 is white, full is black, everything else gray
+ int rating = (int) ((count / (double) total) * 4);
+ String css = CSS_GRAY_CELLS[rating];
+ format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count);
+ }
+ }
+
+ public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ printCountsHeader(cm, out, false);
+ printCountsRows(cm, out);
+ out.println("</table>");
+ }
+
+ private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
+ List<String> labels = stripDefault(cm);
+ for (String label : labels) {
+ printCountsRow(cm, out, labels, label);
+ }
+ }
+
+ private static void printCountsRow(ConfusionMatrix cm,
+ PrintStream out,
+ Iterable<String> labels,
+ String rowLabel) {
+ out.println("<tr>");
+ format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
+ for (String columnLabel : labels) {
+ printCountsCell(cm, out, rowLabel, columnLabel);
+ }
+ out.println("</tr>");
+ }
+
+ private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) {
+ int count = cm.getCount(rowLabel, columnLabel);
+ String s = count == 0 ? "" : Integer.toString(count);
+ format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s);
+ }
+
+ private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) {
+ List<String> labels = stripDefault(cm);
+ int longest = getLongestHeader(labels);
+ if (vertical) {
+ // do vertical - rotation is a bitch
+ out.format("<tr class='%s' style='height:%dem'><th> </th>%n", CSS_TALL_HEADER, longest / 2);
+ for (String label : labels) {
+ out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label);
+ }
+ out.println("</tr>");
+ } else {
+ // header - empty cell in upper left
+ out.format("<tr class='%s'><td class='%s'></td>%n", CSS_TABLE, CSS_LABEL);
+ for (String label : labels) {
+ out.format("<td>%s</td>", label);
+ }
+ out.format("</tr>");
+ }
+ }
+
+ private static int getLongestHeader(Iterable<String> labels) {
+ int max = 0;
+ for (String label : labels) {
+ max = Math.max(label.length(), max);
+ }
+ return max;
+ }
+
+ private static void format(String format, PrintStream out, Object... args) {
+ String format2 = String.format(format, args);
+ out.println(format2);
+ }
+
+ public static void printHeader(PrintStream out, CharSequence title) {
+ out.println(HEADER.replace("TITLE", title));
+ }
+
+ public static void printFooter(PrintStream out) {
+ out.println(FOOTER);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
new file mode 100644
index 0000000..545c1ff
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
@@ -0,0 +1,387 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.cdbw;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.GaussianAccumulator;
+import org.apache.mahout.clustering.OnlineGaussianAccumulator;
+import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver;
+import org.apache.mahout.clustering.evaluation.RepresentativePointsMapper;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * This class calculates the CDbw metric as defined in
+ * http://www.db-net.aueb.gr/index.php/corporate/content/download/227/833/file/HV_poster2002.pdf
+ */
+public final class CDbwEvaluator {
+
+ private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class);
+
+ private final Map<Integer,List<VectorWritable>> representativePoints;
+ private final Map<Integer,Double> stDevs = new HashMap<>();
+ private final List<Cluster> clusters;
+ private final DistanceMeasure measure;
+ private Double interClusterDensity = null;
+ // these are symmetric so we only compute half of them
+ private Map<Integer,Map<Integer,Double>> minimumDistances = null;
+ // these are symmetric too
+ private Map<Integer,Map<Integer,Double>> interClusterDensities = null;
+ // these are symmetric too
+ private Map<Integer,Map<Integer,int[]>> closestRepPointIndices = null;
+
+ /**
+ * For testing only
+ *
+ * @param representativePoints
+ * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId
+ * @param clusters
+ * a Map<Integer,Cluster> of the clusters keyed by clusterId
+ * @param measure
+ * an appropriate DistanceMeasure
+ */
+ public CDbwEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters,
+ DistanceMeasure measure) {
+ this.representativePoints = representativePoints;
+ this.clusters = clusters;
+ this.measure = measure;
+ for (Integer cId : representativePoints.keySet()) {
+ computeStd(cId);
+ }
+ }
+
+ /**
+ * Initialize a new instance from job information
+ *
+ * @param conf
+ * a Configuration with appropriate parameters
+ * @param clustersIn
+ * a String path to the input clusters directory
+ */
+ public CDbwEvaluator(Configuration conf, Path clustersIn) {
+ measure = ClassUtils
+ .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class);
+ representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf);
+ clusters = loadClusters(conf, clustersIn);
+ for (Integer cId : representativePoints.keySet()) {
+ computeStd(cId);
+ }
+ }
+
+ /**
+ * Load the clusters from their sequence files
+ *
+ * @param clustersIn
+ * a String pathname to the directory containing input cluster files
+ * @return a List<Cluster> of the clusters
+ */
+ private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) {
+ List<Cluster> clusters = new ArrayList<>();
+ for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ Cluster cluster = clusterWritable.getValue();
+ clusters.add(cluster);
+ }
+ return clusters;
+ }
+
+ /**
+ * Compute the standard deviation of the representative points for the given cluster. Store these in stDevs, indexed
+ * by cI
+ *
+ * @param cI
+ * a int clusterId.
+ */
+ private void computeStd(int cI) {
+ List<VectorWritable> repPts = representativePoints.get(cI);
+ GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : repPts) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ double d = accumulator.getAverageStd();
+ stDevs.put(cI, d);
+ }
+
+ /**
+ * Compute the density of points near the midpoint between the two closest points of the clusters (eqn 2) used for
+ * inter-cluster density calculation
+ *
+ * @param uIJ
+ * the Vector midpoint between the closest representative points of the clusters
+ * @param cI
+ * the int clusterId of the i-th cluster
+ * @param cJ
+ * the int clusterId of the j-th cluster
+ * @param avgStd
+ * the double average standard deviation of the two clusters
+ * @return a double
+ */
+ private double density(Vector uIJ, int cI, int cJ, double avgStd) {
+ List<VectorWritable> repI = representativePoints.get(cI);
+ List<VectorWritable> repJ = representativePoints.get(cJ);
+ double sum = 0.0;
+ // count the number of representative points of the clusters which are within the
+ // average std of the two clusters from the midpoint uIJ (eqn 3)
+ for (VectorWritable vwI : repI) {
+ if (uIJ != null && measure.distance(uIJ, vwI.get()) <= avgStd) {
+ sum++;
+ }
+ }
+ for (VectorWritable vwJ : repJ) {
+ if (uIJ != null && measure.distance(uIJ, vwJ.get()) <= avgStd) {
+ sum++;
+ }
+ }
+ int nI = repI.size();
+ int nJ = repJ.size();
+ return sum / (nI + nJ);
+ }
+
+ /**
+ * Compute the CDbw validity metric (eqn 8). The goal of this metric is to reward clusterings which have a high
+ * intraClusterDensity and also a high cluster separation.
+ *
+ * @return a double
+ */
+ public double getCDbw() {
+ return intraClusterDensity() * separation();
+ }
+
+ /**
+ * The average density within clusters is defined as the percentage of representative points that reside in the
+ * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5)
+ *
+ * @return a double
+ */
+ public double intraClusterDensity() {
+ double avgDensity = 0;
+ int count = 0;
+ for (Element elem : intraClusterDensities().nonZeroes()) {
+ double value = elem.get();
+ if (!Double.isNaN(value)) {
+ avgDensity += value;
+ count++;
+ }
+ }
+ return avgDensity / count;
+ }
+
+ /**
+ * This function evaluates the density of points in the regions between each clusters (eqn 1). The goal is the density
+ * in the area between clusters to be significant low.
+ *
+ * @return a Map<Integer,Map<Integer,Double>> of the inter-cluster densities
+ */
+ public Map<Integer,Map<Integer,Double>> interClusterDensities() {
+ if (interClusterDensities != null) {
+ return interClusterDensities;
+ }
+ interClusterDensities = new TreeMap<>();
+ // find the closest representative points between the clusters
+ for (int i = 0; i < clusters.size(); i++) {
+ int cI = clusters.get(i).getId();
+ Map<Integer,Double> map = new TreeMap<>();
+ interClusterDensities.put(cI, map);
+ for (int j = i + 1; j < clusters.size(); j++) {
+ int cJ = clusters.get(j).getId();
+ double minDistance = minimumDistance(cI, cJ); // the distance between the closest representative points
+ Vector uIJ = midpointVector(cI, cJ); // the midpoint between the closest representative points
+ double stdSum = stDevs.get(cI) + stDevs.get(cJ);
+ double density = density(uIJ, cI, cJ, stdSum / 2);
+ double interDensity = minDistance * density / stdSum;
+ map.put(cJ, interDensity);
+ if (log.isDebugEnabled()) {
+ log.debug("minDistance[{},{}]={}", cI, cJ, minDistance);
+ log.debug("interDensity[{},{}]={}", cI, cJ, density);
+ log.debug("density[{},{}]={}", cI, cJ, interDensity);
+ }
+ }
+ }
+ return interClusterDensities;
+ }
+
+ /**
+ * Calculate the separation of clusters (eqn 4) taking into account both the distances between the clusters' closest
+ * points and the Inter-cluster density. The goal is the distances between clusters to be high while the
+ * representative point density in the areas between them are low.
+ *
+ * @return a double
+ */
+ public double separation() {
+ double minDistanceSum = 0;
+ Map<Integer,Map<Integer,Double>> distances = minimumDistances();
+ for (Map<Integer,Double> map : distances.values()) {
+ for (Double dist : map.values()) {
+ if (!Double.isInfinite(dist)) {
+ minDistanceSum += dist * 2; // account for other half of calculated triangular minimumDistances matrix
+ }
+ }
+ }
+ return minDistanceSum / (1.0 + interClusterDensity());
+ }
+
+ /**
+ * This function evaluates the average density of points in the regions between clusters (eqn 1). The goal is the
+ * density in the area between clusters to be significant low.
+ *
+ * @return a double
+ */
+ public double interClusterDensity() {
+ if (interClusterDensity != null) {
+ return interClusterDensity;
+ }
+ double sum = 0.0;
+ int count = 0;
+ Map<Integer,Map<Integer,Double>> distances = interClusterDensities();
+ for (Map<Integer,Double> row : distances.values()) {
+ for (Double density : row.values()) {
+ if (!Double.isNaN(density)) {
+ sum += density;
+ count++;
+ }
+ }
+ }
+ log.debug("interClusterDensity={}", sum);
+ interClusterDensity = sum / count;
+ return interClusterDensity;
+ }
+
+ /**
+ * The average density within clusters is defined as the percentage of representative points that reside in the
+ * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5)
+ *
+ * @return a Vector of the intra-densities of each clusterId
+ */
+ public Vector intraClusterDensities() {
+ Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ // compute the average standard deviation of the clusters
+ double stdev = 0.0;
+ for (Integer cI : representativePoints.keySet()) {
+ stdev += stDevs.get(cI);
+ }
+ int c = representativePoints.size();
+ stdev /= c;
+ for (Cluster cluster : clusters) {
+ Integer cI = cluster.getId();
+ List<VectorWritable> repPtsI = representativePoints.get(cI);
+ int r = repPtsI.size();
+ double sumJ = 0.0;
+ // compute the term density (eqn 6)
+ for (VectorWritable pt : repPtsI) {
+ // compute f(x, vIJ) (eqn 7)
+ Vector repJ = pt.get();
+ double densityIJ = measure.distance(cluster.getCenter(), repJ) <= stdev ? 1.0 : 0.0;
+ // accumulate sumJ
+ sumJ += densityIJ / stdev;
+ }
+ densities.set(cI, sumJ / r);
+ }
+ return densities;
+ }
+
+ /**
+ * Calculate and cache the distances between the clusters' closest representative points. Also cache the indices of
+ * the closest representative points used for later use
+ *
+ * @return a Map<Integer,Vector> of the closest distances, keyed by clusterId
+ */
+ private Map<Integer,Map<Integer,Double>> minimumDistances() {
+ if (minimumDistances != null) {
+ return minimumDistances;
+ }
+ minimumDistances = new TreeMap<>();
+ closestRepPointIndices = new TreeMap<>();
+ for (int i = 0; i < clusters.size(); i++) {
+ Integer cI = clusters.get(i).getId();
+ Map<Integer,Double> map = new TreeMap<>();
+ Map<Integer,int[]> treeMap = new TreeMap<>();
+ closestRepPointIndices.put(cI, treeMap);
+ minimumDistances.put(cI, map);
+ List<VectorWritable> closRepI = representativePoints.get(cI);
+ for (int j = i + 1; j < clusters.size(); j++) {
+ // find min{d(closRepI, closRepJ)}
+ Integer cJ = clusters.get(j).getId();
+ List<VectorWritable> closRepJ = representativePoints.get(cJ);
+ double minDistance = Double.MAX_VALUE;
+ int[] midPointIndices = null;
+ for (int xI = 0; xI < closRepI.size(); xI++) {
+ VectorWritable aRepI = closRepI.get(xI);
+ for (int xJ = 0; xJ < closRepJ.size(); xJ++) {
+ VectorWritable aRepJ = closRepJ.get(xJ);
+ double distance = measure.distance(aRepI.get(), aRepJ.get());
+ if (distance < minDistance) {
+ minDistance = distance;
+ midPointIndices = new int[] {xI, xJ};
+ }
+ }
+ }
+ map.put(cJ, minDistance);
+ treeMap.put(cJ, midPointIndices);
+ }
+ }
+ return minimumDistances;
+ }
+
+ private double minimumDistance(int cI, int cJ) {
+ Map<Integer,Double> distances = minimumDistances().get(cI);
+ if (distances != null) {
+ return distances.get(cJ);
+ } else {
+ return minimumDistances().get(cJ).get(cI);
+ }
+ }
+
+ private Vector midpointVector(int cI, int cJ) {
+ Map<Integer,Double> distances = minimumDistances().get(cI);
+ if (distances != null) {
+ int[] ks = closestRepPointIndices.get(cI).get(cJ);
+ if (ks == null) {
+ return null;
+ }
+ return representativePoints.get(cI).get(ks[0]).get().plus(representativePoints.get(cJ).get(ks[1]).get())
+ .divide(2);
+ } else {
+ int[] ks = closestRepPointIndices.get(cJ).get(cI);
+ if (ks == null) {
+ return null;
+ }
+ return representativePoints.get(cJ).get(ks[1]).get().plus(representativePoints.get(cI).get(ks[0]).get())
+ .divide(2);
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java
new file mode 100644
index 0000000..6a2b376
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.conversion;
+
+import java.io.IOException;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class converts text files containing space-delimited floating point numbers into
+ * Mahout sequence files of VectorWritable suitable for input to the clustering jobs in
+ * particular, and any Mahout job requiring this input in general.
+ *
+ */
+public final class InputDriver {
+
+ private static final Logger log = LoggerFactory.getLogger(InputDriver.class);
+
+ private InputDriver() {
+ }
+
+ public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = DefaultOptionCreator.inputOption().withRequired(false).create();
+ Option outputOpt = DefaultOptionCreator.outputOption().withRequired(false).create();
+ Option vectorOpt = obuilder.withLongName("vector").withRequired(false).withArgument(
+ abuilder.withName("v").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The vector implementation to use.").withShortName("v").create();
+
+ Option helpOpt = DefaultOptionCreator.helpOption();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
+ vectorOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ Path input = new Path(cmdLine.getValue(inputOpt, "testdata").toString());
+ Path output = new Path(cmdLine.getValue(outputOpt, "output").toString());
+ String vectorClassName = cmdLine.getValue(vectorOpt,
+ "org.apache.mahout.math.RandomAccessSparseVector").toString();
+ runJob(input, output, vectorClassName);
+ } catch (OptionException e) {
+ log.error("Exception parsing command line: ", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ public static void runJob(Path input, Path output, String vectorClassName)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Configuration conf = new Configuration();
+ conf.set("vector.implementation.class.name", vectorClassName);
+ Job job = new Job(conf, "Input Driver running over input: " + input);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(InputMapper.class);
+ job.setNumReduceTasks(0);
+ job.setJarByClass(InputDriver.class);
+
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java
new file mode 100644
index 0000000..e4c72c6
--- /dev/null
+++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.conversion;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.regex.Pattern;
+
+public class InputMapper extends Mapper<LongWritable, Text, Text, VectorWritable> {
+
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ private Constructor<?> constructor;
+
+ @Override
+ protected void map(LongWritable key, Text values, Context context) throws IOException, InterruptedException {
+
+ String[] numbers = SPACE.split(values.toString());
+ // sometimes there are multiple separator spaces
+ Collection<Double> doubles = new ArrayList<>();
+ for (String value : numbers) {
+ if (!value.isEmpty()) {
+ doubles.add(Double.valueOf(value));
+ }
+ }
+ // ignore empty lines in data file
+ if (!doubles.isEmpty()) {
+ try {
+ Vector result = (Vector) constructor.newInstance(doubles.size());
+ int index = 0;
+ for (Double d : doubles) {
+ result.set(index++, d);
+ }
+ VectorWritable vectorWritable = new VectorWritable(result);
+ context.write(new Text(String.valueOf(index)), vectorWritable);
+
+ } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ String vectorImplClassName = conf.get("vector.implementation.class.name");
+ try {
+ Class<? extends Vector> outputClass = conf.getClassByName(vectorImplClassName).asSubclass(Vector.class);
+ constructor = outputClass.getConstructor(int.class);
+ } catch (NoSuchMethodException | ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+}