You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:36:46 UTC

[03/32] mahout git commit: MAHOUT-1701: Flink: AtB implemented, ABt and AtA expressed via AtB

MAHOUT-1701: Flink: AtB implemented, ABt and AtA expressed via AtB


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/f836481b
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/f836481b
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/f836481b

Branch: refs/heads/flink-binding
Commit: f836481b823a1aaaa70d9ab87c030f60c459de0d
Parents: 98d4ff0
Author: Alexey Grigorev <al...@gmail.com>
Authored: Tue May 5 20:05:21 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:41:39 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      | 32 ++++++-
 .../mahout/flinkbindings/blas/FlinkOpAt.scala   | 11 ++-
 .../mahout/flinkbindings/blas/FlinkOpAtB.scala  | 85 +++++++++++++++++++
 .../mahout/flinkbindings/blas/package.scala     | 15 ++++
 .../mahout/flinkbindings/RLikeOpsSuite.scala    | 88 +++++++++++---------
 .../mahout/flinkbindings/UseCasesSuite.scala    | 79 ++++++++++++++++++
 .../mahout/flinkbindings/blas/LATestSuit.scala  | 24 +++++-
 7 files changed, 285 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 17bf0b6..a124a7c 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -33,6 +33,10 @@ import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 import org.apache.mahout.math.drm.logical.OpAt
 import org.apache.mahout.math.drm.logical.OpAtx
 import org.apache.mahout.math.drm.logical.OpAtx
+import org.apache.mahout.math.drm.logical.OpAtB
+import org.apache.mahout.math.drm.logical.OpABt
+import org.apache.mahout.math.drm.logical.OpAtB
+import org.apache.mahout.math.drm.logical.OpAtA
 
 object FlinkEngine extends DistributedEngine {
 
@@ -56,14 +60,40 @@ object FlinkEngine extends DistributedEngine {
     case op @ OpAx(a, x) => FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)(op.classTagA))
     case op @ OpAt(a) => FlinkOpAt.sparseTrick(op, flinkTranslate(a)(op.classTagA))
     case op @ OpAtx(a, x) => {
+      // express Atx as (A.t) %*% x
+      // TODO: create specific implementation of Atx
       val opAt = OpAt(a)
       val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)(op.classTagA))
       val atCast = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
       val opAx = OpAx(atCast, x)
       FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)(op.classTagA))
     }
+    case op @ OpAtB(a, b) => FlinkOpAtB.notZippable(op, flinkTranslate(a)(op.classTagA), 
+        flinkTranslate(b)(op.classTagA))
+    case op @ OpABt(a, b) => {
+      // express ABt via AtB: let C=At and D=Bt, and calculate CtD
+      // TODO: create specific implementation of ABt
+      val opAt = OpAt(a.asInstanceOf[DrmLike[Int]]) // TODO: casts!
+      val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a.asInstanceOf[DrmLike[Int]]))
+      val c = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
+
+      val opBt = OpAt(b.asInstanceOf[DrmLike[Int]]) // TODO: casts!
+      val bt = FlinkOpAt.sparseTrick(opBt, flinkTranslate(b.asInstanceOf[DrmLike[Int]]))
+      val d = new CheckpointedFlinkDrm(bt.deblockify.ds, _nrow=opBt.nrow, _ncol=opBt.ncol)
+
+      FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d))
+                .asInstanceOf[FlinkDrm[K]]
+    }
+    case op @ OpAtA(a) => {
+      // express AtA via AtB
+      // TODO: create specific implementation of AtA
+      val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts!
+      val opAtB = OpAtB(aInt, aInt)
+      val aTranslated = flinkTranslate(aInt)
+      FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
+    }
     case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
-    case _ => ???
+    case _ => throw new NotImplementedError(s"operator $oper is not implemented yet")
   }
   
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
index be7fc8f..08aea73 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
@@ -20,6 +20,7 @@ import org.apache.flink.api.java.functions.KeySelector
 import java.util.ArrayList
 import org.apache.flink.shaded.com.google.common.collect.Lists
 
+
 /**
  * Taken from
  */
@@ -40,7 +41,7 @@ object FlinkOpAt {
             val columnVector: Vector = new SequentialAccessSparseVector(ncol)
 
             keys.zipWithIndex.foreach { case (key, idx) =>
-                columnVector(key) = block(idx, columnIdx)
+              columnVector(key) = block(idx, columnIdx)
             }
 
             out.collect(new Tuple2(columnIdx, columnVector))
@@ -49,12 +50,10 @@ object FlinkOpAt {
       }
     })
 
-    val regrouped = sparseParts.groupBy(new KeySelector[Tuple2[Int, Vector], Integer] {
-      def getKey(tuple: Tuple2[Int, Vector]): Integer = tuple._1
-    })
+    val regrouped = sparseParts.groupBy(tuple_1[Vector])
 
-    val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[Tuple2[Int, Vector], DrmTuple[Int]] {
-      def reduce(values: Iterable[DrmTuple[Int]], out: Collector[DrmTuple[Int]]): Unit = {
+    val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[(Int, Vector), DrmTuple[Int]] {
+      def reduce(values: Iterable[(Int, Vector)], out: Collector[DrmTuple[Int]]): Unit = {
         val it = Lists.newArrayList(values).asScala
         val (idx, _) = it.head
         val vector = it map { case (idx, vec) => vec } reduce (_ + _)

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
new file mode 100644
index 0000000..3b353fc
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -0,0 +1,85 @@
+package org.apache.mahout.flinkbindings.blas
+
+import scala.reflect.ClassTag
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.math.drm.logical.OpAtB
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.java.tuple.Tuple2
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.Matrix
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.util.Collector
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+import org.apache.flink.api.common.functions.GroupReduceFunction
+import java.lang.Iterable
+import scala.collection.JavaConverters._
+import com.google.common.collect.Lists
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet
+import org.apache.flink.api.scala._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
+
+object FlinkOpAtB {
+
+  def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = {
+    // TODO: to help Flink's type inference
+    // only Int is supported now 
+    val rowsAt = At.deblockify.ds.map(new DrmTupleToDrmTupleInt())
+    val rowsB = B.deblockify.ds.map(new DrmTupleToDrmTupleInt())
+    val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
+
+    val ncol = op.ncol
+    val nrow = op.nrow
+    val blockHeight = 10
+    val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
+
+    val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)], 
+                                                        (Int, Matrix)] {
+      def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)],
+                  out: Collector[(Int, Matrix)]): Unit = {
+        val avec = in.f0._2
+        val bvec = in.f1._2
+
+        0.until(blockCount) map { blockKey =>
+          val blockStart = blockKey * blockHeight
+          val blockEnd = Math.min(ncol, blockStart + blockHeight)
+
+          // Create block by cross product of proper slice of aRow and qRow
+          val outer = avec(blockStart until blockEnd) cross bvec
+          out.collect((blockKey, outer))
+        }
+      }
+    })
+
+    val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup(
+            new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
+      def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
+        val it = Lists.newArrayList(values).asScala
+        val (idx, _) = it.head
+
+        val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2)
+
+        val keys = idx.until(block.nrow).toArray[Int]
+        out.collect((keys, block))
+      }
+    })
+
+    new BlockifiedFlinkDrm(res, ncol)
+  }
+
+}
+
+class DrmTupleToDrmTupleInt[K: ClassTag] extends MapFunction[(K, Vector), (Int, Vector)] {
+  def map(tuple: (K, Vector)): (Int, Vector) = tuple match {
+    case (key, vec) => (key.asInstanceOf[Int], vec)
+  }
+}
+
+class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), Tuple2[Int, Vector]] {
+  def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match {
+    case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
new file mode 100644
index 0000000..af5ccc8
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
@@ -0,0 +1,15 @@
+package org.apache.mahout.flinkbindings
+
+import org.apache.flink.api.java.functions.KeySelector
+import org.apache.mahout.math.Vector
+import scala.reflect.ClassTag
+
+
+package object blas {
+
+  // TODO: remove it once figure out how to make Flink accept interfaces (Vector here)
+  def tuple_1[K: ClassTag] = new KeySelector[(Int, K), Integer] {
+    def getKey(tuple: Tuple2[Int, K]): Integer = tuple._1
+  }
+  
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 07d6a84..2624077 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -33,64 +33,72 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit {
     assert(b == dvec(8, 11, 14))
   }
 
-  test("Power interation 1000 x 1000 matrix") {
-    val dim = 1000
+  test("A.t") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val res = A.t.collect
 
-    // we want a symmetric matrix so we can have real eigenvalues
-    val inCoreA = symmtericMatrix(dim, max = 2000)
+    val expected = inCoreA.t
+    assert((res - expected).norm < 1e-6)
+  }
 
+  test("A.t %*% x") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val x = dvec(3, 11)
+    val res = (A.t %*% x).collect(::, 0)
 
-    var x: Vector = 1 to dim map (_ => 1.0 / Math.sqrt(dim))
-    var converged = false
+    val expected = inCoreA.t %*% x 
+    assert((res - expected).norm(2) < 1e-6)
+  }
 
-    var iteration = 1
+  test("A.t %*% B") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val inCoreB = dense((1, 2), (3, 4), (11, 4))
 
-    while (!converged) {
-      LOGGER.info(s"iteration #$iteration...")
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val B = drmParallelize(m = inCoreB, numPartitions = 2)
 
-      val Ax = A %*% x
-      var x_new = Ax.collect(::, 0)
-      x_new = x_new / x_new.norm(2)
+    val res = A.t %*% B
 
-      val diff = (x_new - x).norm(2)
-      LOGGER.info(s"difference norm is $diff")
+    val expected = inCoreA.t %*% inCoreB
+    assert((res.collect - expected).norm < 1e-6)
+  }
 
-      converged = diff < 1e-6
-      iteration = iteration + 1
-      x = x_new
-    }
+  test("A %*% B.t") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
+    val inCoreB = dense((1, 2), (3, 4), (11, 4))
 
-    LOGGER.info("converged")
-    // TODO: add test that it's the 1st PC
-  }
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val B = drmParallelize(m = inCoreB, numPartitions = 2)
+
+    val res = A %*% B.t
 
-  def symmtericMatrix(dim: Int, max: Int, seed: Int = 0x31337) = {
-    Matrices.functionalMatrixView(dim, dim, new IntIntFunction {
-      def apply(i: Int, j: Int): Double = {
-        val arr = Array(i + j, i * j, i + j + 31, i / (j + 1) + j / (i + 1))
-        Math.abs(MurmurHash3.arrayHash(arr, seed) % max)
-      }
-    })
+    val expected = inCoreA %*% inCoreB.t
+    assert((res.collect - expected).norm < 1e-6)
   }
 
-  test("A.t") {
-    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+  test("A.t %*% A") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
-    val res = A.t.collect
 
-    val expected = inCoreA.t
-    assert((res - expected).norm < 1e-6)
+    val res = A.t %*% A
+
+    val expected = inCoreA.t %*% inCoreA
+    assert((res.collect - expected).norm < 1e-6)
   }
 
-  test("A.t %*% x") {
-    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+  test("A %*% B") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4)).t
+    val inCoreB = dense((1, 2), (3, 4), (11, 4))
+
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
-    val x = dvec(3, 11)
-    val res = (A.t %*% x).collect(::, 0)
+    val B = drmParallelize(m = inCoreB, numPartitions = 2)
 
-    val expected = inCoreA.t %*% x 
-    assert((res - expected).norm(2) < 1e-6)
+    val res = A %*% B
+
+    val expected = inCoreA %*% inCoreB
+    assert((res.collect - expected).norm < 1e-6)
   }
-  
+
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
new file mode 100644
index 0000000..8cdaca3
--- /dev/null
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
@@ -0,0 +1,79 @@
+package org.apache.mahout.flinkbindings
+
+import org.junit.runner.RunWith
+import org.scalatest.junit.JUnitRunner
+import org.scalatest.FunSuite
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+import org.apache.mahout.math.drm._
+import RLikeDrmOps._
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math.function.IntIntFunction
+import scala.util.Random
+import scala.util.MurmurHash
+import scala.util.hashing.MurmurHash3
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+import org.scalatest.Ignore
+
+@RunWith(classOf[JUnitRunner])
+class UseCasesSuite extends FunSuite with DistributedFlinkSuit {
+
+  val LOGGER = LoggerFactory.getLogger(getClass())
+
+  test("use case: Power interation 1000 x 1000 matrix") {
+    val dim = 1000
+
+    // we want a symmetric matrix so we can have real eigenvalues
+    val inCoreA = symmtericMatrix(dim, max = 2000)
+
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    var x: Vector = 1 to dim map (_ => 1.0 / Math.sqrt(dim))
+    var converged = false
+
+    var iteration = 1
+
+    while (!converged) {
+      LOGGER.info(s"iteration #$iteration...")
+
+      val Ax = A %*% x
+      var x_new = Ax.collect(::, 0)
+      x_new = x_new / x_new.norm(2)
+
+      val diff = (x_new - x).norm(2)
+      LOGGER.info(s"difference norm is $diff")
+
+      converged = diff < 1e-6
+      iteration = iteration + 1
+      x = x_new
+    }
+
+    LOGGER.info("converged")
+    // TODO: add test that it's the 1st PC
+  }
+
+  def symmtericMatrix(dim: Int, max: Int, seed: Int = 0x31337) = {
+    Matrices.functionalMatrixView(dim, dim, new IntIntFunction {
+      def apply(i: Int, j: Int): Double = {
+        val arr = Array(i + j, i * j, i + j + 31, i / (j + 1) + j / (i + 1))
+        Math.abs(MurmurHash3.arrayHash(arr, seed) % max)
+      }
+    })
+  }
+
+  test("use case: OLS Regression") {
+    val inCoreA = dense((1, 2), (2, 3), (3, 4), (5, 6), (7, 8), (9, 10))
+    val x = dvec(1, 2, 2, 3, 3, 3)
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val AtA = A.t %*% A
+    val Atx = A.t %*% x
+
+    val w = solve(AtA, Atx)
+
+    val expected = solve(inCoreA.t %*% inCoreA, inCoreA.t %*% x)
+    assert((w(::, 0) - expected).norm(2) < 1e-6)
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/f836481b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
index 3ce8895..baf23d6 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
@@ -12,11 +12,12 @@ import org.apache.mahout.math.drm.logical.OpAx
 import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
 import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.math.drm.logical.OpAtB
 
 @RunWith(classOf[JUnitRunner])
 class LATestSuit extends FunSuite with DistributedFlinkSuit {
 
-  test("Ax") {
+  test("Ax blockified") {
     val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
     val x: Vector = (0, 1, 2)
@@ -30,7 +31,7 @@ class LATestSuit extends FunSuite with DistributedFlinkSuit {
     assert(b == dvec(8, 11, 14))
   }
 
-  test("At") {
+  test("At sparseTrick") {
     val inCoreA = dense((1, 2, 3), (2, 3, 4))
     val A = drmParallelize(m = inCoreA, numPartitions = 2)
 
@@ -42,4 +43,23 @@ class LATestSuit extends FunSuite with DistributedFlinkSuit {
     assert((output - inCoreA.t).norm < 1e-6)
   }
 
+  test("AtB notZippable") {
+    val inCoreAt = dense((1, 2), (2, 3), (3, 4))
+
+    val At = drmParallelize(m = inCoreAt, numPartitions = 2)
+
+    val inCoreB = dense((1, 2), (3, 4), (11, 4))
+    val B = drmParallelize(m = inCoreB, numPartitions = 2)
+
+    val opAtB = new OpAtB(At, B)
+    val res = FlinkOpAtB.notZippable(opAtB, At, B)
+
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreAt.ncol, _ncol=inCoreB.ncol)
+    val output = drm.collect
+
+    val expected = inCoreAt.t %*% inCoreB
+    assert((output - expected).norm < 1e-6)
+  }
+  
+
 }
\ No newline at end of file