You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by is...@apache.org on 2017/08/07 14:24:59 UTC
ignite git commit: IGNITE-5880: BLAS integration phase 2
Repository: ignite
Updated Branches:
refs/heads/master 9b730e732 -> 488a6d271
IGNITE-5880: BLAS integration phase 2
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/488a6d27
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/488a6d27
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/488a6d27
Branch: refs/heads/master
Commit: 488a6d271bbeb60d424745fb20ef856c624e2fcb
Parents: 9b730e7
Author: Yury Babak <yb...@gridgain.com>
Authored: Mon Aug 7 17:24:21 2017 +0300
Committer: Igor Sapego <is...@gridgain.com>
Committed: Mon Aug 7 17:24:21 2017 +0300
----------------------------------------------------------------------
.../java/org/apache/ignite/ml/math/Blas.java | 157 ++++---------------
.../ml/math/impls/matrix/AbstractMatrix.java | 3 +-
.../storage/vector/MatrixVectorStorage.java | 11 ++
.../vector/SparseLocalOnHeapVectorStorage.java | 9 ++
.../impls/matrix/MatrixImplementationsTest.java | 5 +-
.../RandomAccessSparseVectorStorageTest.java | 4 +-
6 files changed, 56 insertions(+), 133 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java
index 29312e5..a61d796 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/Blas.java
@@ -19,8 +19,6 @@ package org.apache.ignite.ml.math;
import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.F2jBLAS;
-import it.unimi.dsi.fastutil.ints.IntIterator;
-import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.Set;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException;
@@ -30,7 +28,10 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix;
import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
import org.apache.ignite.ml.math.impls.matrix.SparseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.CacheVector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOffHeapVector;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.math.impls.vector.SparseLocalOffHeapVector;
import org.apache.ignite.ml.math.impls.vector.SparseLocalVector;
import org.apache.ignite.ml.math.util.MatrixUtil;
@@ -244,9 +245,9 @@ public class Blas {
else if (alpha == 0.0)
scal(c, beta);
else {
- checkTypes(a, "gemm");
- checkTypes(b, "gemm");
- checkTypes(c, "gemm");
+ checkMatrixType(a, "gemm");
+ checkMatrixType(b, "gemm");
+ checkMatrixType(c, "gemm");
double[] fA = a.getStorage().data();
double[] fB = b.getStorage().data();
@@ -265,7 +266,7 @@ public class Blas {
/**
* Currently we support only local onheap matrices for BLAS.
*/
- private static void checkTypes(Matrix a, String op){
+ private static void checkMatrixType(Matrix a, String op){
if (a instanceof DenseLocalOffHeapMatrix || a instanceof SparseDistributedMatrix
|| a instanceof SparseBlockDistributedMatrix)
throw new IllegalArgumentException("Operation doesn't support for matrix [class="
@@ -273,37 +274,12 @@ public class Blas {
}
/**
- * y := alpha * A * x + beta * y.
- *
- * @param alpha Alpha.
- * @param a Matrix a.
- * @param x Vector x.
- * @param beta Beta.
- * @param y Vector y.
+ * Currently we support only local onheap vectors for BLAS.
*/
- public static void gemv(double alpha, Matrix a, Vector x, double beta, DenseLocalOnHeapVector y) {
- checkCardinality(a, x);
- checkCardinality(a, y);
-
- if (alpha == 0.0 && beta == 1.0)
- return;
-
- if (alpha == 0.0) {
- scal(y, beta);
- return;
- }
-
- if (a instanceof SparseLocalOnHeapMatrix && x instanceof DenseLocalOnHeapVector)
- gemv(alpha, (SparseLocalOnHeapMatrix)a, (DenseLocalOnHeapVector)x, beta, y);
- else if (a instanceof SparseLocalOnHeapMatrix && x instanceof SparseLocalVector)
- gemv(alpha, (SparseLocalOnHeapMatrix)a, (SparseLocalVector)x, beta, y);
- else if (a instanceof DenseLocalOnHeapMatrix && x instanceof DenseLocalOnHeapVector)
- gemv(alpha, (DenseLocalOnHeapMatrix)a, (DenseLocalOnHeapVector)x, beta, y);
- else if (a instanceof DenseLocalOnHeapMatrix && x instanceof SparseLocalVector)
- gemv(alpha, (DenseLocalOnHeapMatrix)a, (SparseLocalVector)x, beta, y);
- else
- throw new IllegalArgumentException("Operation gemv doesn't support running thist input [matrix=" +
- a.getClass().getSimpleName() + ", vector=" + x.getClass().getSimpleName()+"].");
+ private static void checkVectorType(Vector a, String op){
+ if (a instanceof DenseLocalOffHeapVector || a instanceof SparseLocalOffHeapVector || a instanceof CacheVector)
+ throw new IllegalArgumentException("Operation doesn't support for vector [class="
+ + a.getClass().getName() + ", operation="+op+"].");
}
/**
@@ -315,106 +291,32 @@ public class Blas {
* @param beta Beta.
* @param y Vector y.
*/
- private static void gemv(double alpha, SparseLocalOnHeapMatrix a, DenseLocalOnHeapVector x, double beta,
- DenseLocalOnHeapVector y) {
-
- if (beta != 1.0)
- scal(y, beta);
-
- IntIterator rowIter = a.indexesMap().keySet().iterator();
- while (rowIter.hasNext()) {
- int row = rowIter.nextInt();
-
- double sum = 0.0;
- IntIterator colIter = a.indexesMap().get(row).iterator();
- while (colIter.hasNext()) {
- int col = colIter.nextInt();
- sum += alpha * a.getX(row, col) * x.getX(col);
- }
-
- y.setX(row, y.getX(row) + sum);
- }
- }
+ public static void gemv(double alpha, Matrix a, Vector x, double beta, Vector y) {
+ checkCardinality(a, x);
- /**
- * y := alpha * A * x + beta * y.
- *
- * @param alpha Alpha.
- * @param a Matrix a.
- * @param x Vector x.
- * @param beta Beta.
- * @param y Vector y.
- */
- private static void gemv(double alpha, DenseLocalOnHeapMatrix a, DenseLocalOnHeapVector x, double beta,
- DenseLocalOnHeapVector y) {
- nativeBlas.dgemv("N", a.rowSize(), a.columnSize(), alpha, a.getStorage().data(), a.rowSize(), x.getStorage().data(), 1, beta,
- y.getStorage().data(), 1);
- }
+ if (a.rowSize() != y.size())
+ throw new CardinalityException(a.columnSize(), y.size());
- /**
- * y := alpha * A * x + beta * y.
- *
- * @param alpha Alpha.
- * @param a Matrix a.
- * @param x Vector x.
- * @param beta Beta.
- * @param y Vector y.
- */
- private static void gemv(double alpha, SparseLocalOnHeapMatrix a, SparseLocalVector x, double beta,
- DenseLocalOnHeapVector y) {
+ checkMatrixType(a, "gemv");
+ checkVectorType(x,"gemv");
+ checkVectorType(y, "gemv");
+ if (alpha == 0.0 && beta == 1.0)
+ return;
- if (beta != 1.0)
+ if (alpha == 0.0) {
scal(y, beta);
-
- IntIterator rowIter = a.indexesMap().keySet().iterator();
- while (rowIter.hasNext()) {
- int row = rowIter.nextInt();
-
- double sum = 0.0;
- IntIterator colIter = a.indexesMap().get(row).iterator();
- while (colIter.hasNext()) {
- int col = colIter.nextInt();
-
- sum += alpha * a.getX(row, col) * x.getX(col);
- }
-
- y.set(row, y.get(row) + sum);
+ return;
}
- }
-
- /**
- * y := alpha * A * x + beta * y.
- *
- * @param alpha Alpha.
- * @param a Matrix a.
- * @param x Vector x.
- * @param beta Beta.
- * @param y Vector y.
- */
- private static void gemv(double alpha, DenseLocalOnHeapMatrix a, SparseLocalVector x, double beta,
- DenseLocalOnHeapVector y) {
- int rowCntrForA = 0;
- int mA = a.rowSize();
-
- double[] aData = a.getStorage().data();
-
- IntSet indexes = x.indexes();
- double[] yValues = y.getStorage().data();
+ double[] fA = a.getStorage().data();
+ double[] fX = x.getStorage().data();
+ double[] fY = y.getStorage().data();
- while (rowCntrForA < mA) {
- double sum = 0.0;
+ nativeBlas.dgemv("N", a.rowSize(), a.columnSize(), alpha, fA, a.rowSize(), fX, 1, beta, fY, 1);
- IntIterator iter = indexes.iterator();
- while (iter.hasNext()) {
- int xIdx = iter.nextInt();
- sum += x.getX(xIdx) * aData[xIdx * mA + rowCntrForA];
- }
-
- yValues[rowCntrForA] = sum * alpha + beta * yValues[rowCntrForA];
- rowCntrForA++;
- }
+ if (y instanceof SparseLocalVector)
+ y.assign(fY);
}
/**
@@ -427,7 +329,6 @@ public class Blas {
for (int i = 0; i < m.rowSize(); i++)
for (int j = 0; j < m.columnSize(); j++)
m.setX(i, j, m.getX(i, j) * alpha);
-
}
/**
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
index b1680f4..2195a70 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/matrix/AbstractMatrix.java
@@ -770,8 +770,7 @@ public abstract class AbstractMatrix implements Matrix {
Vector res = likeVector(rows);
- for (int x = 0; x < rows; x++)
- res.setX(x, vec.dot(viewRow(x)));
+ Blas.gemv(1,this,vec,0,res);
return res;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/MatrixVectorStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/MatrixVectorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/MatrixVectorStorage.java
index 7700a7c..1e3680a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/MatrixVectorStorage.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/MatrixVectorStorage.java
@@ -163,6 +163,17 @@ public class MatrixVectorStorage implements VectorStorage {
}
/** {@inheritDoc} */
+ //TODO: IGNITE-5925, tmp solution, wait this ticket.
+ @Override public double[] data() {
+ double[] res = new double[size];
+
+ for (int i = 0; i < size; i++)
+ res[i] = get(i);
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(parent);
out.writeInt(row);
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseLocalOnHeapVectorStorage.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseLocalOnHeapVectorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseLocalOnHeapVectorStorage.java
index 3323a07..272f08d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseLocalOnHeapVectorStorage.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/impls/storage/vector/SparseLocalOnHeapVectorStorage.java
@@ -158,6 +158,15 @@ public class SparseLocalOnHeapVectorStorage implements VectorStorage, StorageCon
}
/** {@inheritDoc} */
+ @Override public double[] data() {
+ double[] data = new double[size];
+
+ sto.forEach((idx, val) -> data[idx]=val);
+
+ return data;
+ }
+
+ /** {@inheritDoc} */
@Override public boolean equals(Object o) {
if (this == o)
return true;
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java
index e4c2938..89b6224 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/matrix/MatrixImplementationsTest.java
@@ -287,6 +287,9 @@ public class MatrixImplementationsTest extends ExternalizeTest<Matrix> {
if (ignore(m.getClass()))
return;
+ if (m instanceof DenseLocalOffHeapMatrix)
+ return; //TODO: IGNITE-5535, waiting offheap support.
+
double[][] data = fillAndReturn(m);
double[] arr = fillArray(m.columnSize());
@@ -302,7 +305,7 @@ public class MatrixImplementationsTest extends ExternalizeTest<Matrix> {
exp += arr[j] * data[i][j];
assertEquals("Unexpected value for " + desc + " at " + i,
- times.get(i), exp, 0d);
+ times.get(i), exp, DEFAULT_DELTA);
}
testInvalidCardinality(() -> m.times(new DenseLocalOnHeapVector(m.columnSize() + 1)), desc);
http://git-wip-us.apache.org/repos/asf/ignite/blob/488a6d27/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/RandomAccessSparseVectorStorageTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/RandomAccessSparseVectorStorageTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/RandomAccessSparseVectorStorageTest.java
index 6578e14..1c09ce9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/RandomAccessSparseVectorStorageTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/impls/storage/vector/RandomAccessSparseVectorStorageTest.java
@@ -22,7 +22,7 @@ import org.apache.ignite.ml.math.impls.MathTestConstants;
import org.junit.Test;
import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertNotNull;
/**
* Unit tests for {@link SparseLocalOnHeapVectorStorage}.
@@ -36,7 +36,7 @@ public class RandomAccessSparseVectorStorageTest extends VectorBaseStorageTest<S
/** */
@Test
public void data() throws Exception {
- assertNull(MathTestConstants.NULL_VAL, storage.data());
+ assertNotNull(MathTestConstants.NULL_VAL, storage.data());
}
/** */