You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2015/09/20 21:50:29 UTC
[3/4] [math] MATH-1278
MATH-1278
Deep copy of "Neuron", "Network" and "NeuronSquareMesh2D".
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f13693fd
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f13693fd
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f13693fd
Branch: refs/heads/MATH_3_X
Commit: f13693fdc3cd00c6acbcd51672076e38a333778c
Parents: 78b9d81
Author: Gilles <er...@apache.org>
Authored: Sun Sep 20 21:45:31 2015 +0200
Committer: Gilles <er...@apache.org>
Committed: Sun Sep 20 21:45:31 2015 +0200
----------------------------------------------------------------------
.../commons/math3/ml/neuralnet/Network.java | 24 ++++++++++
.../commons/math3/ml/neuralnet/Neuron.java | 16 +++++++
.../ml/neuralnet/twod/NeuronSquareMesh2D.java | 48 ++++++++++++++++++++
.../commons/math3/ml/neuralnet/NetworkTest.java | 41 +++++++++++++++++
.../commons/math3/ml/neuralnet/NeuronTest.java | 26 +++++++++++
5 files changed, 155 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
index 6c4b8e9..70d8bb2 100644
--- a/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Network.java
@@ -28,6 +28,7 @@ import java.util.Collection;
import java.util.Iterator;
import java.util.Comparator;
import java.util.Collections;
+import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.exception.DimensionMismatchException;
@@ -135,6 +136,29 @@ public class Network
}
/**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ */
+ public synchronized Network copy() {
+ final Network copy = new Network(nextId.get(),
+ featureSize);
+
+
+ for (Map.Entry<Long, Neuron> e : neuronMap.entrySet()) {
+ copy.neuronMap.put(e.getKey(), e.getValue().copy());
+ }
+
+ for (Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) {
+ copy.linkMap.put(e.getKey(), new HashSet<Long>(e.getValue()));
+ }
+
+ return copy;
+ }
+
+ /**
* {@inheritDoc}
*/
public Iterator<Neuron> iterator() {
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
index 3fd0c0a..300fa50 100644
--- a/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/Neuron.java
@@ -67,6 +67,22 @@ public class Neuron implements Serializable {
}
/**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ */
+ public synchronized Neuron copy() {
+ final Neuron copy = new Neuron(getIdentifier(),
+ getFeatures());
+ copy.numberOfAttemptedUpdates.set(numberOfAttemptedUpdates.get());
+ copy.numberOfSuccessfulUpdates.set(numberOfSuccessfulUpdates.get());
+
+ return copy;
+ }
+
+ /**
* Gets the neuron's identifier.
*
* @return the identifier.
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
index 2f4dd2d..d1c692e 100644
--- a/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
+++ b/src/main/java/org/apache/commons/math3/ml/neuralnet/twod/NeuronSquareMesh2D.java
@@ -197,6 +197,54 @@ public class NeuronSquareMesh2D
createLinks();
}
+ /**
+ * Constructor with restricted access, solely used for making a
+ * {@link #copy() deep copy}.
+ *
+ * @param wrapRowDim Whether to wrap the first dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param wrapColDim Whether to wrap the second dimension (i.e the first
+ * and last neurons will be linked together).
+ * @param neighbourhoodType Neighbourhood type.
+ * @param net Underlying network.
+ * @param idGrid Neuron identifiers.
+ */
+ private NeuronSquareMesh2D(boolean wrapRowDim,
+ boolean wrapColDim,
+ SquareNeighbourhood neighbourhoodType,
+ Network net,
+ long[][] idGrid) {
+ numberOfRows = idGrid.length;
+ numberOfColumns = idGrid[0].length;
+ wrapRows = wrapRowDim;
+ wrapColumns = wrapColDim;
+ neighbourhood = neighbourhoodType;
+ network = net;
+ identifiers = idGrid;
+ }
+
+ /**
+ * Performs a deep copy of this instance.
+ * Upon return, the copied and original instances will be independent:
+ * Updating one will not affect the other.
+ *
+ * @return a new instance with the same state as this instance.
+ */
+ public synchronized NeuronSquareMesh2D copy() {
+ final long[][] idGrid = new long[numberOfRows][numberOfColumns];
+ for (int r = 0; r < numberOfRows; r++) {
+ for (int c = 0; c < numberOfColumns; c++) {
+ idGrid[r][c] = identifiers[r][c];
+ }
+ }
+
+ return new NeuronSquareMesh2D(wrapRows,
+ wrapColumns,
+ neighbourhood,
+ network.copy(),
+ idGrid);
+ }
+
/** {@inheritDoc} */
public Iterator<Neuron> iterator() {
return network.iterator();
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
index 7f2bec9..aa83196 100644
--- a/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NetworkTest.java
@@ -127,6 +127,47 @@ public class NetworkTest {
Assert.assertFalse(isUnspecifiedOrder);
}
+ /*
+ * Test assumes that the network is
+ *
+ * 0-----1
+ * | |
+ * | |
+ * 2-----3
+ */
+ @Test
+ public void testCopy() {
+ final FeatureInitializer[] initArray = { init };
+ final Network net = new NeuronSquareMesh2D(2, false,
+ 2, false,
+ SquareNeighbourhood.VON_NEUMANN,
+ initArray).getNetwork();
+
+ final Network copy = net.copy();
+
+ final Neuron netNeuron0 = net.getNeuron(0);
+ final Neuron copyNeuron0 = copy.getNeuron(0);
+ final Neuron netNeuron1 = net.getNeuron(1);
+ final Neuron copyNeuron1 = copy.getNeuron(1);
+ Collection<Neuron> netNeighbours;
+ Collection<Neuron> copyNeighbours;
+
+ // Check that both networks have the same connections.
+ netNeighbours = net.getNeighbours(netNeuron0);
+ copyNeighbours = copy.getNeighbours(copyNeuron0);
+ Assert.assertTrue(netNeighbours.contains(netNeuron1));
+ Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
+
+ // Delete neuron 1 from original.
+ net.deleteNeuron(netNeuron1);
+
+ // Check that the networks now differ.
+ netNeighbours = net.getNeighbours(netNeuron0);
+ copyNeighbours = copy.getNeighbours(copyNeuron0);
+ Assert.assertFalse(netNeighbours.contains(netNeuron1));
+ Assert.assertTrue(copyNeighbours.contains(copyNeuron1));
+ }
+
@Test
public void testSerialize()
throws IOException,
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f13693fd/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
index b03f07d..376d91c 100644
--- a/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
+++ b/src/test/java/org/apache/commons/math3/ml/neuralnet/NeuronTest.java
@@ -86,6 +86,32 @@ public class NeuronTest {
}
@Test
+ public void testCopy() {
+ final Neuron n = new Neuron(1, new double[] { 9.87 });
+
+ // Update original.
+ double[] update = new double[] { n.getFeatures()[0] + 2.34 };
+ n.compareAndSetFeatures(n.getFeatures(), update);
+
+ // Create a copy.
+ final Neuron copy = n.copy();
+
+ // Check that original and copy have the same value.
+ Assert.assertTrue(n.getFeatures()[0] == copy.getFeatures()[0]);
+ Assert.assertEquals(n.getNumberOfAttemptedUpdates(),
+ copy.getNumberOfAttemptedUpdates());
+
+ // Update original.
+ update = new double[] { 1.23 * n.getFeatures()[0] };
+ n.compareAndSetFeatures(n.getFeatures(), update);
+
+ // Check that original and copy differ.
+ Assert.assertFalse(n.getFeatures()[0] == copy.getFeatures()[0]);
+ Assert.assertNotEquals(n.getNumberOfSuccessfulUpdates(),
+ copy.getNumberOfSuccessfulUpdates());
+ }
+
+ @Test
public void testSerialize()
throws IOException,
ClassNotFoundException {