You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/25 11:46:47 UTC
[15/20] ignite git commit: IGNITE-7174: Local MLP
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java
new file mode 100644
index 0000000..639bed0
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/TransformationLayerArchitecture.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.architecture;
+
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableDoubleToDoubleFunction;
+
+/**
+ * Class encapsulation architecture of transformation layer (i.e. non-input layer).
+ */
+public class TransformationLayerArchitecture extends LayerArchitecture {
+ /**
+ * Flag indicating presence of bias in layer.
+ */
+ private boolean hasBias;
+
+ /**
+ * Activation function for layer.
+ */
+ private IgniteDifferentiableDoubleToDoubleFunction activationFunction;
+
+ /**
+ * Construct TransformationLayerArchitecture.
+ *
+ * @param neuronsCnt Count of neurons in this layer.
+ * @param hasBias Flag indicating presence of bias in layer.
+ * @param activationFunction Activation function for layer.
+ */
+ public TransformationLayerArchitecture(int neuronsCnt, boolean hasBias,
+ IgniteDifferentiableDoubleToDoubleFunction activationFunction) {
+ super(neuronsCnt);
+
+ this.hasBias = hasBias;
+ this.activationFunction = activationFunction;
+ }
+
+ /**
+ * Checks if this layer has a bias.
+ *
+ * @return Value of predicate "this layer has a bias".
+ */
+ public boolean hasBias() {
+ return hasBias;
+ }
+
+ /**
+ * Get activation function for this layer.
+ *
+ * @return Activation function for this layer.
+ */
+ public IgniteDifferentiableDoubleToDoubleFunction activationFunction() {
+ return activationFunction;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java
new file mode 100644
index 0000000..aff2d20
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/architecture/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains multilayer perceptron architecture classes.
+ */
+package org.apache.ignite.ml.nn.architecture;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java
new file mode 100644
index 0000000..680508c
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/MLPInitializer.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.initializers;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Interface for classes encapsulating logic for initialization of weights and biases of MLP.
+ */
+public interface MLPInitializer {
+ /**
+ * In-place change values of matrix representing weights.
+ *
+ * @param weights Matrix representing weights.
+ */
+ void initWeights(Matrix weights);
+
+ /**
+ * In-place change values of vector representing vectors.
+ *
+ * @param biases Vector representing vectors.
+ */
+ void initBiases(Vector biases);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java
new file mode 100644
index 0000000..18cb8a6
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/RandomInitializer.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.initializers;
+
+import java.util.Random;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Class for initialization of MLP parameters with random uniformly distributed numbers from -1 to 1.
+ */
+public class RandomInitializer implements MLPInitializer {
+ /**
+ * RNG.
+ */
+ Random rnd;
+
+ /**
+ * Construct RandomInitializer from given RNG.
+ *
+ * @param rnd RNG.
+ */
+ public RandomInitializer(Random rnd) {
+ this.rnd = rnd;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void initWeights(Matrix weights) {
+ weights.map(value -> 2 * (rnd.nextDouble() - 0.5));
+ }
+
+ /** {@inheritDoc} */
+ @Override public void initBiases(Vector biases) {
+ biases.map(value -> 2 * (rnd.nextDouble() - 0.5));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java
new file mode 100644
index 0000000..351783b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/initializers/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains multilayer perceptron parameters initializers.
+ */
+package org.apache.ignite.ml.nn.initializers;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java
new file mode 100644
index 0000000..1641147
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains neural networks and related classes.
+ */
+package org.apache.ignite.ml.nn;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java
new file mode 100644
index 0000000..64a1956
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/LocalBatchTrainer.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.trainers.local;
+
+import org.apache.ignite.IgniteLogger;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.apache.ignite.ml.nn.LocalBatchTrainerInput;
+import org.apache.ignite.ml.nn.updaters.ParameterUpdater;
+import org.apache.ignite.ml.nn.updaters.UpdaterParams;
+
+/**
+ * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from
+ * Ignite cache.
+ */
+public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P extends UpdaterParams<? super M>>
+ implements Trainer<M, LocalBatchTrainerInput<M>> {
+ /**
+ * Supplier for updater function.
+ */
+ private final IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier;
+
+ /**
+ * Error threshold.
+ */
+ private final double errorThreshold;
+
+ /**
+ * Maximal iterations count.
+ */
+ private final int maxIterations;
+
+ /**
+ * Loss function.
+ */
+ private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+ /**
+ * Logger.
+ */
+ private IgniteLogger log;
+
+ /**
+ * Construct a trainer.
+ *
+ * @param loss Loss function.
+ * @param updaterSupplier Supplier of updater function.
+ * @param errorThreshold Error threshold.
+ * @param maxIterations Maximal iterations count.
+ */
+ public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
+ IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier, double errorThreshold, int maxIterations) {
+ this.loss = loss;
+ this.updaterSupplier = updaterSupplier;
+ this.errorThreshold = errorThreshold;
+ this.maxIterations = maxIterations;
+ }
+
+ /** {@inheritDoc} */
+ @Override public M train(LocalBatchTrainerInput<M> data) {
+ int i = 0;
+ M mdl = data.mdl();
+ double err;
+
+ ParameterUpdater<? super M, P> updater = updaterSupplier.get();
+
+ P updaterParams = updater.init(mdl, loss);
+
+ while (i < maxIterations) {
+ IgniteBiTuple<Matrix, Matrix> batch = data.getBatch();
+ Matrix input = batch.get1();
+ Matrix truth = batch.get2();
+
+ updaterParams = updater.updateParams(mdl, updaterParams, i, input, truth);
+
+ // Update mdl with updater parameters.
+ mdl = updaterParams.update(mdl);
+
+ Matrix predicted = mdl.apply(input);
+
+ int batchSize = input.columnSize();
+
+ err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) ->
+ loss.apply(truthCol).apply(predCol)).sum() / batchSize;
+
+ debug("Error: " + err);
+
+ if (err < errorThreshold)
+ break;
+
+ i++;
+ }
+
+ return mdl;
+ }
+
+ /**
+ * Construct new trainer with the same parameters as this trainer, but with new loss.
+ *
+ * @param loss New loss function.
+ * @return new trainer with the same parameters as this trainer, but with new loss.
+ */
+ public LocalBatchTrainer withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+ return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+ }
+
+ /**
+ * Construct new trainer with the same parameters as this trainer, but with new updater supplier.
+ *
+ * @param updaterSupplier New updater supplier.
+ * @return new trainer with the same parameters as this trainer, but with new updater supplier.
+ */
+ public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdater<? super M, P>> updaterSupplier) {
+ return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+ }
+
+ /**
+ * Construct new trainer with the same parameters as this trainer, but with new error threshold.
+ *
+ * @param errorThreshold New error threshold.
+ * @return new trainer with the same parameters as this trainer, but with new error threshold.
+ */
+ public LocalBatchTrainer withErrorThreshold(double errorThreshold) {
+ return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+ }
+
+ /**
+ * Construct new trainer with the same parameters as this trainer, but with new maximal iterations count.
+ *
+ * @param maxIterations New maximal iterations count.
+ * @return new trainer with the same parameters as this trainer, but with new maximal iterations count.
+ */
+ public LocalBatchTrainer withMaxIterations(int maxIterations) {
+ return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+ }
+
+ /**
+ * Set logger.
+ *
+ * @param log Logger.
+ * @return This object.
+ */
+ public LocalBatchTrainer setLogger(IgniteLogger log) {
+ this.log = log;
+
+ return this;
+ }
+
+ /**
+ * Output debug message.
+ *
+ * @param msg Message.
+ */
+ private void debug(String msg) {
+ if (log != null && log.isDebugEnabled())
+ log.debug(msg);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java
new file mode 100644
index 0000000..7065e2f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/MLPLocalBatchTrainer.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.trainers.local;
+
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.nn.LossFunctions;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.updaters.ParameterUpdater;
+import org.apache.ignite.ml.nn.updaters.RPropUpdater;
+import org.apache.ignite.ml.nn.updaters.RPropUpdaterParams;
+import org.apache.ignite.ml.nn.updaters.UpdaterParams;
+
+/**
+ * Local batch trainer for MLP.
+ *
+ * @param <P> Parameter updater parameters.
+ */
+public class MLPLocalBatchTrainer<P extends UpdaterParams<? super MultilayerPerceptron>>
+ extends LocalBatchTrainer<MultilayerPerceptron, P> {
+ /**
+ * Default loss function.
+ */
+ private static final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> DEFAULT_LOSS =
+ LossFunctions.MSE;
+
+ /**
+ * Default error threshold.
+ */
+ private static final double DEFAULT_ERROR_THRESHOLD = 1E-5;
+
+ /**
+ * Default maximal iterations count.
+ */
+ private static final int DEFAULT_MAX_ITERATIONS = 100;
+
+
+ /**
+ * Construct a trainer.
+ *
+ * @param loss Loss function.
+ * @param updaterSupplier Supplier of updater function.
+ * @param errorThreshold Error threshold.
+ * @param maxIterations Maximal iterations count.
+ */
+ public MLPLocalBatchTrainer(
+ IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
+ IgniteSupplier<ParameterUpdater<? super MultilayerPerceptron, P>> updaterSupplier,
+ double errorThreshold, int maxIterations) {
+ super(loss, updaterSupplier, errorThreshold, maxIterations);
+ }
+
+ /**
+ * Get MLPLocalBatchTrainer with default parameters.
+ *
+ * @return MLPLocalBatchTrainer with default parameters.
+ */
+ public static MLPLocalBatchTrainer<RPropUpdaterParams> getDefault() {
+ return new MLPLocalBatchTrainer<>(DEFAULT_LOSS, RPropUpdater::new, DEFAULT_ERROR_THRESHOLD, DEFAULT_MAX_ITERATIONS);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java
new file mode 100644
index 0000000..b78adb8
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/local/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains multilayer perceptron local trainers.
+ */
+package org.apache.ignite.ml.nn.trainers.local;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java
new file mode 100644
index 0000000..c90f67a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/trainers/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains multilayer perceptron trainers.
+ */
+package org.apache.ignite.ml.nn.trainers;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java
new file mode 100644
index 0000000..b33c2c7
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/BaseSmoothParametrized.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Interface for models which are smooth functions of their parameters.
+ */
+interface BaseSmoothParametrized<M extends BaseSmoothParametrized<M>> {
+ /**
+ * Compose function in the following way: feed output of this model as input to second argument to loss function.
+ * After that we have a function g of three arguments: input, ground truth, parameters.
+ * If we consider function
+ * h(w) = 1 / M sum_{i=1}^{M} g(w, input_i, groundTruth_i),
+ * where M is number of entries in batch, we get function of one argument: parameters vector w.
+ * This function is being differentiated.
+ *
+ * @param loss Loss function.
+ * @param inputsBatch Batch of inputs.
+ * @param truthBatch Batch of ground truths.
+ * @return Gradient of h at current point in parameters space.
+ */
+ Vector differentiateByParameters(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, Matrix inputsBatch, Matrix truthBatch);
+
+ /**
+ * Get parameters vector.
+ *
+ * @return Parameters vector.
+ */
+ Vector parameters();
+
+ /**
+ * Set parameters.
+ *
+ * @param vector Parameters vector.
+ */
+ M setParameters(Vector vector);
+
+ /**
+ * Get count of parameters of this model.
+ *
+ * @return Count of parameters of this model.
+ */
+ int parametersCount();
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java
new file mode 100644
index 0000000..7b6a0c7
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdater.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Class encapsulating Nesterov algorithm for MLP parameters update.
+ */
+public class NesterovUpdater implements ParameterUpdater<SmoothParametrized, NesterovUpdaterParams> {
+ /**
+ * Learning rate.
+ */
+ private final double learningRate;
+
+ /**
+ * Loss function.
+ */
+ private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+ /**
+ * Momentum constant.
+ */
+ protected double momentum;
+
+ /**
+ * Construct NesterovUpdater.
+ *
+ * @param momentum Momentum constant.
+ */
+ public NesterovUpdater(double learningRate, double momentum) {
+ this.learningRate = learningRate;
+ this.momentum = momentum;
+ }
+
+ /** {@inheritDoc} */
+ @Override public NesterovUpdaterParams init(SmoothParametrized mdl,
+ IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+ this.loss = loss;
+
+ return new NesterovUpdaterParams(mdl.parametersCount());
+ }
+
+ /** {@inheritDoc} */
+ @Override public NesterovUpdaterParams updateParams(SmoothParametrized mdl, NesterovUpdaterParams updaterParameters,
+ int iteration, Matrix inputs, Matrix groundTruth) {
+
+ if (iteration > 0) {
+ Vector curParams = mdl.parameters();
+ mdl.setParameters(curParams.minus(updaterParameters.prevIterationUpdates().times(momentum)));
+ }
+
+ Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth);
+ updaterParameters.setPreviousUpdates(updaterParameters.prevIterationUpdates().plus(gradient.times(learningRate)));
+
+ return updaterParameters;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java
new file mode 100644
index 0000000..d403ea1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/NesterovUpdaterParams.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Data needed for Nesterov parameters updater.
+ */
+public class NesterovUpdaterParams implements UpdaterParams<SmoothParametrized> {
+ /**
+ * Previous step weights updates.
+ */
+ protected Vector prevIterationUpdates;
+
+ /**
+ * Construct NesterovUpdaterParams.
+ *
+ * @param paramsCnt Count of parameters on which update happens.
+ */
+ public NesterovUpdaterParams(int paramsCnt) {
+ prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt).assign(0);
+ }
+
+ /**
+ * Set previous step parameters updates.
+ *
+ * @param updates Parameters updates.
+ * @return This object with updated parameters updates.
+ */
+ public NesterovUpdaterParams setPreviousUpdates(Vector updates) {
+ prevIterationUpdates = updates;
+ return this;
+ }
+
+ /**
+ * Get previous step parameters updates.
+ *
+ * @return Previous step parameters updates.
+ */
+ public Vector prevIterationUpdates() {
+ return prevIterationUpdates;
+ }
+
+ /** {@inheritDoc} */
+ @SuppressWarnings("unchecked")
+ @Override public <M extends SmoothParametrized> M update(M obj) {
+ Vector parameters = obj.parameters();
+ return (M)obj.setParameters(parameters.minus(prevIterationUpdates));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java
new file mode 100644
index 0000000..e8e28fd
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/ParameterUpdater.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Interface for classes encapsulating parameters update logic.
+ *
+ * @param <M> Type of model to be updated.
+ * @param <P> Type of parameters needed for this updater.
+ */
+public interface ParameterUpdater<M, P extends UpdaterParams> {
+ /**
+ * Initializes the updater.
+ *
+ * @param mdl Model to be trained.
+ * @param loss Loss function.
+ */
+ P init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss);
+
+ /**
+ * Update updater parameters.
+ *
+ * @param mdl Model to be updated.
+ * @param updaterParameters Updater parameters to update.
+ * @param iteration Current trainer iteration.
+ * @param inputs Inputs.
+ * @param groundTruth True values.
+ * @return Updated parameters.
+ */
+ P updateParams(M mdl, P updaterParameters, int iteration, Matrix inputs, Matrix groundTruth);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java
new file mode 100644
index 0000000..c9d8843
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdater.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+
+/**
+ * Class encapsulating RProp algorithm.
+ *
+ * @see <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf</a>.
+ */
+public class RPropUpdater implements ParameterUpdater<SmoothParametrized, RPropUpdaterParams> {
+ /**
+ * Default initial update.
+ */
+ private static double DFLT_INIT_UPDATE = 0.1;
+
+ /**
+ * Default acceleration rate.
+ */
+ private static double DFLT_ACCELERATION_RATE = 1.2;
+
+ /**
+ * Default deacceleration rate.
+ */
+ private static double DFLT_DEACCELERATION_RATE = 0.5;
+
+ /**
+ * Initial update.
+ */
+ private final double initUpdate;
+
+ /**
+ * Acceleration rate.
+ */
+ private final double accelerationRate;
+
+ /**
+ * Deacceleration rate.
+ */
+ private final double deaccelerationRate;
+
+ /**
+ * Maximal value for update.
+ */
+ private final static double UPDATE_MAX = 50.0;
+
+ /**
+ * Minimal value for update.
+ */
+ private final static double UPDATE_MIN = 1E-6;
+
+ /**
+ * Loss function.
+ */
+ protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+ /**
+ * Construct RPropUpdater.
+ *
+ * @param initUpdate Initial update.
+ * @param accelerationRate Acceleration rate.
+ * @param deaccelerationRate Deacceleration rate.
+ */
+ public RPropUpdater(double initUpdate, double accelerationRate, double deaccelerationRate) {
+ this.initUpdate = initUpdate;
+ this.accelerationRate = accelerationRate;
+ this.deaccelerationRate = deaccelerationRate;
+ }
+
+ /**
+ * Construct RPropUpdater with default parameters.
+ */
+ public RPropUpdater() {
+ this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE);
+ }
+
+ /** {@inheritDoc} */
+ @Override public RPropUpdaterParams init(SmoothParametrized mdl,
+ IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+ this.loss = loss;
+ return new RPropUpdaterParams(mdl.parametersCount(), initUpdate);
+ }
+
+ /** {@inheritDoc} */
+ @Override public RPropUpdaterParams updateParams(SmoothParametrized mdl, RPropUpdaterParams updaterParams,
+ int iteration, Matrix inputs, Matrix groundTruth) {
+ Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth);
+ Vector prevGradient = updaterParams.prevIterationGradient();
+ Vector derSigns;
+
+ if (prevGradient != null)
+ derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y));
+ else
+ derSigns = gradient.like(gradient.size()).assign(1.0);
+
+ updaterParams.deltas().map(derSigns, (prevDelta, sign) -> {
+ if (sign > 0)
+ return Math.min(prevDelta * accelerationRate, UPDATE_MAX);
+ else if (sign < 0)
+ return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN);
+ else
+ return prevDelta;
+ });
+
+ updaterParams.setPrevIterationBiasesUpdates(MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> {
+ if (derSigns.getX(i) >= 0)
+ return -Math.signum(der) * delta;
+
+ return updaterParams.prevIterationUpdates().getX(i);
+ }));
+
+ Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> {
+ if (sign < 0)
+ gradient.setX(i, 0.0);
+
+ if (sign >= 0)
+ return 1.0;
+ else
+ return -1.0;
+ });
+
+ updaterParams.setUpdatesMask(updatesMask);
+ updaterParams.setPrevIterationWeightsDerivatives(gradient.copy());
+
+ return updaterParams;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java
new file mode 100644
index 0000000..cff5f5b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/RPropUpdaterParams.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Data needed for RProp updater.
+ * @see <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf</a>.
+ */
+public class RPropUpdaterParams implements UpdaterParams<SmoothParametrized> {
+ /**
+ * Previous iteration weights updates. In original paper they are labeled with "delta w".
+ */
+ protected Vector prevIterationUpdates;
+
+ /**
+ * Previous iteration model partial derivatives by parameters.
+ */
+ protected Vector prevIterationGradient;
+ /**
+ * Previous iteration parameters deltas. In original paper they are labeled with "delta".
+ */
+ protected Vector deltas;
+
+ /**
+ * Updates mask (values by which update is multiplied).
+ */
+ protected Vector updatesMask;
+
+ /**
+ * Construct RPropUpdaterParams.
+ *
+ * @param paramsCnt Parameters count.
+ * @param initUpdate Initial update (in original work labeled as "delta_0").
+ */
+ RPropUpdaterParams(int paramsCnt, double initUpdate) {
+ prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt);
+ prevIterationGradient = new DenseLocalOnHeapVector(paramsCnt);
+ deltas = new DenseLocalOnHeapVector(paramsCnt).assign(initUpdate);
+ updatesMask = new DenseLocalOnHeapVector(paramsCnt);
+ }
+
+ /**
+ * Get bias deltas.
+ *
+ * @return Bias deltas.
+ */
+ Vector deltas() {
+ return deltas;
+ }
+
+ /**
+ * Get previous iteration biases updates. In original paper they are labeled with "delta w".
+ *
+ * @return Biases updates.
+ */
+ Vector prevIterationUpdates() {
+ return prevIterationUpdates;
+ }
+
+ /**
+ * Set previous iteration parameters updates. In original paper they are labeled with "delta w".
+ *
+ * @param updates New parameters updates value.
+ * @return This object.
+ */
+ Vector setPrevIterationBiasesUpdates(Vector updates) {
+ return prevIterationUpdates = updates;
+ }
+
+ /**
+ * Get previous iteration loss function partial derivatives by parameters.
+ *
+ * @return Previous iteration loss function partial derivatives by parameters.
+ */
+ Vector prevIterationGradient() {
+ return prevIterationGradient;
+ }
+
+ /**
+ * Set previous iteration loss function partial derivatives by parameters.
+ *
+ * @return This object.
+ */
+ RPropUpdaterParams setPrevIterationWeightsDerivatives(Vector gradient) {
+ prevIterationGradient = gradient;
+ return this;
+ }
+
+ /**
+ * Get updates mask (values by which update is multiplied).
+ *
+ * @return Updates mask (values by which update is multiplied).
+ */
+ public Vector updatesMask() {
+ return updatesMask;
+ }
+
+ /**
+ * Set updates mask (values by which update is multiplied).
+ *
+ * @param updatesMask New updatesMask.
+ */
+ public RPropUpdaterParams setUpdatesMask(Vector updatesMask) {
+ this.updatesMask = updatesMask;
+
+ return this;
+ }
+
+ /** {@inheritDoc} */
+ @SuppressWarnings("unchecked")
+ @Override public <M extends SmoothParametrized> M update(M obj) {
+ Vector updatesToAdd = VectorUtils.elementWiseTimes(updatesMask.copy(), prevIterationUpdates);
+ return (M)obj.setParameters(obj.parameters().plus(updatesToAdd));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java
new file mode 100644
index 0000000..50a120a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDParams.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Parameters for {@link SimpleGDUpdater}.
+ */
+public class SimpleGDParams implements UpdaterParams<SmoothParametrized> {
+ /**
+ * Gradient.
+ */
+ private Vector gradient;
+
+ /**
+ * Learning rate.
+ */
+ private double learningRate;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param paramsCnt Count of parameters.
+ * @param learningRate Learning rate.
+ */
+ public SimpleGDParams(int paramsCnt, double learningRate) {
+ gradient = new DenseLocalOnHeapVector(paramsCnt);
+ this.learningRate = learningRate;
+ }
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param gradient Gradient.
+ * @param learningRate Learning rate.
+ */
+ public SimpleGDParams(Vector gradient, double learningRate) {
+ this.gradient = gradient;
+ this.learningRate = learningRate;
+ }
+
+ /** {@inheritDoc} */
+ @SuppressWarnings("unchecked")
+ @Override public <M extends SmoothParametrized> M update(M obj) {
+ Vector params = obj.parameters();
+ return (M)obj.setParameters(params.minus(gradient.times(learningRate)));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java
new file mode 100644
index 0000000..5bf9c3f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SimpleGDUpdater.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Simple gradient descent parameters updater.
+ */
+public class SimpleGDUpdater implements ParameterUpdater<SmoothParametrized, SimpleGDParams> {
+ /**
+ * Learning rate.
+ */
+ private double learningRate;
+
+ /**
+ * Loss function.
+ */
+ protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+ /**
+ * Construct SimpleGDUpdater.
+ *
+ * @param learningRate Learning rate.
+ */
+ public SimpleGDUpdater(double learningRate) {
+ this.learningRate = learningRate;
+ }
+
+ /** {@inheritDoc} */
+ @Override public SimpleGDParams init(SmoothParametrized mlp,
+ IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+ this.loss = loss;
+ return new SimpleGDParams(mlp.parametersCount(), learningRate);
+ }
+
+ /** {@inheritDoc} */
+ @Override public SimpleGDParams updateParams(SmoothParametrized mlp, SimpleGDParams updaterParameters,
+ int iteration, Matrix inputs, Matrix groundTruth) {
+ return new SimpleGDParams(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java
new file mode 100644
index 0000000..5c4f59f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/SmoothParametrized.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+/**
+ * Interface for models which are smooth functions of their parameters.
+ */
+public interface SmoothParametrized<M extends SmoothParametrized<M>> extends BaseSmoothParametrized<M> {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java
new file mode 100644
index 0000000..cd5bc32
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/UpdaterParams.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.updaters;
+
+/**
+ * A common interface for parameter updaters.
+ *
+ * @param <T> Type of object to be updated with this params.
+ */
+public interface UpdaterParams<T> {
+ /**
+ * Update given obj with this parameters.
+ *
+ * @param obj Object to be updated.
+ */
+ <M extends T> M update(M obj);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java
new file mode 100644
index 0000000..13bc3c8
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/updaters/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains parameters updaters.
+ */
+package org.apache.ignite.ml.nn.updaters;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
index 76a90fc..b95cbf3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java
@@ -47,7 +47,7 @@ public class OLSMultipleLinearRegressionModel implements Model<Vector, Vector>,
}
/** {@inheritDoc} */
- @Override public Vector predict(Vector val) {
+ @Override public Vector apply(Vector val) {
return xMatrix.times(solver.solve(val));
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
index 86e9326..572e64a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
@@ -38,7 +38,7 @@ public class DecisionTreeModel implements Model<Vector, Double> {
}
/** {@inheritDoc} */
- @Override public Double predict(Vector val) {
+ @Override public Double apply(Vector val) {
return root.process(val);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index 847b1f1..4472300 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.util.Random;
import org.apache.ignite.IgniteException;
/**
@@ -57,4 +58,30 @@ public class Utils {
return (T)obj;
}
+
+ /**
+ * Select k distinct integers from range [0, n) with reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling.
+ *
+ * @param n Number specifying left end of range of integers to pick values from.
+ * @param k Count specifying how many integers should be picked.
+ * @return Array containing k distinct integers from range [0, n);
+ */
+ public static int[] selectKDistinct(int n, int k) {
+ int i;
+
+ int res[] = new int[k];
+ for (i = 0; i < k; i++)
+ res[i] = i;
+
+ Random r = new Random();
+
+ for (; i < n; i++) {
+ int j = r.nextInt(i + 1);
+
+ if (j < k)
+ res[j] = i;
+ }
+
+ return res;
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index 05c91bd..fafd364 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -20,6 +20,7 @@ package org.apache.ignite.ml;
import org.apache.ignite.ml.clustering.ClusteringTestSuite;
import org.apache.ignite.ml.knn.KNNTestSuite;
import org.apache.ignite.ml.math.MathImplMainTestSuite;
+import org.apache.ignite.ml.nn.MLPTestSuite;
import org.apache.ignite.ml.regressions.RegressionsTestSuite;
import org.apache.ignite.ml.trees.DecisionTreesTestSuite;
import org.junit.runner.RunWith;
@@ -35,7 +36,8 @@ import org.junit.runners.Suite;
ClusteringTestSuite.class,
DecisionTreesTestSuite.class,
KNNTestSuite.class,
- LocalModelsTest.class
+ LocalModelsTest.class,
+ MLPTestSuite.class
})
public class IgniteMLTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index e010553..28af6fa 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -46,9 +46,9 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
- assertEquals(knnMdl.predict(firstVector), 1.0);
+ assertEquals(knnMdl.apply(firstVector), 1.0);
Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
- assertEquals(knnMdl.predict(secondVector), 2.0);
+ assertEquals(knnMdl.apply(secondVector), 2.0);
}
/** */
@@ -69,9 +69,9 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNModel knnMdl = new KNNModel(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
- assertEquals(knnMdl.predict(firstVector), 1.0);
+ assertEquals(knnMdl.apply(firstVector), 1.0);
Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
- assertEquals(knnMdl.predict(secondVector), 2.0);
+ assertEquals(knnMdl.apply(secondVector), 2.0);
}
/** */
@@ -91,7 +91,7 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
- assertEquals(knnMdl.predict(vector), 2.0);
+ assertEquals(knnMdl.apply(vector), 2.0);
}
/** */
@@ -112,7 +112,7 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.WEIGHTED, training);
Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
- assertEquals(knnMdl.predict(vector), 1.0);
+ assertEquals(knnMdl.apply(vector), 1.0);
}
/** */
@@ -122,7 +122,7 @@ public class KNNClassificationTest extends BaseKNNTest {
KNNModel knnMdl = new KNNModel(7, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector vector = new DenseLocalOnHeapVector(new double[] {5.15, 3.55, 1.45, 0.25});
- assertEquals(knnMdl.predict(vector), 1.0);
+ assertEquals(knnMdl.apply(vector), 1.0);
}
/** */
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
index 9a918b6..d973686 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
@@ -56,8 +56,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest {
KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector vector = new SparseBlockDistributedVector(new double[] {0, 0, 0, 5.0, 0.0});
- System.out.println(knnMdl.predict(vector));
- Assert.assertEquals(15, knnMdl.predict(vector), 1E-12);
+ System.out.println(knnMdl.apply(vector));
+ Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
}
/** */
@@ -87,8 +87,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest {
KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.predict(vector));
- Assert.assertEquals(67857, knnMdl.predict(vector), 2000);
+ System.out.println(knnMdl.apply(vector));
+ Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
/** */
@@ -119,8 +119,8 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest {
KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.SIMPLE, normalizedTrainingDataset);
Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.predict(vector));
- Assert.assertEquals(67857, knnMdl.predict(vector), 2000);
+ System.out.println(knnMdl.apply(vector));
+ Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
/** */
@@ -151,7 +151,7 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest {
KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.WEIGHTED, normalizedTrainingDataset);
Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.predict(vector));
- Assert.assertEquals(67857, knnMdl.predict(vector), 2000);
+ System.out.println(knnMdl.apply(vector));
+ Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java
new file mode 100644
index 0000000..fa2b5e2
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPConstInitializer.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.nn.initializers.MLPInitializer;
+
+/**
+ * Initialize weights and biases with specified constant.
+ */
+public class MLPConstInitializer implements MLPInitializer {
+ /**
+ * Constant to be used as bias for all layers.
+ */
+ private double bias;
+
+ /**
+ * Constant to be used as weight from any neuron to any neuron in next layer.
+ */
+ private double weight;
+
+ /**
+ * Construct MLPConstInitializer.
+ *
+ * @param weight Constant to be used as weight from any neuron to any neuron in next layer.
+ * @param bias Constant to be used as bias for all layers.
+ */
+ public MLPConstInitializer(double weight, double bias) {
+ this.bias = bias;
+ this.weight = weight;
+ }
+
+ /**
+ * Construct MLPConstInitializer with biases constant equal to 0.0.
+ *
+ * @param weight Constant to be used as weight from any neuron to any neuron in next layer.
+ */
+ public MLPConstInitializer(double weight) {
+ this(weight, 0.0);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void initWeights(Matrix weights) {
+ weights.assign(weight);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void initBiases(Vector biases) {
+ biases.assign(bias);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
new file mode 100644
index 0000000..2a6b55d
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import java.util.Random;
+import org.apache.ignite.internal.util.typedef.X;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.StorageConstants;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
+import org.apache.ignite.ml.nn.updaters.NesterovUpdater;
+import org.apache.ignite.ml.nn.updaters.ParameterUpdater;
+import org.apache.ignite.ml.nn.updaters.RPropUpdater;
+import org.apache.ignite.ml.nn.updaters.SimpleGDUpdater;
+import org.apache.ignite.ml.nn.updaters.UpdaterParams;
+import org.junit.Test;
+
+/**
+ * Tests for {@link MLPLocalBatchTrainer}.
+ */
+public class MLPLocalTrainerTest {
+ /**
+ * Test 'XOR' operation training with {@link SimpleGDUpdater} updater.
+ */
+ @Test
+ public void testXORSimpleGD() {
+ xorTest(() -> new SimpleGDUpdater(0.3));
+ }
+
+ /**
+ * Test 'XOR' operation training with {@link RPropUpdater}.
+ */
+ @Test
+ public void testXORRProp() {
+ xorTest(RPropUpdater::new);
+ }
+
+ /**
+ * Test 'XOR' operation training with {@link NesterovUpdater}.
+ */
+ @Test
+ public void testXORNesterov() {
+ xorTest(() -> new NesterovUpdater(0.1, 0.7));
+ }
+
+ /**
+ * Common method for testing 'XOR' with various updaters.
+ * @param updaterSupplier Updater supplier.
+ * @param <P> Updater parameters type.
+ */
+ private <P extends UpdaterParams<? super MultilayerPerceptron>> void xorTest(IgniteSupplier<ParameterUpdater<? super MultilayerPerceptron, P>> updaterSupplier) {
+ Matrix xorInputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}},
+ StorageConstants.ROW_STORAGE_MODE).transpose();
+
+ Matrix xorOutputs = new DenseLocalOnHeapMatrix(new double[][] {{0.0}, {1.0}, {1.0}, {0.0}},
+ StorageConstants.ROW_STORAGE_MODE).transpose();
+
+ MLPArchitecture conf = new MLPArchitecture(2).
+ withAddedLayer(10, true, Activators.RELU).
+ withAddedLayer(1, false, Activators.SIGMOID);
+
+ SimpleMLPLocalBatchTrainerInput trainerInput = new SimpleMLPLocalBatchTrainerInput(conf,
+ new Random(1234L), xorInputs, xorOutputs, 4);
+
+ MultilayerPerceptron mlp = new MLPLocalBatchTrainer<>(LossFunctions.MSE,
+ updaterSupplier,
+ 0.0001,
+ 16000).train(trainerInput);
+
+ Matrix predict = mlp.apply(xorInputs);
+
+ Tracer.showAscii(predict);
+
+ X.println(xorOutputs.getRow(0).minus(predict.getRow(0)).kNorm(2) + "");
+
+ TestUtils.checkIsInEpsilonNeighbourhood(xorOutputs.getRow(0), predict.getRow(0), 1E-1);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
new file mode 100644
index 0000000..d757fcb
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for Multilayer perceptron.
+ */
+public class MLPTest {
+ /**
+ * Tests that MLP with 2 layer, 1 neuron in each layer and weight equal to 1 is equivalent to sigmoid function.
+ */
+ @Test
+ public void testSimpleMLPPrediction() {
+ MLPArchitecture conf = new MLPArchitecture(1).withAddedLayer(1, false, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1));
+
+ int input = 2;
+
+ Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{input}}));
+
+ Assert.assertEquals(predict, new DenseLocalOnHeapMatrix(new double[][] {{Activators.SIGMOID.apply(input)}}));
+ }
+
+ /**
+ * Test that MLP with parameters that should produce function close to 'XOR' is close to 'XOR' on 'XOR' domain.
+ */
+ @Test
+ public void testXOR() {
+ MLPArchitecture conf = new MLPArchitecture(2).
+ withAddedLayer(2, true, Activators.SIGMOID).
+ withAddedLayer(1, true, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1, 2));
+
+ mlp.setWeights(1, new DenseLocalOnHeapMatrix(new double[][] {{20.0, 20.0}, {-20.0, -20.0}}));
+ mlp.setBiases(1, new DenseLocalOnHeapVector(new double[] {-10.0, 30.0}));
+
+ mlp.setWeights(2, new DenseLocalOnHeapMatrix(new double[][] {{20.0, 20.0}}));
+ mlp.setBiases(2, new DenseLocalOnHeapVector(new double[] {-30.0}));
+
+ Matrix input = new DenseLocalOnHeapMatrix(new double[][] {{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}}).transpose();
+
+ Matrix predict = mlp.apply(input);
+ Vector truth = new DenseLocalOnHeapVector(new double[] {0.0, 1.0, 1.0, 0.0});
+
+ TestUtils.checkIsInEpsilonNeighbourhood(predict.getRow(0), truth, 1E-4);
+ }
+
+ /**
+ * Test that two layer MLP is equivalent to it's subparts stacked on each other.
+ */
+ @Test
+ public void testStackedMLP() {
+ int firstLayerNeuronsCnt = 3;
+ int secondLayerNeuronsCnt = 2;
+ MLPConstInitializer initer = new MLPConstInitializer(1, 2);
+
+ MLPArchitecture conf = new MLPArchitecture(4).
+ withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID).
+ withAddedLayer(secondLayerNeuronsCnt, false, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp = new MultilayerPerceptron(conf, initer);
+
+ MLPArchitecture mlpLayer1Conf = new MLPArchitecture(4).
+ withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID);
+ MLPArchitecture mlpLayer2Conf = new MLPArchitecture(firstLayerNeuronsCnt).
+ withAddedLayer(secondLayerNeuronsCnt, false, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp1 = new MultilayerPerceptron(mlpLayer1Conf, initer);
+ MultilayerPerceptron mlp2 = new MultilayerPerceptron(mlpLayer2Conf, initer);
+
+ MultilayerPerceptron stackedMLP = mlp1.add(mlp2);
+
+ Matrix predict = mlp.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose());
+ Matrix stackedPredict = stackedMLP.apply(new DenseLocalOnHeapMatrix(new double[][] {{1, 2, 3, 4}}).transpose());
+
+ Assert.assertEquals(predict, stackedPredict);
+ }
+
+ /**
+ * Test parameters count works well.
+ */
+ @Test
+ public void paramsCountTest() {
+ int inputSize = 10;
+ int layerWithBiasNeuronsCnt = 13;
+ int layerWithoutBiasNeuronsCnt = 17;
+
+ MLPArchitecture conf = new MLPArchitecture(inputSize).
+ withAddedLayer(layerWithBiasNeuronsCnt, true, Activators.SIGMOID).
+ withAddedLayer(layerWithoutBiasNeuronsCnt, false, Activators.SIGMOID);
+
+ Assert.assertEquals(layerWithBiasNeuronsCnt * inputSize + layerWithBiasNeuronsCnt + (layerWithoutBiasNeuronsCnt * layerWithBiasNeuronsCnt),
+ conf.parametersCount());
+ }
+
+ /**
+ * Test methods related to parameters flattening.
+ */
+ @Test
+ public void setParamsFlattening() {
+ int inputSize = 3;
+ int firstLayerNeuronsCnt = 2;
+ int secondLayerNeurons = 1;
+
+ DenseLocalOnHeapVector paramsVector = new DenseLocalOnHeapVector(new double[] {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // First layer weight matrix.
+ 7.0, 8.0, // Second layer weight matrix.
+ 9.0 // Second layer biases.
+ });
+
+ DenseLocalOnHeapMatrix firstLayerWeights = new DenseLocalOnHeapMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}});
+ DenseLocalOnHeapMatrix secondLayerWeights = new DenseLocalOnHeapMatrix(new double[][] {{7.0, 8.0}});
+ DenseLocalOnHeapVector secondLayerBiases = new DenseLocalOnHeapVector(new double[] {9.0});
+
+ MLPArchitecture conf = new MLPArchitecture(inputSize).
+ withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID).
+ withAddedLayer(secondLayerNeurons, true, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(100, 200));
+
+ mlp.setParameters(paramsVector);
+ Assert.assertEquals(paramsVector, mlp.parameters());
+
+ Assert.assertEquals(mlp.weights(1), firstLayerWeights);
+ Assert.assertEquals(mlp.weights(2), secondLayerWeights);
+ Assert.assertEquals(mlp.biases(2), secondLayerBiases);
+ }
+
+ /**
+ * Test differentiation.
+ */
+ @Test
+ public void testDifferentiation() {
+ int inputSize = 2;
+ int firstLayerNeuronsCnt = 1;
+
+ double w10 = 0.1;
+ double w11 = 0.2;
+
+ MLPArchitecture conf = new MLPArchitecture(inputSize).
+ withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID);
+
+ MultilayerPerceptron mlp = new MultilayerPerceptron(conf);
+
+ mlp.setWeight(1, 0, 0, w10);
+ mlp.setWeight(1, 1, 0, w11);
+ double x0 = 1.0;
+ double x1 = 3.0;
+
+ Matrix inputs = new DenseLocalOnHeapMatrix(new double[][] {{x0, x1}}).transpose();
+ double ytt = 1.0;
+ Matrix truth = new DenseLocalOnHeapMatrix(new double[][] {{ytt}}).transpose();
+
+ Vector grad = mlp.differentiateByParameters(LossFunctions.MSE, inputs, truth);
+
+ // Let yt be y ground truth value.
+ // d/dw1i [(yt - sigma(w10 * x0 + w11 * x1))^2] =
+ // 2 * (yt - sigma(w10 * x0 + w11 * x1)) * (-1) * (sigma(w10 * x0 + w11 * x1)) * (1 - sigma(w10 * x0 + w11 * x1)) * xi =
+ // let z = sigma(w10 * x0 + w11 * x1)
+ // - 2* (yt - z) * (z) * (1 - z) * xi.
+
+ IgniteTriFunction<Double, Vector, Vector, Vector> partialDer = (yt, w, x) -> {
+ Double z = Activators.SIGMOID.apply(w.dot(x));
+
+ return x.copy().map(xi -> -2 * (yt - z) * z * (1 - z) * xi);
+ };
+
+ Vector weightsVec = mlp.weights(1).getRow(0);
+ Tracer.showAscii(weightsVec);
+
+ Vector trueGrad = partialDer.apply(ytt, weightsVec, inputs.getCol(0));
+
+ Tracer.showAscii(trueGrad);
+ Tracer.showAscii(grad);
+
+ Assert.assertEquals(mlp.architecture().parametersCount(), grad.size());
+ Assert.assertEquals(trueGrad, grad);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java
new file mode 100644
index 0000000..d006cd9
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTestSuite.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for multilayer perceptrons.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ MLPTest.class,
+ MLPLocalTrainerTest.class,
+})
+public class MLPTestSuite {
+ // No-op.
+}