You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datasketches.apache.org by jm...@apache.org on 2021/05/26 21:58:20 UTC

[datasketches-vector] 01/02: Vector normalization to support ridge regression

This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch fdrr
in repository https://gitbox.apache.org/repos/asf/datasketches-vector.git

commit 718b647534d00fec47891cacd56222769204865e
Author: Jon Malkin <jm...@users.noreply.github.com>
AuthorDate: Wed May 26 14:55:44 2021 -0700

    Vector normalization to support ridge regression
---
 .../apache/datasketches/vector/MatrixFamily.java   |  16 +-
 .../vector/regression/VectorNormalizer.java        | 402 ++++++++++++++++++
 .../vector/regression/RidgeRegressionTest.java     |   2 +
 .../vector/regression/VectorNormalizerTest.java    | 454 +++++++++++++++++++++
 4 files changed, 869 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/apache/datasketches/vector/MatrixFamily.java b/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
index 9c1d141..0d57d75 100644
--- a/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
+++ b/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
@@ -44,15 +44,21 @@ public enum MatrixFamily {
   /**
    * Select Frequent Directions Family
    */
-  FREQUENTDIRECTIONS(129, "FrequentDirections", 2, 4);
+  FREQUENTDIRECTIONS(129, "FrequentDirections", 2, 4),
+
+  /**
+   * Aggregation for vector means/variances.
+   */
+  VECTORNORMALIZER(130, "VectorNormalizer", 1, 2)
+  ;
 
 
   private static final Map<Integer, MatrixFamily> lookupID = new HashMap<>();
   private static final Map<String, MatrixFamily> lookupFamName = new HashMap<>();
-  private int id_;
-  private String famName_;
-  private int minPreLongs_;
-  private int maxPreLongs_;
+  private final int id_;
+  private final String famName_;
+  private final int minPreLongs_;
+  private final int maxPreLongs_;
 
   static {
     for (MatrixFamily f : values()) {
diff --git a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
new file mode 100644
index 0000000..49a5eec
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.apache.datasketches.memory.UnsafeUtil.unsafe;
+
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.vector.MatrixFamily;
+
+/**
+ * Computes mean and variance for each of d dimensions of an input vector using Welford's online algorithm,
+ * as described in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ * <p>
+ * For serialized images, multi-byte integers (<tt>int</tt> and <tt>long</tt>) are stored in native byte
+ * order. All <tt>byte</tt> values are treated as unsigned.</p>
+ *
+ * <p>An empty object requires 8 bytes. A non-empty sketch requires 16 bytes
+ * of preamble.</p>
+ *
+ * <pre>
+ * Long || Start Byte Adr:
+ * Adr:
+ *      ||       0        |    1   |    2   |    3   |   4   |    5   |    6   |   7   |
+ *  0   || Preamble_Longs | SerVer | FamID  | Flags  |---------Vector Dim. (d)---------|
+ *
+ *      ||       8        |   9    |   10   |   11   |   12  |   13   |   14   |  15   |
+ *  1   ||-------------------------Num. Vectors Processed (n)--------------------------|
+ * </pre>
+ *
+ * @author Jon Malkin
+
+ */
+public class VectorNormalizer {
+  private final int d_;
+  private final double[] mean_;
+  private final double[] M2_;
+  private long n_;
+
+  // Preamble byte Addresses
+  static final int PREAMBLE_LONGS_BYTE = 0;
+  static final int SER_VER_BYTE          = 1;
+  static final int FAMILY_BYTE           = 2;
+  static final int FLAGS_BYTE            = 3;
+  static final int D_INT                 = 4;
+  static final int N_LONG                = 8;
+
+  // flag bit masks
+  static final int EMPTY_FLAG_MASK        = 4;
+
+  // Other constants
+  static final int SER_VER                = 1;
+
+
+  /**
+   * Creates a new, empty VectorNormalizer
+   * @param d The number of dimensions the VectorNormalizer holds
+   */
+  public VectorNormalizer(final int d) {
+    if (d < 1)
+      throw new IllegalArgumentException("d cannot be < 1. Found: " + d);
+
+    d_ = d;
+    mean_ = new double[d_];
+    M2_ = new double[d_];
+    n_ = 0;
+  }
+
+  /**
+   * Copy constructor
+   * @param other The VectorNormalizer to copy
+   */
+  public VectorNormalizer(final VectorNormalizer other) {
+    d_ = other.d_;
+    n_ = other.n_;
+    mean_ = other.mean_.clone();
+    M2_ = other.M2_.clone();
+  }
+
+  private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2) {
+    d_ = d;
+    n_ = n;
+    mean_ = mean;
+    M2_ = M2;
+  }
+
+  /**
+   * Instantiates a VectorNormalizer object from a serialized image
+   * @param srcMem Memory containing the serialized image of a VectorNormalizer object
+   * @return A VectorNormalizer, or null if srcMem is null
+   */
+  static VectorNormalizer heapify(final Memory srcMem) {
+    if (srcMem == null) { return null; }
+
+    final int preLongs = getAndCheckPreLongs(srcMem);
+    if (preLongs < MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
+        || preLongs > MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) {
+      throw new IllegalArgumentException("Possible corruption: Invalid number of preamble longs: " + preLongs);
+    }
+
+    final int serVer = extractSerVer(srcMem);
+    if (serVer != SER_VER) {
+      throw new IllegalArgumentException("Invalid serialization version: " + serVer);
+    }
+
+    final int family = extractFamilyID(srcMem);
+    if (family != MatrixFamily.VECTORNORMALIZER.getID()) {
+      throw new IllegalArgumentException("Possible corruption: Family id (" + family + ") "
+          + "is not a VectorNormalization image");
+    }
+
+    final boolean empty = (extractFlags(srcMem) & EMPTY_FLAG_MASK) > 0;
+    final int d = extractD(srcMem);
+    if (d < 1)
+      throw new IllegalArgumentException("Possible corruption: d cannot be < 1. Found: " + d);
+
+    if (empty) {
+      if (preLongs != MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) {
+        throw new IllegalArgumentException("Possible corruption: Empty flag set but header indicates image has data.");
+      }
+      return new VectorNormalizer(d);
+    }
+
+    if (preLongs == MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) {
+      throw new IllegalArgumentException("Possible corruption: Non-empty image too small to contain serialized data");
+    }
+
+    final long n = extractN(srcMem);
+    if (n <= 0)
+      throw new IllegalArgumentException("Possible corruption: n must be positive for a non-empty sketch. Found: " + n);
+
+    long offsetBytes = (long) preLongs * Long.BYTES;
+
+    // check capacity for the rest
+    final long bytesNeeded = offsetBytes + (2L * d * Double.BYTES);
+    if (srcMem.getCapacity() < bytesNeeded) {
+      throw new IllegalArgumentException(
+          "Possible Corruption: Size of Memory not large enough: Size: " + srcMem.getCapacity()
+              + ", Required: " + bytesNeeded);
+    }
+
+    final double[] mean = new double[d];
+    srcMem.getDoubleArray(offsetBytes, mean, 0, d);
+    offsetBytes += (long) d * Double.BYTES;
+
+    final double[] M2 = new double[d];
+    srcMem.getDoubleArray(offsetBytes, M2, 0, d);
+
+    return new VectorNormalizer(d, n, mean, M2);
+  }
+
+  /**
+   * Returns an array of bytes with a serialized image of this object.
+   * @return A <tt>byte[]</tt> containing the serialized image of this object.
+   */
+  public byte[] toByteArray() {
+    final boolean empty = isEmpty();
+    final int familyId = MatrixFamily.VECTORNORMALIZER.getID();
+
+    final int preLongs = empty
+        ? MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
+        : MatrixFamily.VECTORNORMALIZER.getMaxPreLongs();
+
+    final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : 2 * d_ * Double.BYTES);
+    final byte[] outArr = new byte[outBytes];
+    final WritableMemory memOut = WritableMemory.wrap(outArr);
+    final Object memObj = memOut.getArray();
+    final long memAddr = memOut.getCumulativeOffset(0L);
+
+    insertPreLongs(memObj, memAddr, preLongs);
+    insertSerVer(memObj, memAddr, SER_VER);
+    insertFamilyID(memObj, memAddr, familyId);
+    insertFlags(memObj, memAddr, (empty ? EMPTY_FLAG_MASK : 0));
+    insertD(memObj, memAddr, d_);
+
+    if (!empty) {
+      insertN(memObj, memAddr, n_);
+      long offset = (long) preLongs * Long.BYTES;
+      memOut.putDoubleArray(offset, mean_, 0, d_);
+      offset += (long) d_ * Double.BYTES;
+      memOut.putDoubleArray(offset, M2_, 0, d_);
+    }
+
+    return outArr;
+  }
+
+  /**
+   * Returns true if the object has no data, otherwise false
+   * @return True if the object has no data, otherwise false.
+   */
+  public boolean isEmpty() {
+    return n_ == 0;
+  }
+
+  /**
+   * Returns the number of dimensions configured for this object
+   * @return The number of dimensions
+   */
+  public long getD() {
+    return d_;
+  }
+
+  /**
+   * Returns the number of input vectors processed by this object
+   * @return The number of input vectors processed
+   */
+  public long getN() {
+    return n_;
+  }
+
+  /**
+   * Returns the array of means held by this object
+   * @return The array of means
+   */
+  public double[] getMean() {
+    if (n_ == 0) {
+      final double[] result = new double[d_];
+      for (int i = 0; i < d_; ++i) {
+        result[i] = Double.NaN;
+      }
+      return result;
+    } else {
+      return mean_.clone();
+    }
+  }
+
+  /**
+   * Returns the sample variance array represented in this object. Returns an array of NaN if N = 0 and an
+   * array of zeros if N = 1.
+   * @return The sample variance array represented in this object
+   */
+  public double[] getSampleVariance() {
+    if (n_ == 0) {
+      final double[] result = new double[d_];
+      for (int i = 0; i < d_; ++i) {
+        result[i] = Double.NaN;
+      }
+      return result;
+    } else if (n_ == 1) {
+      return new double[d_]; // array of zeros
+    } else { // n_ > 1
+      double[] result = M2_.clone();
+      for (int i = 0; i < d_; ++i) {
+        result[i] = M2_[i] / n_;
+      }
+      return result;
+    }
+  }
+
+  /**
+   * Returns the population variance array represented in this object. Returns an array of NaN if N = 0 and an
+   * array of zeros if N = 1.
+   * @return The population variance array represented in this object
+   */
+  public double[] getPopulationVariance() {
+    if (n_ == 0) {
+      final double[] result = new double[d_];
+      for (int i = 0; i < d_; ++i) {
+        result[i] = Double.NaN;
+      }
+      return result;
+    } else if (n_ == 1) {
+      return new double[d_]; // array of zeros
+    } else { // n_ > 1
+      double[] result = M2_.clone();
+      for (int i = 0; i < d_; ++i) {
+        result[i] = M2_[i] / (n_ - 1);
+      }
+      return result;
+    }
+  }
+
+  public void update(double[] x) {
+    if (x == null)
+      return;
+
+    if (x.length != d_) {
+      throw new IllegalArgumentException("Input vector length must be " + d_ + ". Found: " + x.length );
+    }
+
+    ++n_;
+    for (int i = 0; i < d_; ++i) {
+      double d1 = x[i] - mean_[i];  // x_i - oldMean_i
+      mean_[i] += d1 / n_;
+      double d2 = x[i] - mean_[i];  // x_i - newMean_i
+      M2_[i] += d1 * d2;
+    }
+  }
+
+  public void merge(VectorNormalizer other) {
+    if (other == null)
+      return;
+
+    if (other.d_ != d_)
+      throw new IllegalArgumentException("Input VectorNormalizer must have d= " + d_ + ". Found: " + other.d_);
+
+    long combinedN = n_ + other.n_;
+    double varCountScalar = (n_ * other.n_) / (double) combinedN; // n_A * n_B / (n_A + n_B)
+    for (int i = 0; i < d_; ++i) {
+      double meanDiff = other.mean_[i] - mean_[i];
+      mean_[i] = ((n_ * mean_[i]) + (other.n_ * other.mean_[i])) / combinedN;
+      M2_[i] += other.M2_[i] + meanDiff * meanDiff * varCountScalar;
+    }
+    n_ += other.n_;
+  }
+
+  public int getSerializedSizeBytes() {
+    if (n_ == 0) {
+      return MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES;
+    } else {
+      return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + (2 * d_ * Double.BYTES);
+    }
+  }
+
+  // Extraction methods
+  static int extractPreLongs(final Memory mem) {
+    return mem.getInt(PREAMBLE_LONGS_BYTE) & 0xFF;
+  }
+
+  static int extractSerVer(final Memory mem) {
+    return mem.getInt(SER_VER_BYTE) & 0xFF;
+  }
+
+  static int extractFamilyID(final Memory mem) {
+    return mem.getByte(FAMILY_BYTE) & 0xFF;
+  }
+
+  static int extractFlags(final Memory mem) {
+    return mem.getByte(FLAGS_BYTE) & 0xFF;
+  }
+
+  static int extractD(final Memory mem) {
+    return mem.getInt(D_INT);
+  }
+
+  static long extractN(final Memory mem) {
+    return mem.getLong(N_LONG);
+  }
+
+
+  // Insertion methods
+  private void insertPreLongs(final Object memObj, final long memAddr, final int preLongs) {
+    unsafe.putByte(memObj, memAddr + PREAMBLE_LONGS_BYTE, (byte) preLongs);
+  }
+
+  private void insertSerVer(final Object memObj, final long memAddr, final int serVer) {
+    unsafe.putByte(memObj, memAddr + SER_VER_BYTE, (byte) serVer);
+  }
+
+  private void insertFamilyID(final Object memObj, final long memAddr, final int matrixFamId) {
+    unsafe.putByte(memObj, memAddr + FAMILY_BYTE, (byte) matrixFamId);
+  }
+
+  private void insertFlags(final Object memObj, final long memAddr, final int flags) {
+    unsafe.putByte(memObj, memAddr + FLAGS_BYTE, (byte) flags);
+  }
+
+  private void insertD(final Object memObj, final long memAddr, final int d) {
+    unsafe.putInt(memObj, memAddr + D_INT, d);
+  }
+
+  private void insertN(final Object memObj, final long memAddr, final long n) {
+    unsafe.putLong(memObj, memAddr + N_LONG, n);
+  }
+
+  /**
+   * Checks Memory for capacity to hold the preamble and returns the extracted preLongs.
+   * @param mem the given Memory
+   * @return the extracted prelongs value.
+   */
+  private static int getAndCheckPreLongs(final Memory mem) {
+    final long cap = mem.getCapacity();
+    if (cap < Long.BYTES) { throwNotBigEnough(cap, Long.BYTES); }
+    final int preLongs = extractPreLongs(mem);
+    final int required = Math.max(preLongs << 2, Long.BYTES);
+    if (cap < required) { throwNotBigEnough(cap, required); }
+    return preLongs;
+  }
+
+  private static void throwNotBigEnough(final long cap, final int required) {
+    throw new IllegalArgumentException(
+        "Possible Corruption: Size of byte array or Memory not large enough: Size: " + cap
+            + ", Required: " + required);
+  }
+}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
new file mode 100644
index 0000000..94da914
--- /dev/null
+++ b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
@@ -0,0 +1,2 @@
+package org.apache.datasketches.vector.regression;public class RidgeRegressionTest {
+}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
new file mode 100644
index 0000000..e6a7807
--- /dev/null
+++ b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.apache.datasketches.vector.regression.VectorNormalizer.D_INT;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.EMPTY_FLAG_MASK;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.FAMILY_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.FLAGS_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.N_LONG;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.PREAMBLE_LONGS_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.SER_VER;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.SER_VER_BYTE;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
+
+import javax.imageio.plugins.jpeg.JPEGImageWriteParam;
+import java.util.concurrent.ThreadLocalRandom;
+
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.vector.MatrixFamily;
+import org.testng.annotations.Test;
+
+
+import com.google.common.primitives.Longs;
+
+public class VectorNormalizerTest {
+
+  @Test
+  public void instantiationTest() {
+    final int d = 5;
+    final VectorNormalizer vn = new VectorNormalizer(d);
+    assertNotNull(vn);
+    assertEquals(vn.getD(), d);
+    assertEquals(vn.getN(), 0);
+    assertTrue(vn.isEmpty());
+
+
+    final double[] mean = vn.getMean();
+    assertNotNull(mean);
+
+    final double[] sampleVar = vn.getSampleVariance();
+    assertNotNull(sampleVar);
+
+    final double[] popVar = vn.getPopulationVariance();
+    assertNotNull(popVar);
+
+    // no data, so everything should be Double.NaN
+    for (int i = 0; i < d; ++i) {
+      assertTrue(Double.isNaN(mean[i]));
+      assertTrue(Double.isNaN(sampleVar[i]));
+      assertTrue(Double.isNaN(popVar[i]));
+    }
+
+    // error case
+    try {
+      new VectorNormalizer(0);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+  }
+
+  @Test
+  public void singleUpdateTest() {
+    final int d = 3;
+    final VectorNormalizer vn = new VectorNormalizer(d);
+
+    final double[] input = {-1, 0, 0.5};
+    vn.update(input);
+    assertEquals(vn.getN(), 1);
+    assertFalse(vn.isEmpty());
+
+    final double[] mean = vn.getMean();
+    assertNotNull(mean);
+
+    final double[] sampleVar = vn.getSampleVariance();
+    assertNotNull(sampleVar);
+
+    final double[] popVar = vn.getPopulationVariance();
+    assertNotNull(popVar);
+
+    // mean should equal input, others should be 0.0
+    for (int i = 0; i < d; ++i) {
+      assertEquals(mean[i], input[i]);
+      assertEquals(sampleVar[i], 0.0);
+      assertEquals(popVar[i], 0.0);
+    }
+  }
+
+  @Test
+  public void multipleUpdateTest() {
+    final int n = 100000;
+    final int d = 3;
+    final double tol = 0.01;
+
+    final VectorNormalizer vn = new VectorNormalizer(d);
+
+    final ThreadLocalRandom rand = ThreadLocalRandom.current();
+    final double[] input = new double[d];
+    for (int i = 0; i < n; ++i) {
+      input[0] = rand.nextGaussian();      // mean = 0.0, var = 1.0
+      input[1] = rand.nextDouble() * 2.0;  // mean = 1.0, var = (2-0)^2/12 = 1/3
+      input[2] = rand.nextDouble() - 0.5;  // mean = 0.0, var = (1-0)^2/12
+      vn.update(input);
+    }
+    assertFalse(vn.isEmpty());
+
+    final double[] mean = vn.getMean();
+    assertNotNull(mean);
+    assertEquals(mean[0], 0.0, tol);
+    assertEquals(mean[1], 1.0, tol);
+    assertEquals(mean[2], 0.0, tol);
+
+    // n is large enough that sample vs population variance won't matter for testing
+    final double[] sampleVar = vn.getSampleVariance();
+    assertNotNull(sampleVar);
+    assertEquals(sampleVar[0], 1.0, tol);
+    assertEquals(sampleVar[1], 1.0 / 3.0, tol);
+    assertEquals(sampleVar[2], 1.0 / 12.0, tol);
+
+    final double[] popVar = vn.getPopulationVariance();
+    assertNotNull(popVar);
+    assertEquals(popVar[0], 1.0, tol);
+    assertEquals(popVar[1], 1.0 / 3.0, tol);
+    assertEquals(popVar[2], 1.0 / 12.0, tol);
+
+    // n is small enough that we still expect a difference with doubles
+    for (int i = 0; i < d; ++i) {
+      assertTrue(popVar[i] > sampleVar[i]);
+    }
+  }
+
+  @Test
+  public void mergeTest() {
+    final int n = 1000000;
+    final int d = 2;
+    final double tol = 0.01;
+    final VectorNormalizer vn1 = new VectorNormalizer(d);
+    final VectorNormalizer vn2 = new VectorNormalizer(d);
+
+    final ThreadLocalRandom rand = ThreadLocalRandom.current();
+
+    // data expectations:
+    // dimension 0: zero-mean, unit-variance Gaussian, even after merging
+    // dimension 1: U[0,2] + U[2,4) -> U[0,4), so mean = 2.0 and var = 4^2/12=4/3
+    final double[] input = new double[d];
+    for (int i = 0; i < n; ++i) {
+      input[0] = rand.nextGaussian();
+      input[1] = (rand.nextDouble() * 2.0) + 2.0;
+      vn1.update(input);
+
+      input[0] = rand.nextGaussian();
+      input[1] = rand.nextDouble() * 2.0;
+      vn2.update(input);
+    }
+
+    vn1.merge(vn2);
+    assertEquals(vn1.getN(), 2 * n);
+
+    final double[] mean = vn1.getMean();
+    assertEquals(mean[0], 0.0, tol);
+    assertEquals(mean[1], 2.0, tol);
+
+    // n is large enough that sample vs population variance won't matter for testing
+    final double[] sampleVar = vn1.getSampleVariance();
+    assertEquals(sampleVar[0], 1.0, tol);
+    assertEquals(sampleVar[1], 4.0 / 3.0, tol);
+
+    final double[] popVar = vn1.getPopulationVariance();
+    assertEquals(popVar[0], 1.0, tol);
+    assertEquals(popVar[1], 4.0 / 3.0, tol);
+  }
+
+  @Test
+  public void invalidUpdateSizeTest() {
+    final int d = 5;
+    final VectorNormalizer vn = new VectorNormalizer(d);
+
+    final double[] input = new double[d];
+    for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
+    vn.update(input);
+    assertEquals(vn.getN(), 1);
+
+    vn.update(null);
+    assertEquals(vn.getN(), 1);
+
+    try {
+      final double[] badInput = {1.0};
+      vn.update(badInput);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+      assertEquals(vn.getN(), 1);
+    }
+  }
+
+  @Test
+  public void invalidMergeSizeTest() {
+    final int d = 3;
+    final VectorNormalizer vn1 = new VectorNormalizer(d);
+
+    double[] input = new double[d];
+    for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
+    vn1.update(input);
+    assertEquals(vn1.getN(), 1);
+
+    vn1.merge(null);
+    assertEquals(vn1.getN(), 1);
+
+    // update with a non-empty VN with a different value of d
+    final int d2 = d + 3;
+    final VectorNormalizer vn2 = new VectorNormalizer(d2);
+    input = new double[d2];
+    for (int i = 0; i < d2; ++i) { input[i] = 1.0 * i; }
+    vn2.update(input);
+    assertEquals(vn2.getN(), 1);
+
+    try {
+      vn1.merge(vn2);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+      assertEquals(vn1.getN(), 1);
+    }
+  }
+
+  @Test
+  public void copyConstructorTest() {
+    final int d = 5;
+    final int n = 100;
+
+    final VectorNormalizer vn = new VectorNormalizer(d);
+    final ThreadLocalRandom rand = ThreadLocalRandom.current();
+    final double[] input = new double[d];
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < d; ++j) {
+        input[j] = rand.nextDouble();
+      }
+      vn.update(input);
+    }
+
+    final VectorNormalizer vnCopy = new VectorNormalizer(vn);
+
+    // we'll assume serialization works for this test and compare serialized images for equality
+    final byte[] origBytes = vn.toByteArray();
+    final byte[] copyBytes = vnCopy.toByteArray();
+    assertEquals(copyBytes, origBytes);
+  }
+
+  @Test
+  public void serializationTest() {
+    final int d = 7;
+    final int n = 10;
+
+    // empty memory should return null
+    assertNull(VectorNormalizer.heapify(null));
+
+    final VectorNormalizer vn = new VectorNormalizer(d);
+
+    // check empty size
+    byte[] outBytes = vn.toByteArray();
+    assertEquals(outBytes.length, MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES);
+    assertEquals(outBytes.length, vn.getSerializedSizeBytes());
+
+    VectorNormalizer rebuilt = VectorNormalizer.heapify(Memory.wrap(outBytes));
+    assertTrue(rebuilt.isEmpty());
+
+    // test with data added
+    final double[] input = new double[d];
+    final ThreadLocalRandom rand = ThreadLocalRandom.current();
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < d; ++j) {
+        input[j] = rand.nextGaussian();
+      }
+      vn.update(input);
+    }
+
+    outBytes = vn.toByteArray();
+    assertEquals(outBytes.length, vn.getSerializedSizeBytes());
+
+    rebuilt = VectorNormalizer.heapify(Memory.wrap(outBytes));
+    assertFalse(rebuilt.isEmpty());
+    assertEquals(vn.getD(), rebuilt.getD());
+    assertEquals(vn.getN(), rebuilt.getN());
+
+    final double[] originalMean = vn.getMean();
+    final double[] rebuiltMean = vn.getMean();
+    final double[] originalVar = vn.getSampleVariance();
+    final double[] rebuiltVar = vn.getSampleVariance();
+
+    for (int i = 0; i < d; ++i) {
+      // expecting identical bits meaning exact equality
+      assertEquals(rebuiltMean[i], originalMean[i]);
+      assertEquals(rebuiltVar[i], originalVar[i]);
+    }
+  }
+
+  @Test
+  public void corruptPreambleTest() {
+    // memory too small
+    byte[] bytes = new byte[3];
+    try {
+      VectorNormalizer.heapify(Memory.wrap(bytes));
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // memory smaller than preLongs
+    bytes = new byte[10];
+    bytes[PREAMBLE_LONGS_BYTE] = 2;
+    try {
+      VectorNormalizer.heapify(Memory.wrap(bytes));
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // invalid preLongs
+    final int preLongs = MatrixFamily.VECTORNORMALIZER.getMaxPreLongs() + 1;
+    bytes = new byte[preLongs * Longs.BYTES];
+    bytes[PREAMBLE_LONGS_BYTE] = (byte) preLongs;
+    try {
+      VectorNormalizer.heapify(Memory.wrap(bytes));
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    final int d = 12;
+    final VectorNormalizer vn = new VectorNormalizer(d);
+
+    // wrong serialization version
+    bytes = vn.toByteArray();
+    bytes[SER_VER_BYTE] = ~SER_VER; // any bits that don't match
+    try {
+      VectorNormalizer.heapify(Memory.wrap(bytes));
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // wrong family id
+    bytes = vn.toByteArray();
+    bytes[FAMILY_BYTE] = (byte) MatrixFamily.FREQUENTDIRECTIONS.getID();
+    try {
+      VectorNormalizer.heapify(Memory.wrap(bytes));
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // invalid d
+    bytes = vn.toByteArray();
+    WritableMemory mem = WritableMemory.wrap(bytes);
+    mem.putInt(D_INT, -1);
+    try {
+      VectorNormalizer.heapify(mem);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+  }
+
+  @Test
+  public void corruptEmptyHeapifyTest() {
+    final int d = 7;
+    final VectorNormalizer vn = new VectorNormalizer(d);
+    byte[] outBytes = vn.toByteArray();
+    WritableMemory mem = WritableMemory.wrap(outBytes);
+
+    // clear empty flag
+    mem.putByte(FLAGS_BYTE, (byte) 0);
+    try {
+      VectorNormalizer.heapify(mem);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+  }
+
+  @Test
+  public void corruptNonEmptyHeapifyTest() {
+    final int d = 1;
+    final int n = 100;
+
+    final VectorNormalizer vn = new VectorNormalizer(d);
+    final ThreadLocalRandom rand = ThreadLocalRandom.current();
+    final double[] input = new double[d];
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < d; ++j) {
+        input[j] = rand.nextDouble();
+      }
+      vn.update(input);
+    }
+    assertFalse(vn.isEmpty());
+
+    // force-set empty flag
+    byte[] bytes = vn.toByteArray();
+    WritableMemory mem = WritableMemory.wrap(bytes);
+    mem.putByte(FLAGS_BYTE, (byte) EMPTY_FLAG_MASK);
+    try {
+      VectorNormalizer.heapify(mem);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // invalid n
+    bytes = vn.toByteArray();
+    mem = WritableMemory.wrap(bytes);
+    mem.putLong(N_LONG, -100);
+    try {
+      VectorNormalizer.heapify(mem);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+
+    // capacity too small for vectors
+    bytes = vn.toByteArray();
+    mem = WritableMemory.allocate(bytes.length - 1);
+    mem.putByteArray(0, bytes, 0, bytes.length - 1);
+    try {
+      VectorNormalizer.heapify(mem);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // expected
+    }
+  }
+
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@datasketches.apache.org
For additional commands, e-mail: commits-help@datasketches.apache.org