You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2017/12/30 21:59:31 UTC
[1/2] ignite git commit: IGNITE-6783: Create common mechanism for
group training.
Repository: ignite
Updated Branches:
refs/heads/master 442716ed6 -> b0c5ef1ea
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
new file mode 100644
index 0000000..080494c
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.io.Serializable;
+import java.util.UUID;
+import java.util.stream.Stream;
+import javax.cache.processor.EntryProcessor;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.cluster.ClusterGroup;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.math.functions.Functions;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
+import org.apache.ignite.ml.trainers.group.GroupTrainerEntriesProcessorTask;
+import org.apache.ignite.ml.trainers.group.GroupTrainerKeysProcessorTask;
+import org.apache.ignite.ml.trainers.group.GroupTrainingContext;
+import org.apache.ignite.ml.trainers.group.ResultAndUpdates;
+
+/**
+ * This class encapsulates convenient way for creating computations chain for distributed model training.
+ * Chain is meant in the sense that output of each non-final computation is fed as input to next computation.
+ * Chain is basically a bi-function from context and input to output, context is separated from input
+ * because input is specific to each individual step and context is something which is convenient to have access to in each of steps.
+ * Context is separated into two parts: local context and remote context.
+ * There are two kinds of computations: local and distributed.
+ * Local steps are just functions from two arguments: input and local context.
+ * Distributed steps are more sophisticated, but basically can be thought as functions of form
+ * localContext -> (function of remote context -> output), locally we fix local context and get function
+ * (function of remote context -> output) which is executed distributed.
+ * Chains are composable through 'then' method.
+ *
+ * @param <L> Type of local context.
+ * @param <K> Type of cache keys.
+ * @param <V> Type of cache values.
+ * @param <I> Type of input of this chain.
+ * @param <O> Type of output of this chain.
+ * // TODO: IGNITE-7322 check if it is possible to integrate with {@link EntryProcessor}.
+ */
+@FunctionalInterface
+public interface ComputationsChain<L extends HasTrainingUUID, K, V, I, O> {
+ /**
+ * Process given input and {@link GroupTrainingContext}.
+ *
+ * @param input Computation chain input.
+ * @param ctx {@link GroupTrainingContext}.
+ * @return Result of processing input and context.
+ */
+ O process(I input, GroupTrainingContext<K, V, L> ctx);
+
+ /**
+ * Add a local step to this chain.
+ *
+ * @param locStep Local step.
+ * @param <O1> Output of local step.
+ * @return Composition of this chain and local step.
+ */
+ default <O1> ComputationsChain<L, K, V, I, O1> thenLocally(IgniteBiFunction<O, L, O1> locStep) {
+ ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> locStep.apply(input, context.localContext());
+ return then(nextStep);
+ }
+
+ /**
+ * Add a distributed step which works in the following way:
+ * 1. apply local context and input to local context extractor and keys supplier to get corresponding suppliers;
+ * 2. on each node_n
+ * 2.1. get context object.
+ * 2.2. for each entry_i e located on node_n with key_i from keys stream compute worker((context, entry_i)) and get
+ * (cachesUpdates_i, result_i).
+ * 2.3. for all i on node_n merge cacheUpdates_i and apply them.
+ * 2.4. for all i on node_n, reduce result_i into result_n.
+ * 3. get all result_n, reduce them into result and return result.
+ *
+ * @param <O1> Type of worker output.
+ * @param <G> Type of context used by worker.
+ * @param workerCtxExtractor Extractor of context for worker.
+ * @param worker Function computed on each entry of cache used for training. Second argument is context:
+ * common part of data which is independent from key.
+ * @param ks Function from chain input and local context to supplier of keys for worker.
+ * @param reducer Function used for reducing results of worker.
+ * @param identity Identity for reducer.
+ * @return Combination of this chain and distributed step specified by given parameters.
+ */
+ default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForEntries(
+ IgniteBiFunction<O, L, IgniteSupplier<G>> workerCtxExtractor,
+ IgniteFunction<EntryAndContext<K, V, G>, ResultAndUpdates<O1>> worker,
+ IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> ks,
+ IgniteBinaryOperator<O1> reducer, O1 identity) {
+ ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> {
+ L locCtx = context.localContext();
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier = ks.apply(input, locCtx);
+
+ Ignite ignite = context.ignite();
+ UUID trainingUUID = context.localContext().trainingUUID();
+ String cacheName = context.cache().getName();
+ ClusterGroup grp = ignite.cluster().forDataNodes(cacheName);
+
+ // Apply first two arguments locally because it is common for all nodes.
+ IgniteSupplier<G> extractor = Functions.curry(workerCtxExtractor).apply(input).apply(locCtx);
+
+ return ignite.compute(grp).execute(new GroupTrainerEntriesProcessorTask<>(trainingUUID, extractor, worker, keysSupplier, reducer, identity, cacheName, ignite), null);
+ };
+ return then(nextStep);
+ }
+
+ /**
+ * Add a distributed step which works in the following way:
+ * 1. apply local context and input to local context extractor and keys supplier to get corresponding suppliers;
+ * 2. on each node_n
+ * 2.1. get context object.
+ * 2.2. for each key_i from keys stream such that key_i located on node_n compute worker((context, entry_i)) and get
+ * (cachesUpdates_i, result_i).
+ * 2.3. for all i on node_n merge cacheUpdates_i and apply them.
+ * 2.4. for all i on node_n, reduce result_i into result_n.
+ * 3. get all result_n, reduce them into result and return result.
+ *
+ * @param <O1> Type of worker output.
+ * @param <G> Type of context used by worker.
+ * @param workerCtxExtractor Extractor of context for worker.
+ * @param worker Function computed on each entry of cache used for training. Second argument is context:
+ * common part of data which is independent from key.
+ * @param keysSupplier Function from chain input and local context to supplier of keys for worker.
+ * @param reducer Function used for reducing results of worker.
+ * @param identity Identity for reducer.
+ * @return Combination of this chain and distributed step specified by given parameters.
+ */
+ default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys(
+ IgniteBiFunction<O, L, IgniteSupplier<G>> workerCtxExtractor,
+ IgniteFunction<KeyAndContext<K, G>, ResultAndUpdates<O1>> worker,
+ IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> keysSupplier,
+ IgniteBinaryOperator<O1> reducer, O1 identity) {
+ ComputationsChain<L, K, V, O, O1> nextStep = (input, context) -> {
+ L locCtx = context.localContext();
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> ks = keysSupplier.apply(input, locCtx);
+
+ Ignite ignite = context.ignite();
+ UUID trainingUUID = context.localContext().trainingUUID();
+ String cacheName = context.cache().getName();
+ ClusterGroup grp = ignite.cluster().forDataNodes(cacheName);
+
+ // Apply first argument locally because it is common for all nodes.
+ IgniteSupplier<G> extractor = Functions.curry(workerCtxExtractor).apply(input).apply(locCtx);
+
+ return ignite.compute(grp).execute(new GroupTrainerKeysProcessorTask<>(trainingUUID, extractor, worker, ks, reducer, identity, cacheName, ignite), null);
+ };
+ return then(nextStep);
+ }
+
+ /**
+ * Add a distributed step specified by {@link DistributedEntryProcessingStep}.
+ *
+ * @param step Distributed step.
+ * @param <O1> Type of output of distributed step.
+ * @param <G> Type of context of distributed step.
+ * @return Combination of this chain and distributed step specified by input.
+ */
+ default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForEntries(
+ DistributedEntryProcessingStep<L, K, V, G, O, O1> step) {
+ return thenDistributedForEntries(step::remoteContextSupplier, step.worker(), step::keys, step.reducer(), step.identity());
+ }
+
+ /**
+ * Add a distributed step specified by {@link DistributedKeyProcessingStep}.
+ *
+ * @param step Distributed step.
+ * @param <O1> Type of output of distributed step.
+ * @param <G> Type of context of distributed step.
+ * @return Combination of this chain and distributed step specified by input.
+ */
+ default <O1 extends Serializable, G> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys(
+ DistributedKeyProcessingStep<L, K, G, O, O1> step) {
+ return thenDistributedForKeys(step::remoteContextSupplier, step.worker(), step::keys, step.reducer(), step.identity());
+ }
+
+ /**
+ * Version of 'thenDistributedForKeys' where worker does not depend on context.
+ *
+ * @param worker Worker.
+ * @param kf Function providing supplier
+ * @param reducer Function from chain input and local context to supplier of keys for worker.
+ * @param <O1> Type of worker output.
+ * @return Combination of this chain and distributed step specified by given parameters.
+ */
+ default <O1 extends Serializable> ComputationsChain<L, K, V, I, O1> thenDistributedForKeys(
+ IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<O1>> worker,
+ IgniteBiFunction<O, L, IgniteSupplier<Stream<GroupTrainerCacheKey<K>>>> kf,
+ IgniteBinaryOperator<O1> reducer) {
+
+ return thenDistributedForKeys((o, lc) -> () -> o, (context) -> worker.apply(context.key()), kf, reducer, null);
+ }
+
+ /**
+ * Combine this computation chain with other computation chain in the following way:
+ * 1. perform this calculations chain and get result r.
+ * 2. while 'cond(r)' is true, r = otherChain(r, context)
+ * 3. return r.
+ *
+ * @param cond Condition checking if 'while' loop should continue.
+ * @param otherChain Chain to be combined with this chain.
+ * @return Combination of this chain and otherChain.
+ */
+ default ComputationsChain<L, K, V, I, O> thenWhile(IgniteBiPredicate<O, L> cond,
+ ComputationsChain<L, K, V, O, O> otherChain) {
+ ComputationsChain<L, K, V, I, O> me = this;
+ return (input, context) -> {
+ O res = me.process(input, context);
+
+ while (cond.apply(res, context.localContext()))
+ res = otherChain.process(res, context);
+
+ return res;
+ };
+ }
+
+ /**
+ * Combine two this chain to other: feed this chain as input to other, pass same context as second argument to both chains
+ * process method.
+ *
+ * @param next Next chain.
+ * @param <O1> Type of next chain output.
+ * @return Combined chain.
+ */
+ default <O1> ComputationsChain<L, K, V, I, O1> then(ComputationsChain<L, K, V, O, O1> next) {
+ ComputationsChain<L, K, V, I, O> me = this;
+ return (input, context) -> {
+ O myRes = me.process(input, context);
+ return next.process(myRes, context);
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java
new file mode 100644
index 0000000..8fd1264
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedEntryProcessingStep.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.io.Serializable;
+
+/**
+ * {@link DistributedStep} specialized to {@link EntryAndContext}.
+ *
+ * @param <L> Local context.
+ * @param <K> Type of keys of cache used for group training.
+ * @param <V> Type of values of cache used for group training.
+ * @param <C> Context used by worker.
+ * @param <I> Type of input to this step.
+ * @param <O> Type of output of this step.
+ */
+public interface DistributedEntryProcessingStep<L, K, V, C, I, O extends Serializable> extends
+ DistributedStep<EntryAndContext<K, V, C>, L, K, C, I, O> {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java
new file mode 100644
index 0000000..fb8d867
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedKeyProcessingStep.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.io.Serializable;
+
+/**
+ * {@link DistributedStep} specialized to {@link KeyAndContext}.
+ *
+ * @param <L> Local context.
+ * @param <K> Type of keys of cache used for group training.
+ * @param <C> Context used by worker.
+ * @param <I> Type of input to this step.
+ * @param <O> Type of output of this step.
+ */
+public interface DistributedKeyProcessingStep<L, K, C, I, O extends Serializable> extends
+ DistributedStep<KeyAndContext<K, C>, L, K, C, I, O> {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java
new file mode 100644
index 0000000..7ddc649
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/DistributedStep.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.io.Serializable;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
+import org.apache.ignite.ml.trainers.group.ResultAndUpdates;
+
+/**
+ * Class encapsulating logic of distributed step in {@link ComputationsChain}.
+ *
+ * @param <T> Type of elements to be processed by worker.
+ * @param <L> Local context.
+ * @param <K> Type of keys of cache used for group training.
+ * @param <C> Context used by worker.
+ * @param <I> Type of input to this step.
+ * @param <O> Type of output of this step.
+ */
+public interface DistributedStep<T, L, K, C, I, O extends Serializable> {
+ /**
+ * Create supplier of context used by worker.
+ *
+ * @param input Input.
+ * @param locCtx Local context.
+ * @return Context used by worker.
+ */
+ IgniteSupplier<C> remoteContextSupplier(I input, L locCtx);
+
+ /**
+ * Get function applied to each cache element specified by keys.
+ *
+ * @return Function applied to each cache entry specified by keys..
+ */
+ IgniteFunction<T, ResultAndUpdates<O>> worker();
+
+ /**
+ * Get supplier of keys for worker.
+ *
+ * @param input Input to this step.
+ * @param locCtx Local context.
+ * @return Keys for worker.
+ */
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keys(I input, L locCtx);
+
+ /**
+ * Get function used to reduce results returned by worker.
+ *
+ * @return Function used to reduce results returned by worker..
+ */
+ IgniteBinaryOperator<O> reducer();
+
+ /**
+ * Identity for reduce.
+ *
+ * @return Identity for reduce.
+ */
+ O identity();
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java
new file mode 100644
index 0000000..59c3b34
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/EntryAndContext.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.util.Map;
+import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
+
+/**
+ * Entry of cache used for group training and context.
+ * This class is used as input for workers of distributed steps of {@link ComputationsChain}.
+ *
+ * @param <K> Type of cache keys used for training.
+ * @param <V> Type of cache values used for training.
+ * @param <C> Type of context.
+ */
+public class EntryAndContext<K, V, C> {
+ /**
+ * Entry of cache used for training.
+ */
+ private Map.Entry<GroupTrainerCacheKey<K>, V> entry;
+
+ /**
+ * Context.
+ */
+ private C ctx;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param entry Entry of cache used for training.
+ * @param ctx Context.
+ */
+ public EntryAndContext(Map.Entry<GroupTrainerCacheKey<K>, V> entry, C ctx) {
+ this.entry = entry;
+ this.ctx = ctx;
+ }
+
+ /**
+ * Get entry of cache used for training.
+ *
+ * @return Entry of cache used for training.
+ */
+ public Map.Entry<GroupTrainerCacheKey<K>, V> entry() {
+ return entry;
+ }
+
+ /**
+ * Get context.
+ *
+ * @return Context.
+ */
+ public C context() {
+ return ctx;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java
new file mode 100644
index 0000000..d855adf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/HasTrainingUUID.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.util.UUID;
+
+/**
+ * Interface for classes which contain UUID of training.
+ */
+public interface HasTrainingUUID {
+ /**
+ * Get training UUID.
+ *
+ * @return Training UUID.
+ */
+ UUID trainingUUID();
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java
new file mode 100644
index 0000000..ba36e0e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/KeyAndContext.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import org.apache.ignite.ml.trainers.group.GroupTrainerCacheKey;
+
+/**
+ * Class containing key and remote context (see explanation of remote context in {@link ComputationsChain}).
+ *
+ * @param <K> Cache key type.
+ * @param <C> Remote context.
+ */
+public class KeyAndContext<K, C> {
+ /**
+ * Key of group trainer.
+ */
+ private GroupTrainerCacheKey<K> key;
+
+ /**
+ * Remote context.
+ */
+ private C ctx;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param key Cache key.
+ * @param ctx Remote context.
+ */
+ public KeyAndContext(GroupTrainerCacheKey<K> key, C ctx) {
+ this.key = key;
+ this.ctx = ctx;
+ }
+
+ /**
+ * Get group trainer cache key.
+ *
+ * @return Group trainer cache key.
+ */
+ public GroupTrainerCacheKey<K> key() {
+ return key;
+ }
+
+ /**
+ * Get remote context.
+ *
+ * @return Remote context.
+ */
+ public C context() {
+ return ctx;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java
new file mode 100644
index 0000000..46dcc6b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains classes related to computations chain.
+ */
+package org.apache.ignite.ml.trainers.group.chain;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java
new file mode 100644
index 0000000..9b7f7cd
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains group trainers.
+ */
+package org.apache.ignite.ml.trainers.group;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/package-info.java
new file mode 100644
index 0000000..b6e4fe2
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains model trainers.
+ */
+package org.apache.ignite.ml.trainers;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
index a3f1d21..a30cfe9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
@@ -38,7 +38,7 @@ public class MnistUtils {
*
* @param imagesPath Path to the file with images.
* @param labelsPath Path to the file with labels.
- * @param rnd Random numbers generatror.
+ * @param rnd Random numbers generator.
* @param cnt Count of samples to read.
* @return Stream of MNIST samples.
* @throws IOException In case of exception.
@@ -85,7 +85,7 @@ public class MnistUtils {
* @param outPath Path to output path.
* @param rnd Random numbers generator.
* @param cnt Count of samples to read.
- * @throws IOException In case of exception.
+ * @throws IgniteException In case of exception.
*/
public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int cnt)
throws IOException {
@@ -121,4 +121,4 @@ public class MnistUtils {
private static int read4Bytes(FileInputStream is) throws IOException {
return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read());
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index fafd364..35ffdbc 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -22,6 +22,7 @@ import org.apache.ignite.ml.knn.KNNTestSuite;
import org.apache.ignite.ml.math.MathImplMainTestSuite;
import org.apache.ignite.ml.nn.MLPTestSuite;
import org.apache.ignite.ml.regressions.RegressionsTestSuite;
+import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite;
import org.apache.ignite.ml.trees.DecisionTreesTestSuite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -37,7 +38,8 @@ import org.junit.runners.Suite;
DecisionTreesTestSuite.class,
KNNTestSuite.class,
LocalModelsTest.class,
- MLPTestSuite.class
+ MLPTestSuite.class,
+ TrainersGroupTestSuite.class
})
public class IgniteMLTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java
new file mode 100644
index 0000000..987595d
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.Chains;
+import org.apache.ignite.ml.trainers.group.chain.ComputationsChain;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+import org.junit.Assert;
+
+/** */
+public class DistributedWorkersChainTest extends GridCommonAbstractTest {
+ /** Count of nodes. */
+ private static final int NODE_COUNT = 3;
+
+ /** Grid instance. */
+ private Ignite ignite;
+
+ /**
+ * Default constructor.
+ */
+ public DistributedWorkersChainTest() {
+ super(false);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+ TestGroupTrainingCache.getOrCreate(ignite).removeAll();
+ TestGroupTrainingSecondCache.getOrCreate(ignite).removeAll();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /** */
+ public void testId() {
+ ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
+
+ UUID trainingUUID = UUID.randomUUID();
+ Integer res = chain.process(1, new GroupTrainingContext<>(new TestLocalContext(0, trainingUUID), TestGroupTrainingCache.getOrCreate(ignite), ignite));
+
+ Assert.assertEquals(1L, (long)res);
+ }
+
+ /** */
+ public void testSimpleLocal() {
+ ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
+
+ IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
+ int init = 1;
+ int initLocCtxData = 0;
+ UUID trainingUUID = UUID.randomUUID();
+ TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
+
+ Integer res = chain.
+ thenLocally((prev, lc) -> prev + 1).
+ process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
+
+ Assert.assertEquals(init + 1, (long)res);
+ Assert.assertEquals(initLocCtxData, locCtx.data());
+ }
+
+ /** */
+ public void testChainLocal() {
+ ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
+
+ IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
+ int init = 1;
+ int initLocCtxData = 0;
+ UUID trainingUUID = UUID.randomUUID();
+ TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
+
+ Integer res = chain.
+ thenLocally((prev, lc) -> prev + 1).
+ thenLocally((prev, lc) -> prev * 5).
+ process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
+
+ Assert.assertEquals((init + 1) * 5, (long)res);
+ Assert.assertEquals(initLocCtxData, locCtx.data());
+ }
+
+ /** */
+ public void testChangeLocalContext() {
+ ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
+ IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
+ int init = 1;
+ int newData = 10;
+ UUID trainingUUID = UUID.randomUUID();
+ TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
+
+ Integer res = chain.
+ thenLocally((prev, lc) -> { lc.setData(newData); return prev;}).
+ process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
+
+ Assert.assertEquals(newData, locCtx.data());
+ Assert.assertEquals(init, res.intValue());
+ }
+
+ /** */
+ public void testDistributed() {
+ ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
+ IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
+ int init = 1;
+ UUID trainingUUID = UUID.randomUUID();
+ TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
+
+ Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
+ m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
+ m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
+ m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
+ m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
+
+ Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
+
+ cache.putAll(m);
+
+ IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
+ IgniteBinaryOperator<Integer> max = Integer::max;
+
+ Integer res = chain.
+ thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max, Integer.MIN_VALUE).
+ process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
+
+ int locMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
+
+ assertEquals((long)locMax, (long)res);
+
+ for (GroupTrainerCacheKey<Double> key : m.keySet())
+ m.compute(key, (k, v) -> v + 1);
+
+ assertMapEqualsCache(m, cache);
+ }
+
+ /** */
+ private ResultAndUpdates<Integer> readAndIncrement(EntryAndContext<Double, Integer, Void> ec) {
+ Integer val = ec.entry().getValue();
+
+ ResultAndUpdates<Integer> res = ResultAndUpdates.of(val);
+ res.update(TestGroupTrainingCache.getOrCreate(Ignition.localIgnite()), ec.entry().getKey(), val + 1);
+
+ return res;
+ }
+
+ /** */
+ private <K, V> void assertMapEqualsCache(Map<K, V> m, IgniteCache<K, V> cache) {
+ assertEquals(m.size(), cache.size());
+
+ for (K k : m.keySet())
+ assertEquals(m.get(k), cache.get(k));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java
new file mode 100644
index 0000000..5bb9a47
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Test of {@link GroupTrainer}.
+ */
+public class GroupTrainerTest extends GridCommonAbstractTest {
+ /** Count of nodes. */
+ private static final int NODE_COUNT = 3;
+
+ /** Grid instance. */
+ private Ignite ignite;
+
+ /**
+ * Default constructor.
+ */
+ public GroupTrainerTest() {
+ super(false);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+ TestGroupTrainingCache.getOrCreate(ignite).removeAll();
+ TestGroupTrainingSecondCache.getOrCreate(ignite).removeAll();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /** */
+ public void testGroupTrainer() {
+ TestGroupTrainer trainer = new TestGroupTrainer(ignite);
+
+ int limit = 5;
+ int eachNumCnt = 3;
+ int iterCnt = 2;
+
+ ConstModel<Integer> mdl = trainer.train(new SimpleGroupTrainerInput(limit, eachNumCnt, iterCnt));
+ int locRes = computeLocally(limit, eachNumCnt, iterCnt);
+ assertEquals(locRes, (int)mdl.apply(10));
+ }
+
+ /** */
+ private int computeLocally(int limit, int eachNumCnt, int iterCnt) {
+ Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
+
+ for (int i = 0; i < limit; i++) {
+ for (int j = 0; j < eachNumCnt; j++)
+ m.put(new GroupTrainerCacheKey<>(i, (double)j, null), i);
+ }
+
+ for (int i = 0; i < iterCnt; i++)
+ for (GroupTrainerCacheKey<Double> key : m.keySet())
+ m.compute(key, (key1, integer) -> integer * integer);
+
+ return m.values().stream().filter(x -> x % 2 == 0).mapToInt(i -> i).sum();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java
new file mode 100644
index 0000000..efca26a
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
+/** */
+class SimpleGroupTrainerInput implements GroupTrainerInput<Double> {
+ /** */
+ private final int limit;
+ /** */
+ private final int eachNumCnt;
+ /** */
+ private final int iterCnt;
+
+ /** */
+ SimpleGroupTrainerInput(int limit, int eachNumCnt, int iterCnt) {
+ this.limit = limit;
+ this.eachNumCnt = eachNumCnt;
+ this.iterCnt = iterCnt;
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>> initialKeys(UUID trainingUUID) {
+ return () -> IntStream.range(0, limit).mapToObj(i -> new GroupTrainerCacheKey<>(i, 0.0, trainingUUID));
+ }
+
+ /** */
+ int limit() {
+ return limit;
+ }
+
+ /** */
+ int iterCnt() {
+ return iterCnt;
+ }
+
+ /** */
+ int eachNumberCount() {
+ return eachNumCnt;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
new file mode 100644
index 0000000..75be373
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.Chains;
+import org.apache.ignite.ml.trainers.group.chain.ComputationsChain;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+
+/**
+ * Test group trainer.
+ */
+class TestGroupTrainer extends GroupTrainer<TestGroupTrainerLocalContext, Double, Integer, Integer, Integer, Double,
+ ConstModel<Integer>, SimpleGroupTrainerInput, Void> {
+ /**
+ * Construct instance of this class with given parameters.
+ *
+ * @param ignite Ignite instance.
+ */
+ TestGroupTrainer(Ignite ignite) {
+ super(TestGroupTrainingCache.getOrCreate(ignite), ignite);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected TestGroupTrainerLocalContext initialLocalContext(SimpleGroupTrainerInput data,
+ UUID trainingUUID) {
+ return new TestGroupTrainerLocalContext(data.iterCnt(), data.eachNumberCount(), data.limit(), trainingUUID);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteFunction<GroupTrainerCacheKey<Double>, ResultAndUpdates<Integer>> distributedInitializer(
+ SimpleGroupTrainerInput data) {
+ return key -> {
+ long i = key.nodeLocalEntityIndex();
+ UUID trainingUUID = key.trainingUUID();
+ IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(Ignition.localIgnite());
+
+ long sum = i * data.eachNumberCount();
+
+ ResultAndUpdates<Integer> res = ResultAndUpdates.of((int)sum);
+
+ for (int j = 0; j < data.eachNumberCount(); j++)
+ res.update(cache, new GroupTrainerCacheKey<>(i, (double)j, trainingUUID), (int)i);
+
+ return res;
+ };
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteBinaryOperator<Integer> reduceDistributedInitData() {
+ return (a, b) -> a + b;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double locallyProcessInitData(Integer data, TestGroupTrainerLocalContext locCtx) {
+ return data.doubleValue();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected ComputationsChain<TestGroupTrainerLocalContext,
+ Double, Integer, Double, Double> trainingLoopStep() {
+ ComputationsChain<TestGroupTrainerLocalContext, Double, Integer, Double, Double> chain = Chains.
+ create(new TestTrainingLoopStep());
+ return chain.
+ thenLocally((aDouble, context) -> {
+ context.incCnt();
+ return aDouble;
+ });
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean shouldContinue(Double data, TestGroupTrainerLocalContext locCtx) {
+ return locCtx.cnt() < locCtx.maxCnt();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteSupplier<Void> extractContextForFinalResultCreation(Double data,
+ TestGroupTrainerLocalContext locCtx) {
+ // No context is needed.
+ return () -> null;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>> finalResultKeys(Double data,
+ TestGroupTrainerLocalContext locCtx) {
+ int limit = locCtx.limit();
+ int cnt = locCtx.eachNumberCnt();
+ UUID uuid = locCtx.trainingUUID();
+
+ return () -> TestGroupTrainingCache.allKeys(limit, cnt, uuid);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteFunction<EntryAndContext<Double, Integer, Void>, ResultAndUpdates<Integer>> finalResultsExtractor() {
+ return entryAndCtx -> {
+ Integer val = entryAndCtx.entry().getValue();
+ return ResultAndUpdates.of(val % 2 == 0 ? val : 0);
+ };
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Integer defaultFinalResult() {
+ return 0;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteBinaryOperator<Integer> finalResultsReducer() {
+ return (a, b) -> a + b;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected ConstModel<Integer> mapFinalResult(Integer res, TestGroupTrainerLocalContext locCtx) {
+ return new ConstModel<>(res);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void cleanup(TestGroupTrainerLocalContext locCtx) {
+ Stream<GroupTrainerCacheKey<Double>> toRemote = TestGroupTrainingCache.allKeys(locCtx.limit(), locCtx.eachNumberCnt(), locCtx.trainingUUID());
+ TestGroupTrainingCache.getOrCreate(ignite).removeAll(toRemote.collect(Collectors.toSet()));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java
new file mode 100644
index 0000000..b8db56e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/** */
+class TestGroupTrainerLocalContext implements HasTrainingUUID {
+ /** */
+ private int cnt;
+ /** */
+ private final int maxCnt;
+ /** */
+ private final int eachNumCnt;
+ /** */
+ private final int limit;
+ /** */
+ private final UUID trainingUUID;
+
+ /** */
+ TestGroupTrainerLocalContext(int maxCnt, int eachNumCnt, int limit, UUID trainingUUID) {
+ this.maxCnt = maxCnt;
+ this.eachNumCnt = eachNumCnt;
+ this.limit = limit;
+ this.trainingUUID = trainingUUID;
+ this.cnt = 0;
+ }
+
+ /** */
+ int cnt() {
+ return cnt;
+ }
+
+ /** */
+ void incCnt() {
+ this.cnt++;
+ }
+
+ /** */
+ int maxCnt() {
+ return maxCnt;
+ }
+
+ /** */
+ int eachNumberCnt() {
+ return eachNumCnt;
+ }
+
+ /** */
+ int limit() {
+ return limit;
+ }
+
+ /** {@inheritDoc} */
+ @Override public UUID trainingUUID() {
+ return trainingUUID;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java
new file mode 100644
index 0000000..e7826ae
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.Arrays;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+
+/** */
+class TestGroupTrainingCache {
+ /** */
+ private static final String CACHE_NAME = "TEST_GROUP_TRAINING_CACHE";
+
+ /** */
+ static IgniteCache<GroupTrainerCacheKey<Double>, Integer> getOrCreate(Ignite ignite) {
+ CacheConfiguration<GroupTrainerCacheKey<Double>, Integer> cfg = new CacheConfiguration<>();
+
+ // Write to primary.
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
+
+ // Atomic transactions only.
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ // No copying of values.
+ cfg.setCopyOnRead(false);
+
+ // Cache is partitioned.
+ cfg.setCacheMode(CacheMode.PARTITIONED);
+
+ cfg.setBackups(0);
+
+ cfg.setOnheapCacheEnabled(true);
+
+ cfg.setName(CACHE_NAME);
+
+ return ignite.getOrCreateCache(cfg);
+ }
+
+ /** */
+ static Stream<GroupTrainerCacheKey<Double>> allKeys(int limit, int eachNumCnt, UUID trainingUUID) {
+ @SuppressWarnings("unchecked")
+ GroupTrainerCacheKey<Double>[] a = new GroupTrainerCacheKey[limit * eachNumCnt];
+
+ for (int num = 0; num < limit; num++)
+ for (int i = 0; i < eachNumCnt; i++)
+ a[num * eachNumCnt + i] = new GroupTrainerCacheKey<>(num, (double)i, trainingUUID);
+
+ return Arrays.stream(a);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java
new file mode 100644
index 0000000..fea931e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+
+/** */
+class TestGroupTrainingSecondCache {
+ /** */
+ private static final String CACHE_NAME = "TEST_GROUP_TRAINING_SECOND_CACHE";
+
+ /** */
+ static IgniteCache<GroupTrainerCacheKey<Character>, Integer> getOrCreate(Ignite ignite) {
+ CacheConfiguration<GroupTrainerCacheKey<Character>, Integer> cfg = new CacheConfiguration<>();
+
+ // Write to primary.
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
+
+ // Atomic transactions only.
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ // No copying of values.
+ cfg.setCopyOnRead(false);
+
+ // Cache is partitioned.
+ cfg.setCacheMode(CacheMode.PARTITIONED);
+
+ cfg.setBackups(0);
+
+ cfg.setOnheapCacheEnabled(true);
+
+ cfg.setName(CACHE_NAME);
+
+ return ignite.getOrCreateCache(cfg);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java
new file mode 100644
index 0000000..8348334
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/** */
+class TestLocalContext implements HasTrainingUUID {
+ /** */
+ private final UUID trainingUUID;
+ /** */
+ private int data;
+
+ /** */
+ TestLocalContext(int data, UUID trainingUUID) {
+ this.data = data;
+ this.trainingUUID = trainingUUID;
+ }
+
+ /** */
+ int data() {
+ return data;
+ }
+
+ /** */
+ void setData(int data) {
+ this.data = data;
+ }
+
+ /** {@inheritDoc} */
+ @Override public UUID trainingUUID() {
+ return trainingUUID;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestTrainingLoopStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestTrainingLoopStep.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestTrainingLoopStep.java
new file mode 100644
index 0000000..21f328a
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestTrainingLoopStep.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.DistributedEntryProcessingStep;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+
+/** */
+class TestTrainingLoopStep implements DistributedEntryProcessingStep<TestGroupTrainerLocalContext, Double, Integer, Void, Double, Double> {
+ /** {@inheritDoc} */
+ @Override public IgniteSupplier<Void> remoteContextSupplier(Double input, TestGroupTrainerLocalContext locCtx) {
+ // No context is needed.
+ return () -> null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteFunction<EntryAndContext<Double, Integer, Void>, ResultAndUpdates<Double>> worker() {
+ return entryAndContext -> {
+ Integer oldVal = entryAndContext.entry().getValue();
+ double v = oldVal * oldVal;
+ ResultAndUpdates<Double> res = ResultAndUpdates.of(v);
+ res.update(TestGroupTrainingCache.getOrCreate(Ignition.localIgnite()), entryAndContext.entry().getKey(), (int)v);
+ return res;
+ };
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>> keys(Double input,
+ TestGroupTrainerLocalContext locCtx) {
+ // Copying here because otherwise locCtx will be serialized with supplier returned in result.
+ int limit = locCtx.limit();
+ int cnt = locCtx.eachNumberCnt();
+ UUID uuid = locCtx.trainingUUID();
+
+ return () -> TestGroupTrainingCache.allKeys(limit, cnt, uuid);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double identity() {
+ return 0.0;
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteBinaryOperator<Double> reducer() {
+ return (a, b) -> a + b;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TrainersGroupTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TrainersGroupTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TrainersGroupTestSuite.java
new file mode 100644
index 0000000..0ec5afb
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TrainersGroupTestSuite.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for group trainer tests.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ DistributedWorkersChainTest.class,
+ GroupTrainerTest.class
+})
+public class TrainersGroupTestSuite {
+}
[2/2] ignite git commit: IGNITE-6783: Create common mechanism for
group training.
Posted by ch...@apache.org.
IGNITE-6783: Create common mechanism for group training.
this closes #3297
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/b0c5ef1e
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/b0c5ef1e
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/b0c5ef1e
Branch: refs/heads/master
Commit: b0c5ef1ea273530f48a8014c004ed073be9c6d6e
Parents: 442716e
Author: Yury Babak <yb...@gridgain.com>
Authored: Sun Dec 31 00:59:21 2017 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Sun Dec 31 00:59:21 2017 +0300
----------------------------------------------------------------------
.../ignite/ml/math/functions/Functions.java | 66 ++++-
.../functions/IgniteCurriedTriFunction.java | 28 +++
.../trainers/group/BaseLocalProcessorJob.java | 154 ++++++++++++
.../ignite/ml/trainers/group/ConstModel.java | 46 ++++
.../ignite/ml/trainers/group/GroupTrainer.java | 206 +++++++++++++++
.../group/GroupTrainerBaseProcessorTask.java | 151 +++++++++++
.../ml/trainers/group/GroupTrainerCacheKey.java | 125 ++++++++++
.../group/GroupTrainerEntriesProcessorTask.java | 64 +++++
.../ml/trainers/group/GroupTrainerInput.java | 37 +++
.../group/GroupTrainerKeysProcessorTask.java | 63 +++++
.../ml/trainers/group/GroupTrainingContext.java | 98 ++++++++
.../group/LocalEntriesProcessorJob.java | 86 +++++++
.../trainers/group/LocalKeysProcessorJob.java | 79 ++++++
.../ignite/ml/trainers/group/Metaoptimizer.java | 100 ++++++++
.../group/MetaoptimizerDistributedStep.java | 94 +++++++
.../group/MetaoptimizerGroupTrainer.java | 129 ++++++++++
.../ml/trainers/group/ResultAndUpdates.java | 173 +++++++++++++
.../ignite/ml/trainers/group/chain/Chains.java | 56 +++++
.../trainers/group/chain/ComputationsChain.java | 248 +++++++++++++++++++
.../chain/DistributedEntryProcessingStep.java | 34 +++
.../chain/DistributedKeyProcessingStep.java | 33 +++
.../trainers/group/chain/DistributedStep.java | 77 ++++++
.../trainers/group/chain/EntryAndContext.java | 70 ++++++
.../trainers/group/chain/HasTrainingUUID.java | 32 +++
.../ml/trainers/group/chain/KeyAndContext.java | 67 +++++
.../ml/trainers/group/chain/package-info.java | 22 ++
.../ignite/ml/trainers/group/package-info.java | 22 ++
.../apache/ignite/ml/trainers/package-info.java | 22 ++
.../org/apache/ignite/ml/util/MnistUtils.java | 6 +-
.../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +-
.../group/DistributedWorkersChainTest.java | 188 ++++++++++++++
.../ml/trainers/group/GroupTrainerTest.java | 90 +++++++
.../trainers/group/SimpleGroupTrainerInput.java | 60 +++++
.../ml/trainers/group/TestGroupTrainer.java | 144 +++++++++++
.../group/TestGroupTrainerLocalContext.java | 74 ++++++
.../trainers/group/TestGroupTrainingCache.java | 71 ++++++
.../group/TestGroupTrainingSecondCache.java | 56 +++++
.../ml/trainers/group/TestLocalContext.java | 50 ++++
.../ml/trainers/group/TestTrainingLoopStep.java | 68 +++++
.../trainers/group/TrainersGroupTestSuite.java | 32 +++
40 files changed, 3217 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
index fa7ee76..f723166 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/Functions.java
@@ -100,7 +100,7 @@ public final class Functions {
* @return Minimum between {@code a} and {@code b} in terms of comparator {@code f}.
*/
public static <T> T MIN_GENERIC(T a, T b, Comparator<T> f) {
- return f.compare(a, b) > 0 ? a : b;
+ return f.compare(a, b) < 0 ? a : b;
}
/** Function that returns {@code min(abs(a), abs(b))}. */
@@ -215,15 +215,73 @@ public final class Functions {
}
/**
- * Curry bifunction.
+ * Curry bi-function.
*
- * @param f Bifunction to curry.
+ * @param f Bi-function to curry.
* @param <A> Type of first argument of {@code f}.
* @param <B> Type of second argument of {@code f}.
* @param <C> Return type of {@code f}.
- * @return Curried bifunction.
+ * @return Curried bi-function.
*/
public static <A, B, C> IgniteCurriedBiFunction<A, B, C> curry(BiFunction<A, B, C> f) {
return a -> b -> f.apply(a, b);
}
+
+ /**
+ * Transform bi-function of the form (a, b) -> c into a function of form a -> (b -> c).
+ *
+ * @param f Function to be curried.
+ * @param <A> Type of first argument of function to be transformed.
+ * @param <B> Type of second argument of function to be transformed.
+ * @param <C> Type of third argument of function to be transformed.
+ * @return Curried bi-function.
+ */
+ public static <A, B, C> IgniteCurriedBiFunction<A, B, C> curry(IgniteBiFunction<A, B, C> f) {
+ return a -> b -> f.apply(a, b);
+ }
+
+ /**
+ * Transform tri-function of the form (a, b, c) -> d into a function of form a -> (b -> (c -> d)).
+ *
+ * @param f Function to be curried.
+ * @param <A> Type of first argument of function to be transformed.
+ * @param <B> Type of second argument of function to be transformed.
+ * @param <C> Type of third argument of function to be transformed.
+ * @param <D> Type output of function to be transformed.
+ * @return Curried tri-function.
+ */
+ public static <A, B, C, D> IgniteCurriedTriFunction<A, B, C, D> curry(IgniteTriFunction<A, B, C, D> f) {
+ return a -> b -> c -> f.apply(a, b, c);
+ }
+
+ /**
+ * Transform function of form a -> b into a -> (() -> b).
+ *
+ * @param f Function to be transformed.
+ * @param <A> Type of input of function to be transformed.
+ * @param <B> Type of output of function to be transformed.
+ * @return Transformed function.
+ */
+ public static <A, B> IgniteFunction<A, IgniteSupplier<B>> outputSupplier(IgniteFunction<A, B> f) {
+ return a -> {
+ B res = f.apply(a);
+ return () -> res;
+ };
+ }
+
+ /**
+ * Transform function of form (a, b) -> c into (a, b) - () -> c.
+ *
+ * @param f Function to be transformed.
+ * @param <A> Type of first argument of function to be transformed.
+ * @param <B> Type of second argument of function to be transformed.
+ * @param <C> Type of output of function to be transformed.
+ * @return Transformed function.
+ */
+ public static <A, B, C> IgniteBiFunction<A, B, IgniteSupplier<C>> outputSupplier(IgniteBiFunction<A, B, C> f) {
+ return (a, b) -> {
+ C res = f.apply(a, b);
+ return () -> res;
+ };
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedTriFunction.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedTriFunction.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedTriFunction.java
new file mode 100644
index 0000000..cddffcd
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/functions/IgniteCurriedTriFunction.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.functions;
+
+import java.io.Serializable;
+
+/**
+ * Serializable curried tri-function.
+ *
+ * @see IgniteCurriedBiFunction
+ */
+public interface IgniteCurriedTriFunction<A, B, C, D> extends IgniteFunction<A, IgniteCurriedBiFunction<B, C, D>>, Serializable {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java
new file mode 100644
index 0000000..d12252d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/BaseLocalProcessorJob.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.compute.ComputeJob;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
+/**
+ * Base job for group training.
+ * It's purpose is to apply worker to each element (cache key or cache entry) of given cache specified
+ * by keySupplier. Worker produces {@link ResultAndUpdates} object which contains 'side effects' which are updates
+ * needed to apply to caches and computation result.
+ * After we get all {@link ResultAndUpdates} we merge all 'update' parts of them for each node
+ * and apply them on corresponding node, also we reduce all 'result' by some given reducer.
+ *
+ * @param <K> Type of keys of cache used for group trainer.
+ * @param <V> Type of values of cache used for group trainer.
+ * @param <T> Type of elements to which workers are applier.
+ * @param <R> Type of result of worker.
+ */
+public abstract class BaseLocalProcessorJob<K, V, T, R extends Serializable> implements ComputeJob {
+ /**
+ * UUID of group training.
+ */
+ protected UUID trainingUUID;
+
+ /**
+ * Worker.
+ */
+ protected IgniteFunction<T, ResultAndUpdates<R>> worker;
+
+ /**
+ * Supplier of keys determining elements to which worker should be applied.
+ */
+ protected IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier;
+
+ /**
+ * Operator used to reduce results from worker.
+ */
+ protected IgniteBinaryOperator<R> reducer;
+
+ /**
+ * Identity for reducer.
+ */
+ protected final R identity;
+
+ /**
+ * Name of cache used for training.
+ */
+ protected String cacheName;
+
+ /**
+ * Construct instance of this class with given arguments.
+ *
+ * @param worker Worker.
+ * @param keySupplier Supplier of keys.
+ * @param reducer Reducer.
+ * @param identity Identity for reducer.
+ * @param trainingUUID UUID of training.
+ * @param cacheName Name of cache used for training.
+ */
+ public BaseLocalProcessorJob(
+ IgniteFunction<T, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier,
+ IgniteBinaryOperator<R> reducer,
+ R identity,
+ UUID trainingUUID, String cacheName) {
+ this.worker = worker;
+ this.keySupplier = keySupplier;
+ this.identity = identity;
+ this.reducer = reducer;
+ this.trainingUUID = trainingUUID;
+ this.cacheName = cacheName;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void cancel() {
+ // NO-OP.
+ }
+
+ /** {@inheritDoc} */
+ @Override public R execute() throws IgniteException {
+ List<ResultAndUpdates<R>> resultsAndUpdates = toProcess().
+ map(worker).
+ collect(Collectors.toList());
+
+ ResultAndUpdates<R> totalRes = ResultAndUpdates.sum(reducer, identity, resultsAndUpdates);
+
+ totalRes.applyUpdates(ignite());
+
+ return totalRes.result();
+ }
+
+ /**
+ * Get stream of elements to process.
+ *
+ * @return Stream of elements to process.
+ */
+ protected abstract Stream<T> toProcess();
+
+ /**
+ * Ignite instance.
+ *
+ * @return Ignite instance.
+ */
+ protected static Ignite ignite() {
+ return Ignition.localIgnite();
+ }
+
+ /**
+ * Get cache used for training.
+ *
+ * @return Cache used for training.
+ */
+ protected IgniteCache<GroupTrainerCacheKey<K>, V> cache() {
+ return ignite().getOrCreateCache(cacheName);
+ }
+
+ /**
+ * Get affinity function for cache used in group training.
+ *
+ * @return Affinity function for cache used in group training.
+ */
+ protected Affinity<GroupTrainerCacheKey> affinity() {
+ return ignite().affinity(cacheName);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java
new file mode 100644
index 0000000..75f8179
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ConstModel.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import org.apache.ignite.ml.Model;
+
+/**
+ * Model which outputs given constant.
+ *
+ * @param <T> Type of constant.
+ */
+public class ConstModel<T> implements Model<T, T> {
+ /**
+ * Constant to be returned by this model.
+ */
+ private T c;
+
+ /**
+ * Create instance of this class specified by input parameters.
+ *
+ * @param c Constant to be returned by this model.
+ */
+ public ConstModel(T c) {
+ this.c = c;
+ }
+
+ /** {@inheritDoc} */
+ @Override public T apply(T val) {
+ return c;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainer.java
new file mode 100644
index 0000000..ca8dcf0
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainer.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.ComputationsChain;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/**
+ * Class encapsulating synchronous distributed group training.
+ * Training is performed by following scheme:
+ * 1. For specified set of keys distributed initialization is done. For each key some initialization result is returned.
+ * 2. All initialization results are processed locally and reduced into some object of type I.
+ * 3. While 'shouldContinue' condition is true, training loop step is executed.
+ * 4. After loop is finished, data from each key from final key set is collected.
+ * 5. Data collected on previous step is transformed into a model which is returned as final result.
+ * Note that all methods returning functions, suppliers etc should return values with minimal dependencies because they are serialized
+ * with all dependent objects.
+ *
+ * @param <LC> Type of local context of the training.
+ * @param <K> Type of cache keys on which the training is done.
+ * @param <V> Type of cache values on which the training is done.
+ * @param <IN> Type of data returned after initializing of distributed context.
+ * @param <R> Type of result returned after training from each node.
+ * @param <I> Type of data which is fed into each training loop step and returned from it.
+ * @param <M> Type of model returned after training.
+ * @param <T> Type of input to this trainer.
+ * @param <G> Type of distributed context which is needed for forming final result which is send from each node to trainer for final model creation.
+ */
+abstract class GroupTrainer<LC extends HasTrainingUUID, K, V, IN extends Serializable, R extends Serializable, I extends Serializable, M extends Model, T extends GroupTrainerInput<K>, G> implements Trainer<M, T> {
+ /**
+ * Cache on which training is performed. For example it can be cache of neural networks.
+ */
+ IgniteCache<GroupTrainerCacheKey<K>, V> cache;
+
+ /**
+ * Ignite instance.
+ */
+ Ignite ignite;
+
+ /**
+ * Construct an instance of this class.
+ *
+ * @param cache Cache on which training is performed.
+ * @param ignite Ignite instance.
+ */
+ GroupTrainer(
+ IgniteCache<GroupTrainerCacheKey<K>, V> cache,
+ Ignite ignite) {
+ this.cache = cache;
+ this.ignite = ignite;
+ }
+
+ /** {@inheritDoc} */
+ @Override public M train(T data) {
+ UUID trainingUUID = UUID.randomUUID();
+ LC locCtx = initialLocalContext(data, trainingUUID);
+
+ GroupTrainingContext<K, V, LC> ctx = new GroupTrainingContext<>(locCtx, cache, ignite);
+ ComputationsChain<LC, K, V, T, T> chain = (i, c) -> i;
+ IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer = distributedInitializer(data);
+
+ M res = chain.
+ thenDistributedForKeys(distributedInitializer, (t, lc) -> data.initialKeys(trainingUUID), reduceDistributedInitData()).
+ thenLocally(this::locallyProcessInitData).
+ thenWhile(this::shouldContinue, trainingLoopStep()).
+ thenDistributedForEntries(this::extractContextForFinalResultCreation, finalResultsExtractor(), this::finalResultKeys, finalResultsReducer(), defaultFinalResult()).
+ thenLocally(this::mapFinalResult).
+ process(data, ctx);
+
+ cleanup(locCtx);
+
+ return res;
+ }
+
+ /**
+ * Create initial local context from data given as input to trainer.
+ *
+ * @param data Data given as input to this trainer.
+ * @param trainingUUID UUID of this training.
+ * @return Initial local context.
+ */
+ protected abstract LC initialLocalContext(T data, UUID trainingUUID);
+
+ /**
+ * Get function for initialization for each of keys specified in initial key set.
+ *
+ * @param data Data given to this trainer as input.
+ * @return Function for initialization for each of keys specified in initial key set.
+ */
+ protected abstract IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer(T data);
+
+ /**
+ * Get reducer to reduce data collected from initialization of each key specified in initial key set.
+ *
+ * @return Reducer to reduce data collected from initialization of each key specified in initial key set.
+ */
+ protected abstract IgniteBinaryOperator<IN> reduceDistributedInitData();
+
+ /**
+ * Transform data from initialization step into data which is fed as input to first step of training loop.
+ *
+ * @param data Data from initialization step.
+ * @param locCtx Local context.
+ * @return Data which is fed as input to first step of training loop.
+ */
+ protected abstract I locallyProcessInitData(IN data, LC locCtx);
+
+ /**
+ * Training loop step.
+ *
+ * @return Result of training loop step.
+ */
+ protected abstract ComputationsChain<LC, K, V, I, I> trainingLoopStep();
+
+ /**
+ * Condition specifying if training loop should continue.
+ *
+ * @param data First time, data returned by locallyProcessInitData then data returned by last step of loop.
+ * @param locCtx Local context.
+ * @return Boolean value indicating if training loop should continue.
+ */
+ protected abstract boolean shouldContinue(I data, LC locCtx);
+
+ /**
+ * Extract context for final result creation. Each key from the final keys set will be processed with
+ * finalResultsExtractor. While entry data (i.e. key and value) for each key varies, some data can be common for all
+ * processed entries. This data is called context.
+ *
+ * @param data Data returned from last training loop step.
+ * @param locCtx Local context.
+ * @return Context.
+ */
+ protected abstract IgniteSupplier<G> extractContextForFinalResultCreation(I data, LC locCtx);
+
+ /**
+ * Keys for final result creation.
+ *
+ * @param data Data returned from the last training loop step.
+ * @param locCtx Local context.
+ * @return Stream of keys for final result creation.
+ */
+ protected abstract IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> finalResultKeys(I data, LC locCtx);
+
+ /**
+ * Get function for extracting final result from each key specified in finalResultKeys.
+ *
+ * @return Function for extracting final result from each key specified in finalResultKeys.
+ */
+ protected abstract IgniteFunction<EntryAndContext<K, V, G>, ResultAndUpdates<R>> finalResultsExtractor();
+
+ /**
+ * Default final result. Should be identity for finalResultsReducer.
+ *
+ * @return Default final result.
+ */
+ protected abstract R defaultFinalResult();
+
+ /**
+ * Get function for reducing final results.
+ *
+ * @return Function for reducing final results.
+ */
+ protected abstract IgniteBinaryOperator<R> finalResultsReducer();
+
+ /**
+ * Map final result to model which is returned by trainer.
+ *
+ * @param res Final result.
+ * @param locCtx Local context.
+ * @return Model resulted from training.
+ */
+ protected abstract M mapFinalResult(R res, LC locCtx);
+
+ /**
+ * Performs cleanups of temporary objects created by this trainer.
+ *
+ * @param locCtx Local context.
+ */
+ protected abstract void cleanup(LC locCtx);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java
new file mode 100644
index 0000000..2b49ac9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerBaseProcessorTask.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.cluster.ClusterNode;
+import org.apache.ignite.compute.ComputeJob;
+import org.apache.ignite.compute.ComputeJobResult;
+import org.apache.ignite.compute.ComputeTaskAdapter;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.jetbrains.annotations.Nullable;
+
+/**
+ * Base task for group trainer.
+ *
+ * @param <K> Type of cache keys of cache used for training.
+ * @param <V> Type of cache values of cache used for training.
+ * @param <C> Type of context (common part of data needed for computation).
+ * @param <T> Type of arguments of workers.
+ * @param <R> Type of computation result.
+ */
+public abstract class GroupTrainerBaseProcessorTask<K, V, C, T, R extends Serializable> extends ComputeTaskAdapter<Void, R> {
+ /**
+ * Context supplier.
+ */
+ protected final IgniteSupplier<C> ctxSupplier;
+
+ /**
+ * UUID of training.
+ */
+ protected final UUID trainingUUID;
+
+ /**
+ * Worker.
+ */
+ protected IgniteFunction<T, ResultAndUpdates<R>> worker;
+
+ /**
+ * Reducer used for reducing of computations on specified keys.
+ */
+ protected final IgniteBinaryOperator<R> reducer;
+
+ /**
+ * Identity for reducer.
+ */
+ protected final R identity;
+
+ /**
+ * Name of cache on which training is done.
+ */
+ protected final String cacheName;
+
+ /**
+ * Supplier of keys on which worker should be executed.
+ */
+ protected final IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier;
+
+ /**
+ * Ignite instance.
+ */
+ protected final Ignite ignite;
+
+ /**
+ * Construct an instance of this class with specified parameters.
+ *
+ * @param trainingUUID UUID of training.
+ * @param ctxSupplier Supplier of context.
+ * @param worker Function calculated on each of specified keys.
+ * @param keysSupplier Supplier of keys on which training is done.
+ * @param reducer Reducer used for reducing results of computation performed on each of specified keys.
+ * @param identity Identity for reducer.
+ * @param cacheName Name of cache on which training is done.
+ * @param ignite Ignite instance.
+ */
+ public GroupTrainerBaseProcessorTask(UUID trainingUUID,
+ IgniteSupplier<C> ctxSupplier,
+ IgniteFunction<T, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier,
+ IgniteBinaryOperator<R> reducer, R identity,
+ String cacheName,
+ Ignite ignite) {
+ this.trainingUUID = trainingUUID;
+ this.ctxSupplier = ctxSupplier;
+ this.worker = worker;
+ this.keysSupplier = keysSupplier;
+ this.identity = identity;
+ this.reducer = reducer;
+ this.cacheName = cacheName;
+ this.ignite = ignite;
+ }
+
+ /** {@inheritDoc} */
+ @Nullable @Override public Map<? extends ComputeJob, ClusterNode> map(List<ClusterNode> subgrid,
+ @Nullable Void arg) throws IgniteException {
+ Map<ComputeJob, ClusterNode> res = new HashMap<>();
+
+ for (ClusterNode node : subgrid) {
+ BaseLocalProcessorJob<K, V, T, R> job = createJob();
+ res.put(job, node);
+ }
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Nullable @Override public R reduce(List<ComputeJobResult> results) throws IgniteException {
+ return results.stream().map(res -> (R)res.getData()).filter(Objects::nonNull).reduce(reducer).orElse(identity);
+ }
+
+ /**
+ * Create job for execution on subgrid.
+ *
+ * @return Job for execution on subgrid.
+ */
+ protected abstract BaseLocalProcessorJob<K, V, T, R> createJob();
+
+ /**
+ * Get affinity function of cache on which training is done.
+ *
+ * @return Affinity function of cache on which training is done.
+ */
+ protected Affinity<GroupTrainerCacheKey> affinity() {
+ return ignite.affinity(cacheName);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerCacheKey.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerCacheKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerCacheKey.java
new file mode 100644
index 0000000..5e4cb76
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerCacheKey.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import org.apache.ignite.cache.affinity.AffinityKeyMapped;
+
+/**
+ * Class used as a key for caches on which {@link GroupTrainer} works.
+ * Structurally it is a triple: (nodeLocalEntityIndex, trainingUUID, data);
+ * nodeLocalEntityIndex is used to map key to node;
+ * trainingUUID is id of training;
+ * data is some custom data stored in this key, for example if we want to store three neural networks on one node
+ * for training with training UUID == trainingUUID, we can use keys
+ * (1, trainingUUID, networkIdx1), (1, trainingUUID, networkIdx2), (1, trainingUUID, networkIdx3).
+ *
+ * @param <K> Type of data part of this key.
+ */
+public class GroupTrainerCacheKey<K> {
+ /**
+ * Part of key for key-to-node affinity.
+ */
+ @AffinityKeyMapped
+ private Long nodeLocEntityIdx;
+
+ /**
+ * UUID of training.
+ */
+ private UUID trainingUUID;
+
+ /**
+ * Data.
+ */
+ K data;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param nodeLocEntityIdx Part of key for key-to-node affinity.
+ * @param data Data.
+ * @param trainingUUID Training UUID.
+ */
+ public GroupTrainerCacheKey(long nodeLocEntityIdx, K data, UUID trainingUUID) {
+ this.nodeLocEntityIdx = nodeLocEntityIdx;
+ this.trainingUUID = trainingUUID;
+ this.data = data;
+ }
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param nodeLocEntityIdx Part of key for key-to-node affinity.
+ * @param data Data.
+ * @param trainingUUID Training UUID.
+ */
+ public GroupTrainerCacheKey(int nodeLocEntityIdx, K data, UUID trainingUUID) {
+ this((long)nodeLocEntityIdx, data, trainingUUID);
+ }
+
+ /**
+ * Get part of key used for key-to-node affinity.
+ *
+ * @return Part of key used for key-to-node affinity.
+ */
+ public Long nodeLocalEntityIndex() {
+ return nodeLocEntityIdx;
+ }
+
+ /**
+ * Get UUID of training.
+ *
+ * @return UUID of training.
+ */
+ public UUID trainingUUID() {
+ return trainingUUID;
+ }
+
+ /**
+ * Get data.
+ *
+ * @return Data.
+ */
+ public K data() {
+ return data;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ GroupTrainerCacheKey<?> key = (GroupTrainerCacheKey<?>)o;
+
+ if (nodeLocEntityIdx != null ? !nodeLocEntityIdx.equals(key.nodeLocEntityIdx) : key.nodeLocEntityIdx != null)
+ return false;
+ if (trainingUUID != null ? !trainingUUID.equals(key.trainingUUID) : key.trainingUUID != null)
+ return false;
+ return data != null ? data.equals(key.data) : key.data == null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = nodeLocEntityIdx != null ? nodeLocEntityIdx.hashCode() : 0;
+ res = 31 * res + (trainingUUID != null ? trainingUUID.hashCode() : 0);
+ res = 31 * res + (data != null ? data.hashCode() : 0);
+ return res;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerEntriesProcessorTask.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerEntriesProcessorTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerEntriesProcessorTask.java
new file mode 100644
index 0000000..7518ac2
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerEntriesProcessorTask.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+
+/**
+ * Task for processing entries of cache used for training.
+ *
+ * @param <K> Type of cache keys of cache used for training.
+ * @param <V> Type of cache values of cache used for training.
+ * @param <C> Type of context (common part of data needed for computation).
+ * @param <R> Type of computation result.
+ */
+public class GroupTrainerEntriesProcessorTask<K, V, C, R extends Serializable> extends GroupTrainerBaseProcessorTask<K, V, C, EntryAndContext<K, V, C>, R> {
+ /**
+ * Construct instance of this class with given parameters.
+ *
+ * @param trainingUUID UUID of training.
+ * @param ctxSupplier Supplier of context.
+ * @param worker Function calculated on each of specified keys.
+ * @param keysSupplier Supplier of keys on which training is done.
+ * @param reducer Reducer used for reducing results of computation performed on each of specified keys.
+ * @param identity Identity for reducer.
+ * @param cacheName Name of cache on which training is done.
+ * @param ignite Ignite instance.
+ */
+ public GroupTrainerEntriesProcessorTask(UUID trainingUUID,
+ IgniteSupplier<C> ctxSupplier,
+ IgniteFunction<EntryAndContext<K, V, C>, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier,
+ IgniteBinaryOperator<R> reducer, R identity,
+ String cacheName,
+ Ignite ignite) {
+ super(trainingUUID, ctxSupplier, worker, keysSupplier, reducer, identity, cacheName, ignite);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected BaseLocalProcessorJob<K, V, EntryAndContext<K, V, C>, R> createJob() {
+ return new LocalEntriesProcessorJob<>(ctxSupplier, worker, keysSupplier, reducer, identity, trainingUUID, cacheName);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerInput.java
new file mode 100644
index 0000000..ae75f16
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerInput.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
+/**
+ * Interface for {@link GroupTrainer} inputs.
+ *
+ * @param <K> Types of cache keys used for group training.
+ */
+public interface GroupTrainerInput<K> {
+ /**
+ * Get supplier of stream of keys used for initialization of {@link GroupTrainer}.
+ *
+ * @param trainingUUID UUID of training.
+ * @return Supplier of stream of keys used for initialization of {@link GroupTrainer}.
+ */
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> initialKeys(UUID trainingUUID);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerKeysProcessorTask.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerKeysProcessorTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerKeysProcessorTask.java
new file mode 100644
index 0000000..0ce3315
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainerKeysProcessorTask.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.KeyAndContext;
+
+/**
+ * Task for processing entries of cache used for training.
+ *
+ * @param <K> Type of cache keys of cache used for training.
+ * @param <C> Type of context (common part of data needed for computation).
+ * @param <R> Type of computation result.
+ */
+public class GroupTrainerKeysProcessorTask<K, C, R extends Serializable> extends GroupTrainerBaseProcessorTask<K, Object, C, KeyAndContext<K, C>, R> {
+ /**
+ * Construct instance of this class with specified parameters.
+ *
+ * @param trainingUUID UUID of training.
+ * @param ctxSupplier Context supplier.
+ * @param worker Function calculated on each of specified keys.
+ * @param keysSupplier Supplier of keys on which computations should be done.
+ * @param reducer Reducer used for reducing results of computation performed on each of specified keys.
+ * @param identity Identity for reducer.
+ * @param cacheName Name of cache on which training is done.
+ * @param ignite Ignite instance.
+ */
+ public GroupTrainerKeysProcessorTask(UUID trainingUUID,
+ IgniteSupplier<C> ctxSupplier,
+ IgniteFunction<KeyAndContext<K, C>, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysSupplier,
+ IgniteBinaryOperator<R> reducer, R identity,
+ String cacheName,
+ Ignite ignite) {
+ super(trainingUUID, ctxSupplier, worker, keysSupplier, reducer, identity, cacheName, ignite);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected BaseLocalProcessorJob<K, Object, KeyAndContext<K, C>, R> createJob() {
+ return new LocalKeysProcessorJob<>(ctxSupplier, worker, keysSupplier, reducer, identity, trainingUUID, cacheName);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainingContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainingContext.java
new file mode 100644
index 0000000..cbd04b2
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/GroupTrainingContext.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/**
+ * Context for group training.
+ *
+ * @param <K> Type of keys of cache used for group training.
+ * @param <V> Type of values of cache used for group training.
+ * @param <L> Type of local context used for training.
+ */
+public class GroupTrainingContext<K, V, L extends HasTrainingUUID> {
+ /**
+ * Local context.
+ */
+ private L locCtx;
+
+ /**
+ * Cache used for training.
+ */
+ private IgniteCache<GroupTrainerCacheKey<K>, V> cache;
+
+ /**
+ * Ignite instance.
+ */
+ private Ignite ignite;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param locCtx Local context.
+ * @param cache Information about cache used for training.
+ * @param ignite Ignite instance.
+ */
+ public GroupTrainingContext(L locCtx, IgniteCache<GroupTrainerCacheKey<K>, V> cache, Ignite ignite) {
+ this.locCtx = locCtx;
+ this.cache = cache;
+ this.ignite = ignite;
+ }
+
+ /**
+ * Construct new training context with same parameters but with new cache.
+ *
+ * @param newCache New cache.
+ * @param <K1> Type of keys of new cache.
+ * @param <V1> Type of values of new cache.
+ * @return New training context with same parameters but with new cache.
+ */
+ public <K1, V1> GroupTrainingContext<K1, V1, L> withCache(IgniteCache<GroupTrainerCacheKey<K1>, V1> newCache) {
+ return new GroupTrainingContext<>(locCtx, newCache, ignite);
+ }
+
+ /**
+ * Get local context.
+ *
+ * @return Local context.
+ */
+ public L localContext() {
+ return locCtx;
+ }
+
+ /**
+ * Get cache used for training.
+ *
+ * @return Cache used for training.
+ */
+ public IgniteCache<GroupTrainerCacheKey<K>, V> cache() {
+ return cache;
+ }
+
+ /**
+ * Get ignite instance.
+ *
+ * @return Ignite instance.
+ */
+ public Ignite ignite() {
+ return ignite;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalEntriesProcessorJob.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalEntriesProcessorJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalEntriesProcessorJob.java
new file mode 100644
index 0000000..ef0405f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalEntriesProcessorJob.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+
+/**
+ * {@link BaseLocalProcessorJob} specified to entry processing.
+ *
+ * @param <K> Type of cache used for group training.
+ * @param <V> Type of values used for group training.
+ * @param <C> Type of context.
+ * @param <R> Type of result returned by worker.
+ */
+public class LocalEntriesProcessorJob<K, V, C, R extends Serializable> extends BaseLocalProcessorJob<K, V, EntryAndContext<K, V, C>, R> {
+ /**
+ * Supplier of context for worker.
+ */
+ private final IgniteSupplier<C> ctxSupplier;
+
+ /**
+ * Construct an instance of this class.
+ *
+ * @param ctxSupplier Supplier for context for worker.
+ * @param worker Worker.
+ * @param keySupplier Supplier of keys.
+ * @param reducer Reducer.
+ * @param identity Identity for reducer.
+ * @param trainingUUID UUID for training.
+ * @param cacheName Name of cache used for training.
+ */
+ public LocalEntriesProcessorJob(IgniteSupplier<C> ctxSupplier,
+ IgniteFunction<EntryAndContext<K, V, C>, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier,
+ IgniteBinaryOperator<R> reducer, R identity,
+ UUID trainingUUID, String cacheName) {
+ super(worker, keySupplier, reducer, identity, trainingUUID, cacheName);
+ this.ctxSupplier = ctxSupplier;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Stream<EntryAndContext<K, V, C>> toProcess() {
+ C ctx = ctxSupplier.get();
+
+ return selectLocalEntries().map(e -> new EntryAndContext<>(e, ctx));
+ }
+
+ /**
+ * Select entries for processing by worker.
+ *
+ * @return Entries for processing by worker.
+ */
+ private Stream<Map.Entry<GroupTrainerCacheKey<K>, V>> selectLocalEntries() {
+ Set<GroupTrainerCacheKey<K>> keys = keySupplier.get().
+ filter(k -> Objects.requireNonNull(affinity().mapKeyToNode(k)).isLocal()).
+ filter(k -> k.trainingUUID().equals(trainingUUID)).
+ collect(Collectors.toSet());
+
+ return cache().getAll(keys).entrySet().stream();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalKeysProcessorJob.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalKeysProcessorJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalKeysProcessorJob.java
new file mode 100644
index 0000000..842dadf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/LocalKeysProcessorJob.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.Objects;
+import java.util.UUID;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.KeyAndContext;
+
+/**
+ * {@link BaseLocalProcessorJob} specified to keys processing.
+ *
+ * @param <K> Type of cache used for group training.
+ * @param <V> Type of values used for group training.
+ * @param <C> Type of context.
+ * @param <R> Type of result returned by worker.
+ */
+public class LocalKeysProcessorJob<K, V, C, R extends Serializable> extends BaseLocalProcessorJob<K, V, KeyAndContext<K, C>, R> {
+ /**
+ * Supplier of worker context.
+ */
+ private final IgniteSupplier<C> ctxSupplier;
+
+ /**
+ * Construct instance of this class with given arguments.
+ *
+ * @param worker Worker.
+ * @param keySupplier Supplier of keys.
+ * @param reducer Reducer.
+ * @param identity Identity for reducer.
+ * @param trainingUUID UUID of training.
+ * @param cacheName Name of cache used for training.
+ */
+ public LocalKeysProcessorJob(IgniteSupplier<C> ctxSupplier,
+ IgniteFunction<KeyAndContext<K, C>, ResultAndUpdates<R>> worker,
+ IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keySupplier,
+ IgniteBinaryOperator<R> reducer, R identity,
+ UUID trainingUUID, String cacheName) {
+ super(worker, keySupplier, reducer, identity, trainingUUID, cacheName);
+ this.ctxSupplier = ctxSupplier;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Stream<KeyAndContext<K, C>> toProcess() {
+ C ctx = ctxSupplier.get();
+
+ return selectLocalKeys().map(k -> new KeyAndContext<>(k, ctx));
+ }
+
+ /**
+ * Get subset of keys provided by keySupplier which are mapped to node on which code is executed.
+ *
+ * @return Subset of keys provided by keySupplier which are mapped to node on which code is executed.
+ */
+ private Stream<GroupTrainerCacheKey<K>> selectLocalKeys() {
+ return keySupplier.get().
+ filter(k -> Objects.requireNonNull(affinity().mapKeyToNode(k)).isLocal()).
+ filter(k -> k.trainingUUID().equals(trainingUUID));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/Metaoptimizer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/Metaoptimizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/Metaoptimizer.java
new file mode 100644
index 0000000..33e312a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/Metaoptimizer.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Class encapsulating data transformations in group training in {@link MetaoptimizerGroupTrainer}, which is adapter of
+ * {@link GroupTrainer}.
+ *
+ * @param <LC> Local context of {@link GroupTrainer}.
+ * @param <X> Type of data which is processed in training loop step.
+ * @param <Y> Type of data returned by training loop step data processor.
+ * @param <I> Type of data to which data returned by distributed initialization is mapped.
+ * @param <D> Type of data returned by initialization.
+ * @param <O> Type of data to which data returned by data processor is mapped.
+ */
+public interface Metaoptimizer<LC, X, Y, I, D, O> {
+ /**
+ * Get function used to reduce distributed initialization results.
+ *
+ * @return Function used to reduce distributed initialization results.
+ */
+ IgniteBinaryOperator<D> initialReducer();
+
+ /**
+ * Maps data returned by distributed initialization to data consumed by training loop step.
+ *
+ * @param data Data returned by distributed initialization.
+ * @param locCtx Local context.
+ * @return Mapping of data returned by distributed initialization to data consumed by training loop step.
+ */
+ I locallyProcessInitData(D data, LC locCtx);
+
+ /**
+ * Preprocess data for {@link MetaoptimizerGroupTrainer#dataProcessor()}.
+ *
+ * @return Preprocessed data for {@link MetaoptimizerGroupTrainer#dataProcessor()}.
+ */
+ default IgniteFunction<X, X> distributedPreprocessor() {
+ return x -> x;
+ }
+
+ /**
+ * Get function used to map values returned by {@link MetaoptimizerGroupTrainer#dataProcessor()}.
+ *
+ * @return Function used to map values returned by {@link MetaoptimizerGroupTrainer#dataProcessor()}.
+ */
+ IgniteFunction<Y, O> distributedPostprocessor();
+
+ /**
+ * Get binary operator used for reducing results returned by distributedPostprocessor.
+ *
+ * @return Binary operator used for reducing results returned by distributedPostprocessor.
+ */
+ IgniteBinaryOperator<O> postProcessReducer();
+
+ /**
+ * Get identity of postProcessReducer.
+ *
+ * @return Identity of postProcessReducer.
+ */
+ O postProcessIdentity();
+
+ /**
+ * Transform data returned by distributed part of training loop step into input fed into distributed part of training
+ * loop step.
+ *
+ * @param input Type of output of distributed part of training loop step.
+ * @param locCtx Local context.
+ * @return Result of transform data returned by distributed part of training loop step into input fed into distributed part of training
+ * loop step.
+ */
+ I localProcessor(O input, LC locCtx);
+
+ /**
+ * Returns value of predicate 'should training loop continue given previous step output and local context'.
+ *
+ * @param input Input of previous step.
+ * @param locCtx Local context.
+ * @return Value of predicate 'should training loop continue given previous step output and local context'.
+ */
+ boolean shouldContinue(I input, LC locCtx);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
new file mode 100644
index 0000000..2bf4db6
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.DistributedEntryProcessingStep;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/**
+ * Distributed step.
+ *
+ * TODO: IGNITE-7322: add full description.
+ */
+class MetaoptimizerDistributedStep<L extends HasTrainingUUID, K, V, G, I extends Serializable,
+ O extends Serializable, X, Y, D extends Serializable> implements DistributedEntryProcessingStep<L, K, V, G, I, O> {
+ /**
+ * {@link Metaoptimizer}.
+ */
+ private final Metaoptimizer<L, X, Y, I, D, O> metaoptimizer;
+
+ /**
+ * {@link MetaoptimizerGroupTrainer} for which this distributed step is used.
+ */
+ private final MetaoptimizerGroupTrainer<L, K, V, D, ?, I, ?, ?, G, O, X, Y> trainer;
+
+ /**
+ * Construct instance of this class with given parameters.
+ *
+ * @param metaoptimizer Metaoptimizer.
+ * @param trainer {@link MetaoptimizerGroupTrainer} for which this distributed step is used.
+ */
+ MetaoptimizerDistributedStep(Metaoptimizer<L, X, Y, I, D, O> metaoptimizer,
+ MetaoptimizerGroupTrainer<L, K, V, D, ?, I, ?, ?, G, O, X, Y> trainer) {
+ this.metaoptimizer = metaoptimizer;
+ this.trainer = trainer;
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteSupplier<G> remoteContextSupplier(I input, L locCtx) {
+ return trainer.remoteContextExtractor(input, locCtx);
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteFunction<EntryAndContext<K, V, G>, ResultAndUpdates<O>> worker() {
+ IgniteFunction<X, ResultAndUpdates<Y>> dataProcessor = trainer.dataProcessor();
+ IgniteFunction<X, X> preprocessor = metaoptimizer.distributedPreprocessor();
+ IgniteFunction<Y, O> postprocessor = metaoptimizer.distributedPostprocessor();
+ IgniteFunction<EntryAndContext<K, V, G>, X> ctxExtractor = trainer.trainingLoopStepDataExtractor();
+
+ return entryAndCtx -> {
+ X apply = ctxExtractor.apply(entryAndCtx);
+ preprocessor.apply(apply);
+ ResultAndUpdates<Y> res = dataProcessor.apply(apply);
+ O postprocessRes = postprocessor.apply(res.result());
+
+ return ResultAndUpdates.of(postprocessRes).setUpdates(res.updates());
+ };
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keys(I input, L locCtx) {
+ return trainer.keysToProcessInTrainingLoop(locCtx);
+ }
+
+ /** {@inheritDoc} */
+ @Override public O identity() {
+ return metaoptimizer.postProcessIdentity();
+ }
+
+ /** {@inheritDoc} */
+ @Override public IgniteBinaryOperator<O> reducer() {
+ return metaoptimizer.postProcessReducer();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java
new file mode 100644
index 0000000..310ff94
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerGroupTrainer.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.io.Serializable;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trainers.group.chain.Chains;
+import org.apache.ignite.ml.trainers.group.chain.ComputationsChain;
+import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
+import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
+
+/**
+ * Group trainer using {@link Metaoptimizer}.
+ * Main purpose of this trainer is to extract various transformations (normalizations for example) of data which is
+ * processed in the training loop step into distinct entity called metaoptimizer and only fix the main part of logic in
+ * trainers extending this class. This way we'll be able to quickly switch between this transformations by using
+ * different metaoptimizers without touching main logic.
+ *
+ * @param <LC> Type of local context.
+ * @param <K> Type of keys of cache used in group training.
+ * @param <V> Type of values of cache used in group training.
+ * @param <IN> Data type which is returned by distributed initializer.
+ * @param <R> Type of final result returned by nodes on which training is done.
+ * @param <I> Type of data which is fed into each training loop step and returned from it.
+ * @param <M> Type of model returned after training.
+ * @param <T> Type of input of this trainer.
+ * @param <G> Type of distributed context which is needed for forming final result which is send from each node
+ * to trainer for final model creation.
+ * @param <O> Type of output of postprocessor.
+ * @param <X> Type of data which is processed by dataProcessor.
+ * @param <Y> Type of data which is returned by postprocessor.
+ */
+public abstract class MetaoptimizerGroupTrainer<LC extends HasTrainingUUID, K, V, IN extends Serializable,
+ R extends Serializable, I extends Serializable,
+ M extends Model, T extends GroupTrainerInput<K>,
+ G, O extends Serializable, X, Y> extends
+ GroupTrainer<LC, K, V, IN, R, I, M, T, G> {
+ /**
+ * Metaoptimizer.
+ */
+ private Metaoptimizer<LC, X, Y, I, IN, O> metaoptimizer;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param metaoptimizer Metaoptimizer.
+ * @param cache Cache on which group trainer is done.
+ * @param ignite Ignite instance.
+ */
+ public MetaoptimizerGroupTrainer(Metaoptimizer<LC, X, Y, I, IN, O> metaoptimizer,
+ IgniteCache<GroupTrainerCacheKey<K>, V> cache,
+ Ignite ignite) {
+ super(cache, ignite);
+ this.metaoptimizer = metaoptimizer;
+ }
+
+ /**
+ * Get function used to map EntryAndContext to type which is processed by dataProcessor.
+ *
+ * @return Function used to map EntryAndContext to type which is processed by dataProcessor.
+ */
+ abstract IgniteFunction<EntryAndContext<K, V, G>, X> trainingLoopStepDataExtractor();
+
+ /**
+ * Get supplier of keys which should be processed by training loop.
+ *
+ * @param locCtx Local text.
+ * @return Supplier of keys which should be processed by training loop.
+ */
+ protected abstract IgniteSupplier<Stream<GroupTrainerCacheKey<K>>> keysToProcessInTrainingLoop(LC locCtx);
+
+ /**
+ * Get supplier of context used in training loop step.
+ *
+ * @param input Input.
+ * @param ctx Local context.
+ * @return Supplier of context used in training loop step.
+ */
+ protected abstract IgniteSupplier<G> remoteContextExtractor(I input, LC ctx);
+
+ /**
+ * Get function used to process data in training loop step.
+ *
+ * @return Function used to process data in training loop step.
+ */
+ protected abstract IgniteFunction<X, ResultAndUpdates<Y>> dataProcessor();
+
+ /** {@inheritDoc} */
+ @Override protected ComputationsChain<LC, K, V, I, I> trainingLoopStep() {
+ ComputationsChain<LC, K, V, I, O> chain = Chains.create(new MetaoptimizerDistributedStep<>(metaoptimizer, this));
+ return chain.thenLocally(metaoptimizer::localProcessor);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected I locallyProcessInitData(IN data, LC locCtx) {
+ return metaoptimizer.locallyProcessInitData(data, locCtx);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean shouldContinue(I data, LC locCtx) {
+ return metaoptimizer.shouldContinue(data, locCtx);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected IgniteBinaryOperator<IN> reduceDistributedInitData() {
+ return metaoptimizer.initialReducer();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java
new file mode 100644
index 0000000..411cc42
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/ResultAndUpdates.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
+
+/**
+ * Class containing result of computation and updates which should be made for caches.
+ * Purpose of this class is mainly performance optimization: suppose we have multiple computations which run in parallel
+ * and do some updates to caches. It is more efficient to collect all changes from all this computations and perform them
+ * in batch.
+ *
+ * @param <R> Type of computation result.
+ */
+public class ResultAndUpdates<R> {
+ /**
+ * Result of computation.
+ */
+ private R res;
+
+ /**
+ * Updates in the form cache name -> (key -> new value).
+ */
+ private Map<String, Map> updates = new ConcurrentHashMap<>();
+
+ /**
+ * Construct an instance of this class.
+ *
+ * @param res Computation result.
+ */
+ public ResultAndUpdates(R res) {
+ this.res = res;
+ }
+
+ /**
+ * Construct an instance of this class.
+ *
+ * @param res Computation result.
+ * @param updates Map of updates in the form cache name -> (key -> new value).
+ */
+ ResultAndUpdates(R res, Map<String, Map> updates) {
+ this.res = res;
+ this.updates = updates;
+ }
+
+ /**
+ * Construct an empty result.
+ *
+ * @param <R> Result type.
+ * @return Empty result.
+ */
+ public static <R> ResultAndUpdates<R> empty() {
+ return new ResultAndUpdates<>(null);
+ }
+
+ /**
+ * Construct {@link ResultAndUpdates} object from given result.
+ *
+ * @param res Result of computation.
+ * @param <R> Type of result of computation.
+ * @return ResultAndUpdates object.
+ */
+ public static <R> ResultAndUpdates<R> of(R res) {
+ return new ResultAndUpdates<>(res);
+ }
+
+ /**
+ * Add a cache update to this object.
+ *
+ * @param cache Cache to be updated.
+ * @param key Key of cache to be updated.
+ * @param val New value.
+ * @param <K> Type of key of cache to be updated.
+ * @param <V> New value.
+ */
+ @SuppressWarnings("unchecked")
+ public <K, V> void update(IgniteCache<K, V> cache, K key, V val) {
+ String name = cache.getName();
+
+ updates.computeIfAbsent(name, s -> new ConcurrentHashMap());
+ updates.get(name).put(key, val);
+ }
+
+ /**
+ * Get result of computation.
+ *
+ * @return Result of computation.
+ */
+ public R result() {
+ return res;
+ }
+
+ /**
+ * Sum collection of ResultAndUpdate into one: results are reduced by specified binary operator and updates are merged.
+ *
+ * @param op Binary operator used to combine computation results.
+ * @param identity Identity for op.
+ * @param resultsAndUpdates ResultAndUpdates to be combined with.
+ * @param <R> Type of computation result.
+ * @return Sum of collection ResultAndUpdate objects.
+ */
+ @SuppressWarnings("unchecked")
+ static <R> ResultAndUpdates<R> sum(IgniteBinaryOperator<R> op, R identity,
+ Collection<ResultAndUpdates<R>> resultsAndUpdates) {
+ Map<String, Map> allUpdates = new HashMap<>();
+
+ for (ResultAndUpdates<R> ru : resultsAndUpdates) {
+ for (String cacheName : ru.updates.keySet()) {
+ allUpdates.computeIfAbsent(cacheName, s -> new HashMap());
+
+ allUpdates.get(cacheName).putAll(ru.updates.get(cacheName));
+ }
+ }
+
+ R res = resultsAndUpdates.stream().map(ResultAndUpdates::result).reduce(op).orElse(identity);
+
+ return new ResultAndUpdates<>(res, allUpdates);
+ }
+
+ /**
+ * Get updates map.
+ *
+ * @return Updates map.
+ */
+ public Map<String, Map> updates() {
+ return updates;
+ }
+
+ /**
+ * Set updates map.
+ *
+ * @param updates New updates map.
+ * @return This object.
+ */
+ ResultAndUpdates<R> setUpdates(Map<String, Map> updates) {
+ this.updates = updates;
+ return this;
+ }
+
+ /**
+ * Apply updates to caches.
+ *
+ * @param ignite Ignite instance.
+ */
+ void applyUpdates(Ignite ignite) {
+ for (Map.Entry<String, Map> entry : updates.entrySet()) {
+ IgniteCache<Object, Object> cache = ignite.getOrCreateCache(entry.getKey());
+
+ cache.putAll(entry.getValue());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b0c5ef1e/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java
new file mode 100644
index 0000000..db4f13f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/Chains.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.group.chain;
+
+import java.io.Serializable;
+
+/**
+ * Class containing methods creating {@link ComputationsChain}.
+ */
+public class Chains {
+ /**
+ * Create computation chain consisting of one returning its input as output.
+ *
+ * @param <L> Type of local context of created chain.
+ * @param <K> Type of keys of cache used in computation chain.
+ * @param <V> Type of values of cache used in computation chain.
+ * @param <I> Type of input to computation chain.
+ * @return Computation chain consisting of one returning its input as output.
+ */
+ public static <L extends HasTrainingUUID, K, V, I> ComputationsChain<L, K, V, I, I> create() {
+ return (input, context) -> input;
+ }
+
+ /**
+ * Create {@link ComputationsChain} from {@link DistributedEntryProcessingStep}.
+ *
+ * @param step Distributed chain step.
+ * @param <L> Type of local context of created chain.
+ * @param <K> Type of keys of cache used in computation chain.
+ * @param <V> Type of values of cache used in computation chain.
+ * @param <C> Type of context used by worker in {@link DistributedEntryProcessingStep}.
+ * @param <I> Type of input to computation chain.
+ * @param <O> Type of output of computation chain.
+ * @return Computation created from {@link DistributedEntryProcessingStep}.
+ */
+ public static <L extends HasTrainingUUID, K, V, C, I, O extends Serializable> ComputationsChain<L, K, V, I, O> create(
+ DistributedEntryProcessingStep<L, K, V, C, I, O> step) {
+ ComputationsChain<L, K, V, I, I> chain = create();
+ return chain.thenDistributedForEntries(step);
+ }
+}