You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@datasketches.apache.org by jm...@apache.org on 2021/05/26 21:58:20 UTC
[datasketches-vector] 01/02: Vector normalization to support ridge
regression
This is an automated email from the ASF dual-hosted git repository.
jmalkin pushed a commit to branch fdrr
in repository https://gitbox.apache.org/repos/asf/datasketches-vector.git
commit 718b647534d00fec47891cacd56222769204865e
Author: Jon Malkin <jm...@users.noreply.github.com>
AuthorDate: Wed May 26 14:55:44 2021 -0700
Vector normalization to support ridge regression
---
.../apache/datasketches/vector/MatrixFamily.java | 16 +-
.../vector/regression/VectorNormalizer.java | 402 ++++++++++++++++++
.../vector/regression/RidgeRegressionTest.java | 2 +
.../vector/regression/VectorNormalizerTest.java | 454 +++++++++++++++++++++
4 files changed, 869 insertions(+), 5 deletions(-)
diff --git a/src/main/java/org/apache/datasketches/vector/MatrixFamily.java b/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
index 9c1d141..0d57d75 100644
--- a/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
+++ b/src/main/java/org/apache/datasketches/vector/MatrixFamily.java
@@ -44,15 +44,21 @@ public enum MatrixFamily {
/**
* Select Frequent Directions Family
*/
- FREQUENTDIRECTIONS(129, "FrequentDirections", 2, 4);
+ FREQUENTDIRECTIONS(129, "FrequentDirections", 2, 4),
+
+ /**
+ * Aggregation for vector means/variances.
+ */
+ VECTORNORMALIZER(130, "VectorNormalizer", 1, 2)
+ ;
private static final Map<Integer, MatrixFamily> lookupID = new HashMap<>();
private static final Map<String, MatrixFamily> lookupFamName = new HashMap<>();
- private int id_;
- private String famName_;
- private int minPreLongs_;
- private int maxPreLongs_;
+ private final int id_;
+ private final String famName_;
+ private final int minPreLongs_;
+ private final int maxPreLongs_;
static {
for (MatrixFamily f : values()) {
diff --git a/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
new file mode 100644
index 0000000..49a5eec
--- /dev/null
+++ b/src/main/java/org/apache/datasketches/vector/regression/VectorNormalizer.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.apache.datasketches.memory.UnsafeUtil.unsafe;
+
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.vector.MatrixFamily;
+
+/**
+ * Computes mean and variance for each of d dimensions of an input vector using Welford's online algorithm,
+ * as described in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ * <p>
+ * For serialized images, multi-byte integers (<tt>int</tt> and <tt>long</tt>) are stored in native byte
+ * order. All <tt>byte</tt> values are treated as unsigned.</p>
+ *
+ * <p>An empty object requires 8 bytes. A non-empty sketch requires 16 bytes
+ * of preamble.</p>
+ *
+ * <pre>
+ * Long || Start Byte Adr:
+ * Adr:
+ * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+ * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Vector Dim. (d)---------|
+ *
+ * || 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
+ * 1 ||-------------------------Num. Vectors Processed (n)--------------------------|
+ * </pre>
+ *
+ * @author Jon Malkin
+
+ */
+public class VectorNormalizer {
+ private final int d_;
+ private final double[] mean_;
+ private final double[] M2_;
+ private long n_;
+
+ // Preamble byte Addresses
+ static final int PREAMBLE_LONGS_BYTE = 0;
+ static final int SER_VER_BYTE = 1;
+ static final int FAMILY_BYTE = 2;
+ static final int FLAGS_BYTE = 3;
+ static final int D_INT = 4;
+ static final int N_LONG = 8;
+
+ // flag bit masks
+ static final int EMPTY_FLAG_MASK = 4;
+
+ // Other constants
+ static final int SER_VER = 1;
+
+
+ /**
+ * Creates a new, empty VectorNormalizer
+ * @param d The number of dimensions the VectorNormalizer holds
+ */
+ public VectorNormalizer(final int d) {
+ if (d < 1)
+ throw new IllegalArgumentException("d cannot be < 1. Found: " + d);
+
+ d_ = d;
+ mean_ = new double[d_];
+ M2_ = new double[d_];
+ n_ = 0;
+ }
+
+ /**
+ * Copy constructor
+ * @param other The VectorNormalizer to copy
+ */
+ public VectorNormalizer(final VectorNormalizer other) {
+ d_ = other.d_;
+ n_ = other.n_;
+ mean_ = other.mean_.clone();
+ M2_ = other.M2_.clone();
+ }
+
+ private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2) {
+ d_ = d;
+ n_ = n;
+ mean_ = mean;
+ M2_ = M2;
+ }
+
+ /**
+ * Instantiates a VectorNormalizer object from a serialized image
+ * @param srcMem Memory containing the serialized image of a VectorNormalizer object
+ * @return A VectorNormalizer, or null if srcMem is null
+ */
+ static VectorNormalizer heapify(final Memory srcMem) {
+ if (srcMem == null) { return null; }
+
+ final int preLongs = getAndCheckPreLongs(srcMem);
+ if (preLongs < MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
+ || preLongs > MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) {
+ throw new IllegalArgumentException("Possible corruption: Invalid number of preamble longs: " + preLongs);
+ }
+
+ final int serVer = extractSerVer(srcMem);
+ if (serVer != SER_VER) {
+ throw new IllegalArgumentException("Invalid serialization version: " + serVer);
+ }
+
+ final int family = extractFamilyID(srcMem);
+ if (family != MatrixFamily.VECTORNORMALIZER.getID()) {
+ throw new IllegalArgumentException("Possible corruption: Family id (" + family + ") "
+ + "is not a VectorNormalization image");
+ }
+
+ final boolean empty = (extractFlags(srcMem) & EMPTY_FLAG_MASK) > 0;
+ final int d = extractD(srcMem);
+ if (d < 1)
+ throw new IllegalArgumentException("Possible corruption: d cannot be < 1. Found: " + d);
+
+ if (empty) {
+ if (preLongs != MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) {
+ throw new IllegalArgumentException("Possible corruption: Empty flag set but header indicates image has data.");
+ }
+ return new VectorNormalizer(d);
+ }
+
+ if (preLongs == MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) {
+ throw new IllegalArgumentException("Possible corruption: Non-empty image too small to contain serialized data");
+ }
+
+ final long n = extractN(srcMem);
+ if (n <= 0)
+ throw new IllegalArgumentException("Possible corruption: n must be positive for a non-empty sketch. Found: " + n);
+
+ long offsetBytes = (long) preLongs * Long.BYTES;
+
+ // check capacity for the rest
+ final long bytesNeeded = offsetBytes + (2L * d * Double.BYTES);
+ if (srcMem.getCapacity() < bytesNeeded) {
+ throw new IllegalArgumentException(
+ "Possible Corruption: Size of Memory not large enough: Size: " + srcMem.getCapacity()
+ + ", Required: " + bytesNeeded);
+ }
+
+ final double[] mean = new double[d];
+ srcMem.getDoubleArray(offsetBytes, mean, 0, d);
+ offsetBytes += (long) d * Double.BYTES;
+
+ final double[] M2 = new double[d];
+ srcMem.getDoubleArray(offsetBytes, M2, 0, d);
+
+ return new VectorNormalizer(d, n, mean, M2);
+ }
+
+ /**
+ * Returns an array of bytes with a serialized image of this object.
+ * @return A <tt>byte[]</tt> containing the serialized image of this object.
+ */
+ public byte[] toByteArray() {
+ final boolean empty = isEmpty();
+ final int familyId = MatrixFamily.VECTORNORMALIZER.getID();
+
+ final int preLongs = empty
+ ? MatrixFamily.VECTORNORMALIZER.getMinPreLongs()
+ : MatrixFamily.VECTORNORMALIZER.getMaxPreLongs();
+
+ final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : 2 * d_ * Double.BYTES);
+ final byte[] outArr = new byte[outBytes];
+ final WritableMemory memOut = WritableMemory.wrap(outArr);
+ final Object memObj = memOut.getArray();
+ final long memAddr = memOut.getCumulativeOffset(0L);
+
+ insertPreLongs(memObj, memAddr, preLongs);
+ insertSerVer(memObj, memAddr, SER_VER);
+ insertFamilyID(memObj, memAddr, familyId);
+ insertFlags(memObj, memAddr, (empty ? EMPTY_FLAG_MASK : 0));
+ insertD(memObj, memAddr, d_);
+
+ if (!empty) {
+ insertN(memObj, memAddr, n_);
+ long offset = (long) preLongs * Long.BYTES;
+ memOut.putDoubleArray(offset, mean_, 0, d_);
+ offset += (long) d_ * Double.BYTES;
+ memOut.putDoubleArray(offset, M2_, 0, d_);
+ }
+
+ return outArr;
+ }
+
+ /**
+ * Returns true if the object has no data, otherwise false
+ * @return True if the object has no data, otherwise false.
+ */
+ public boolean isEmpty() {
+ return n_ == 0;
+ }
+
+ /**
+ * Returns the number of dimensions configured for this object
+ * @return The number of dimensions
+ */
+ public long getD() {
+ return d_;
+ }
+
+ /**
+ * Returns the number of input vectors processed by this object
+ * @return The number of input vectors processed
+ */
+ public long getN() {
+ return n_;
+ }
+
+ /**
+ * Returns the array of means held by this object
+ * @return The array of means
+ */
+ public double[] getMean() {
+ if (n_ == 0) {
+ final double[] result = new double[d_];
+ for (int i = 0; i < d_; ++i) {
+ result[i] = Double.NaN;
+ }
+ return result;
+ } else {
+ return mean_.clone();
+ }
+ }
+
+ /**
+ * Returns the sample variance array represented in this object. Returns an array of NaN if N = 0 and an
+ * array of zeros if N = 1.
+ * @return The sample variance array represented in this object
+ */
+ public double[] getSampleVariance() {
+ if (n_ == 0) {
+ final double[] result = new double[d_];
+ for (int i = 0; i < d_; ++i) {
+ result[i] = Double.NaN;
+ }
+ return result;
+ } else if (n_ == 1) {
+ return new double[d_]; // array of zeros
+ } else { // n_ > 1
+ double[] result = M2_.clone();
+ for (int i = 0; i < d_; ++i) {
+ result[i] = M2_[i] / n_;
+ }
+ return result;
+ }
+ }
+
+ /**
+ * Returns the population variance array represented in this object. Returns an array of NaN if N = 0 and an
+ * array of zeros if N = 1.
+ * @return The population variance array represented in this object
+ */
+ public double[] getPopulationVariance() {
+ if (n_ == 0) {
+ final double[] result = new double[d_];
+ for (int i = 0; i < d_; ++i) {
+ result[i] = Double.NaN;
+ }
+ return result;
+ } else if (n_ == 1) {
+ return new double[d_]; // array of zeros
+ } else { // n_ > 1
+ double[] result = M2_.clone();
+ for (int i = 0; i < d_; ++i) {
+ result[i] = M2_[i] / (n_ - 1);
+ }
+ return result;
+ }
+ }
+
+ public void update(double[] x) {
+ if (x == null)
+ return;
+
+ if (x.length != d_) {
+ throw new IllegalArgumentException("Input vector length must be " + d_ + ". Found: " + x.length );
+ }
+
+ ++n_;
+ for (int i = 0; i < d_; ++i) {
+ double d1 = x[i] - mean_[i]; // x_i - oldMean_i
+ mean_[i] += d1 / n_;
+ double d2 = x[i] - mean_[i]; // x_i - newMean_i
+ M2_[i] += d1 * d2;
+ }
+ }
+
+ public void merge(VectorNormalizer other) {
+ if (other == null)
+ return;
+
+ if (other.d_ != d_)
+ throw new IllegalArgumentException("Input VectorNormalizer must have d= " + d_ + ". Found: " + other.d_);
+
+ long combinedN = n_ + other.n_;
+ double varCountScalar = (n_ * other.n_) / (double) combinedN; // n_A * n_B / (n_A + n_B)
+ for (int i = 0; i < d_; ++i) {
+ double meanDiff = other.mean_[i] - mean_[i];
+ mean_[i] = ((n_ * mean_[i]) + (other.n_ * other.mean_[i])) / combinedN;
+ M2_[i] += other.M2_[i] + meanDiff * meanDiff * varCountScalar;
+ }
+ n_ += other.n_;
+ }
+
+ public int getSerializedSizeBytes() {
+ if (n_ == 0) {
+ return MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES;
+ } else {
+ return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + (2 * d_ * Double.BYTES);
+ }
+ }
+
+ // Extraction methods
+ static int extractPreLongs(final Memory mem) {
+ return mem.getInt(PREAMBLE_LONGS_BYTE) & 0xFF;
+ }
+
+ static int extractSerVer(final Memory mem) {
+ return mem.getInt(SER_VER_BYTE) & 0xFF;
+ }
+
+ static int extractFamilyID(final Memory mem) {
+ return mem.getByte(FAMILY_BYTE) & 0xFF;
+ }
+
+ static int extractFlags(final Memory mem) {
+ return mem.getByte(FLAGS_BYTE) & 0xFF;
+ }
+
+ static int extractD(final Memory mem) {
+ return mem.getInt(D_INT);
+ }
+
+ static long extractN(final Memory mem) {
+ return mem.getLong(N_LONG);
+ }
+
+
+ // Insertion methods
+ private void insertPreLongs(final Object memObj, final long memAddr, final int preLongs) {
+ unsafe.putByte(memObj, memAddr + PREAMBLE_LONGS_BYTE, (byte) preLongs);
+ }
+
+ private void insertSerVer(final Object memObj, final long memAddr, final int serVer) {
+ unsafe.putByte(memObj, memAddr + SER_VER_BYTE, (byte) serVer);
+ }
+
+ private void insertFamilyID(final Object memObj, final long memAddr, final int matrixFamId) {
+ unsafe.putByte(memObj, memAddr + FAMILY_BYTE, (byte) matrixFamId);
+ }
+
+ private void insertFlags(final Object memObj, final long memAddr, final int flags) {
+ unsafe.putByte(memObj, memAddr + FLAGS_BYTE, (byte) flags);
+ }
+
+ private void insertD(final Object memObj, final long memAddr, final int d) {
+ unsafe.putInt(memObj, memAddr + D_INT, d);
+ }
+
+ private void insertN(final Object memObj, final long memAddr, final long n) {
+ unsafe.putLong(memObj, memAddr + N_LONG, n);
+ }
+
+ /**
+ * Checks Memory for capacity to hold the preamble and returns the extracted preLongs.
+ * @param mem the given Memory
+ * @return the extracted prelongs value.
+ */
+ private static int getAndCheckPreLongs(final Memory mem) {
+ final long cap = mem.getCapacity();
+ if (cap < Long.BYTES) { throwNotBigEnough(cap, Long.BYTES); }
+ final int preLongs = extractPreLongs(mem);
+ final int required = Math.max(preLongs << 2, Long.BYTES);
+ if (cap < required) { throwNotBigEnough(cap, required); }
+ return preLongs;
+ }
+
+ private static void throwNotBigEnough(final long cap, final int required) {
+ throw new IllegalArgumentException(
+ "Possible Corruption: Size of byte array or Memory not large enough: Size: " + cap
+ + ", Required: " + required);
+ }
+}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
new file mode 100644
index 0000000..94da914
--- /dev/null
+++ b/src/test/java/org/apache/datasketches/vector/regression/RidgeRegressionTest.java
@@ -0,0 +1,2 @@
+package org.apache.datasketches.vector.regression;public class RidgeRegressionTest {
+}
diff --git a/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
new file mode 100644
index 0000000..e6a7807
--- /dev/null
+++ b/src/test/java/org/apache/datasketches/vector/regression/VectorNormalizerTest.java
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.datasketches.vector.regression;
+
+import static org.apache.datasketches.vector.regression.VectorNormalizer.D_INT;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.EMPTY_FLAG_MASK;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.FAMILY_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.FLAGS_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.N_LONG;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.PREAMBLE_LONGS_BYTE;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.SER_VER;
+import static org.apache.datasketches.vector.regression.VectorNormalizer.SER_VER_BYTE;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
+
+import javax.imageio.plugins.jpeg.JPEGImageWriteParam;
+import java.util.concurrent.ThreadLocalRandom;
+
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.memory.WritableMemory;
+import org.apache.datasketches.vector.MatrixFamily;
+import org.testng.annotations.Test;
+
+
+import com.google.common.primitives.Longs;
+
+public class VectorNormalizerTest {
+
+ @Test
+ public void instantiationTest() {
+ final int d = 5;
+ final VectorNormalizer vn = new VectorNormalizer(d);
+ assertNotNull(vn);
+ assertEquals(vn.getD(), d);
+ assertEquals(vn.getN(), 0);
+ assertTrue(vn.isEmpty());
+
+
+ final double[] mean = vn.getMean();
+ assertNotNull(mean);
+
+ final double[] sampleVar = vn.getSampleVariance();
+ assertNotNull(sampleVar);
+
+ final double[] popVar = vn.getPopulationVariance();
+ assertNotNull(popVar);
+
+ // no data, so everything should be Double.NaN
+ for (int i = 0; i < d; ++i) {
+ assertTrue(Double.isNaN(mean[i]));
+ assertTrue(Double.isNaN(sampleVar[i]));
+ assertTrue(Double.isNaN(popVar[i]));
+ }
+
+ // error case
+ try {
+ new VectorNormalizer(0);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void singleUpdateTest() {
+ final int d = 3;
+ final VectorNormalizer vn = new VectorNormalizer(d);
+
+ final double[] input = {-1, 0, 0.5};
+ vn.update(input);
+ assertEquals(vn.getN(), 1);
+ assertFalse(vn.isEmpty());
+
+ final double[] mean = vn.getMean();
+ assertNotNull(mean);
+
+ final double[] sampleVar = vn.getSampleVariance();
+ assertNotNull(sampleVar);
+
+ final double[] popVar = vn.getPopulationVariance();
+ assertNotNull(popVar);
+
+ // mean should equal input, others should be 0.0
+ for (int i = 0; i < d; ++i) {
+ assertEquals(mean[i], input[i]);
+ assertEquals(sampleVar[i], 0.0);
+ assertEquals(popVar[i], 0.0);
+ }
+ }
+
+ @Test
+ public void multipleUpdateTest() {
+ final int n = 100000;
+ final int d = 3;
+ final double tol = 0.01;
+
+ final VectorNormalizer vn = new VectorNormalizer(d);
+
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+ final double[] input = new double[d];
+ for (int i = 0; i < n; ++i) {
+ input[0] = rand.nextGaussian(); // mean = 0.0, var = 1.0
+ input[1] = rand.nextDouble() * 2.0; // mean = 1.0, var = (2-0)^2/12 = 1/3
+ input[2] = rand.nextDouble() - 0.5; // mean = 0.0, var = (1-0)^2/12
+ vn.update(input);
+ }
+ assertFalse(vn.isEmpty());
+
+ final double[] mean = vn.getMean();
+ assertNotNull(mean);
+ assertEquals(mean[0], 0.0, tol);
+ assertEquals(mean[1], 1.0, tol);
+ assertEquals(mean[2], 0.0, tol);
+
+ // n is large enough that sample vs population variance won't matter for testing
+ final double[] sampleVar = vn.getSampleVariance();
+ assertNotNull(sampleVar);
+ assertEquals(sampleVar[0], 1.0, tol);
+ assertEquals(sampleVar[1], 1.0 / 3.0, tol);
+ assertEquals(sampleVar[2], 1.0 / 12.0, tol);
+
+ final double[] popVar = vn.getPopulationVariance();
+ assertNotNull(popVar);
+ assertEquals(popVar[0], 1.0, tol);
+ assertEquals(popVar[1], 1.0 / 3.0, tol);
+ assertEquals(popVar[2], 1.0 / 12.0, tol);
+
+ // n is small enough that we still expect a difference with doubles
+ for (int i = 0; i < d; ++i) {
+ assertTrue(popVar[i] > sampleVar[i]);
+ }
+ }
+
+ @Test
+ public void mergeTest() {
+ final int n = 1000000;
+ final int d = 2;
+ final double tol = 0.01;
+ final VectorNormalizer vn1 = new VectorNormalizer(d);
+ final VectorNormalizer vn2 = new VectorNormalizer(d);
+
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+
+ // data expectations:
+ // dimension 0: zero-mean, unit-variance Gaussian, even after merging
+ // dimension 1: U[0,2] + U[2,4) -> U[0,4), so mean = 2.0 and var = 4^2/12=4/3
+ final double[] input = new double[d];
+ for (int i = 0; i < n; ++i) {
+ input[0] = rand.nextGaussian();
+ input[1] = (rand.nextDouble() * 2.0) + 2.0;
+ vn1.update(input);
+
+ input[0] = rand.nextGaussian();
+ input[1] = rand.nextDouble() * 2.0;
+ vn2.update(input);
+ }
+
+ vn1.merge(vn2);
+ assertEquals(vn1.getN(), 2 * n);
+
+ final double[] mean = vn1.getMean();
+ assertEquals(mean[0], 0.0, tol);
+ assertEquals(mean[1], 2.0, tol);
+
+ // n is large enough that sample vs population variance won't matter for testing
+ final double[] sampleVar = vn1.getSampleVariance();
+ assertEquals(sampleVar[0], 1.0, tol);
+ assertEquals(sampleVar[1], 4.0 / 3.0, tol);
+
+ final double[] popVar = vn1.getPopulationVariance();
+ assertEquals(popVar[0], 1.0, tol);
+ assertEquals(popVar[1], 4.0 / 3.0, tol);
+ }
+
+ @Test
+ public void invalidUpdateSizeTest() {
+ final int d = 5;
+ final VectorNormalizer vn = new VectorNormalizer(d);
+
+ final double[] input = new double[d];
+ for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
+ vn.update(input);
+ assertEquals(vn.getN(), 1);
+
+ vn.update(null);
+ assertEquals(vn.getN(), 1);
+
+ try {
+ final double[] badInput = {1.0};
+ vn.update(badInput);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ assertEquals(vn.getN(), 1);
+ }
+ }
+
+ @Test
+ public void invalidMergeSizeTest() {
+ final int d = 3;
+ final VectorNormalizer vn1 = new VectorNormalizer(d);
+
+ double[] input = new double[d];
+ for (int i = 0; i < d; ++i) { input[i] = 1.0 * i; }
+ vn1.update(input);
+ assertEquals(vn1.getN(), 1);
+
+ vn1.merge(null);
+ assertEquals(vn1.getN(), 1);
+
+ // update with a non-empty VN with a different value of d
+ final int d2 = d + 3;
+ final VectorNormalizer vn2 = new VectorNormalizer(d2);
+ input = new double[d2];
+ for (int i = 0; i < d2; ++i) { input[i] = 1.0 * i; }
+ vn2.update(input);
+ assertEquals(vn2.getN(), 1);
+
+ try {
+ vn1.merge(vn2);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ assertEquals(vn1.getN(), 1);
+ }
+ }
+
+ @Test
+ public void copyConstructorTest() {
+ final int d = 5;
+ final int n = 100;
+
+ final VectorNormalizer vn = new VectorNormalizer(d);
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+ final double[] input = new double[d];
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < d; ++j) {
+ input[j] = rand.nextDouble();
+ }
+ vn.update(input);
+ }
+
+ final VectorNormalizer vnCopy = new VectorNormalizer(vn);
+
+ // we'll assume serialization works for this test and compare serialized images for equality
+ final byte[] origBytes = vn.toByteArray();
+ final byte[] copyBytes = vnCopy.toByteArray();
+ assertEquals(copyBytes, origBytes);
+ }
+
+ @Test
+ public void serializationTest() {
+ final int d = 7;
+ final int n = 10;
+
+ // empty memory should return null
+ assertNull(VectorNormalizer.heapify(null));
+
+ final VectorNormalizer vn = new VectorNormalizer(d);
+
+ // check empty size
+ byte[] outBytes = vn.toByteArray();
+ assertEquals(outBytes.length, MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES);
+ assertEquals(outBytes.length, vn.getSerializedSizeBytes());
+
+ VectorNormalizer rebuilt = VectorNormalizer.heapify(Memory.wrap(outBytes));
+ assertTrue(rebuilt.isEmpty());
+
+ // test with data added
+ final double[] input = new double[d];
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < d; ++j) {
+ input[j] = rand.nextGaussian();
+ }
+ vn.update(input);
+ }
+
+ outBytes = vn.toByteArray();
+ assertEquals(outBytes.length, vn.getSerializedSizeBytes());
+
+ rebuilt = VectorNormalizer.heapify(Memory.wrap(outBytes));
+ assertFalse(rebuilt.isEmpty());
+ assertEquals(vn.getD(), rebuilt.getD());
+ assertEquals(vn.getN(), rebuilt.getN());
+
+ final double[] originalMean = vn.getMean();
+ final double[] rebuiltMean = vn.getMean();
+ final double[] originalVar = vn.getSampleVariance();
+ final double[] rebuiltVar = vn.getSampleVariance();
+
+ for (int i = 0; i < d; ++i) {
+ // expecting identical bits meaning exact equality
+ assertEquals(rebuiltMean[i], originalMean[i]);
+ assertEquals(rebuiltVar[i], originalVar[i]);
+ }
+ }
+
+ @Test
+ public void corruptPreambleTest() {
+ // memory too small
+ byte[] bytes = new byte[3];
+ try {
+ VectorNormalizer.heapify(Memory.wrap(bytes));
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // memory smaller than preLongs
+ bytes = new byte[10];
+ bytes[PREAMBLE_LONGS_BYTE] = 2;
+ try {
+ VectorNormalizer.heapify(Memory.wrap(bytes));
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // invalid preLongs
+ final int preLongs = MatrixFamily.VECTORNORMALIZER.getMaxPreLongs() + 1;
+ bytes = new byte[preLongs * Longs.BYTES];
+ bytes[PREAMBLE_LONGS_BYTE] = (byte) preLongs;
+ try {
+ VectorNormalizer.heapify(Memory.wrap(bytes));
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ final int d = 12;
+ final VectorNormalizer vn = new VectorNormalizer(d);
+
+ // wrong serialization version
+ bytes = vn.toByteArray();
+ bytes[SER_VER_BYTE] = ~SER_VER; // any bits that don't match
+ try {
+ VectorNormalizer.heapify(Memory.wrap(bytes));
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // wrong family id
+ bytes = vn.toByteArray();
+ bytes[FAMILY_BYTE] = (byte) MatrixFamily.FREQUENTDIRECTIONS.getID();
+ try {
+ VectorNormalizer.heapify(Memory.wrap(bytes));
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // invalid d
+ bytes = vn.toByteArray();
+ WritableMemory mem = WritableMemory.wrap(bytes);
+ mem.putInt(D_INT, -1);
+ try {
+ VectorNormalizer.heapify(mem);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void corruptEmptyHeapifyTest() {
+ final int d = 7;
+ final VectorNormalizer vn = new VectorNormalizer(d);
+ byte[] outBytes = vn.toByteArray();
+ WritableMemory mem = WritableMemory.wrap(outBytes);
+
+ // clear empty flag
+ mem.putByte(FLAGS_BYTE, (byte) 0);
+ try {
+ VectorNormalizer.heapify(mem);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void corruptNonEmptyHeapifyTest() {
+ final int d = 1;
+ final int n = 100;
+
+ final VectorNormalizer vn = new VectorNormalizer(d);
+ final ThreadLocalRandom rand = ThreadLocalRandom.current();
+ final double[] input = new double[d];
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < d; ++j) {
+ input[j] = rand.nextDouble();
+ }
+ vn.update(input);
+ }
+ assertFalse(vn.isEmpty());
+
+ // force-set empty flag
+ byte[] bytes = vn.toByteArray();
+ WritableMemory mem = WritableMemory.wrap(bytes);
+ mem.putByte(FLAGS_BYTE, (byte) EMPTY_FLAG_MASK);
+ try {
+ VectorNormalizer.heapify(mem);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // invalid n
+ bytes = vn.toByteArray();
+ mem = WritableMemory.wrap(bytes);
+ mem.putLong(N_LONG, -100);
+ try {
+ VectorNormalizer.heapify(mem);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+
+ // capacity too small for vectors
+ bytes = vn.toByteArray();
+ mem = WritableMemory.allocate(bytes.length - 1);
+ mem.putByteArray(0, bytes, 0, bytes.length - 1);
+ try {
+ VectorNormalizer.heapify(mem);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@datasketches.apache.org
For additional commands, e-mail: commits-help@datasketches.apache.org