You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/11/27 11:05:24 UTC
ignite git commit: IGNITE-9284: [ML] Add a Standard Scaler
Repository: ignite
Updated Branches:
refs/heads/master 46a84fddb -> 41f4225c4
IGNITE-9284: [ML] Add a Standard Scaler
this closes #4964
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/41f4225c
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/41f4225c
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/41f4225c
Branch: refs/heads/master
Commit: 41f4225c4b2f2735bce4ce861b9a51afc80d5815
Parents: 46a84fd
Author: Ravil Galeyev <de...@yandex.ru>
Authored: Tue Nov 27 14:05:17 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Tue Nov 27 14:05:17 2018 +0300
----------------------------------------------------------------------
.../ml/preprocessing/StandardScalerExample.java | 84 +++++++++++++++
.../standardscaling/StandardScalerData.java | 56 ++++++++++
.../StandardScalerPreprocessor.java | 91 +++++++++++++++++
.../standardscaling/StandardScalerTrainer.java | 101 +++++++++++++++++++
.../standardscaling/package-info.java | 22 ++++
.../StandardScalerPreprocessorTest.java | 59 +++++++++++
.../StandardScalerTrainerTest.java | 85 ++++++++++++++++
7 files changed, 498 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
new file mode 100644
index 0000000..13d8635
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.preprocessing;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.examples.ml.dataset.model.Person;
+import org.apache.ignite.examples.ml.util.DatasetHelper;
+import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.preprocessing.standardscaling.StandardScalerTrainer;
+
+/**
+ * Example that shows how to use StandardScaler preprocessor to scale the given data.
+ *
+ * Machine learning preprocessors are built as a chain. Most often the first preprocessor is a feature extractor as
+ * shown in this example. The second preprocessor here is a {@code StandardScaler} preprocessor which is built on top of
+ * the feature extractor and represents a chain of itself and the underlying feature extractor.
+ */
+public class StandardScalerExample {
+ /** Run example. */
+ public static void main(String[] args) throws Exception {
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Standard scaler example started.");
+
+ IgniteCache<Integer, Person> persons = createCache(ignite);
+
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
+
+ // Defines second preprocessor that processes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new StandardScalerTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
+
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
+ new DatasetHelper(dataset).describe();
+ }
+
+ System.out.println(">>> Standard scaler example completed.");
+ }
+ }
+
+ /** */
+ private static IgniteCache<Integer, Person> createCache(Ignite ignite) {
+ CacheConfiguration<Integer, Person> cacheConfiguration = new CacheConfiguration<>();
+
+ cacheConfiguration.setName("PERSONS");
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
+
+ IgniteCache<Integer, Person> persons = ignite.createCache(cacheConfiguration);
+
+ persons.put(1, new Person("Mike", 42, 10000));
+ persons.put(2, new Person("John", 32, 64000));
+ persons.put(3, new Person("George", 53, 120000));
+ persons.put(4, new Person("Karl", 24, 70000));
+
+ return persons;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java
new file mode 100644
index 0000000..f96dcc5
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerData.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.ml.preprocessing.standardscaling;
+
+/** A Service class for {@link StandardScalerTrainer} which used for sums holing. */
+public class StandardScalerData implements AutoCloseable {
+ /** Sum values of every feature. */
+ double[] sum;
+ /** Sum of squared values of every feature. */
+ double[] squaredSum;
+ /** Rows count */
+ long cnt;
+
+ /**
+ * Creates {@code StandardScalerData}.
+ *
+ * @param sum Sum values of every feature.
+ * @param squaredSum Sum of squared values of every feature.
+ * @param cnt Rows count.
+ */
+ public StandardScalerData(double[] sum, double[] squaredSum, long cnt) {
+ this.sum = sum;
+ this.squaredSum = squaredSum;
+ this.cnt = cnt;
+ }
+
+ /** Merges to current. */
+ StandardScalerData merge(StandardScalerData that) {
+ for (int i = 0; i < sum.length; i++) {
+ sum[i] += that.sum[i];
+ squaredSum[i] += that.squaredSum[i];
+ }
+
+ cnt += that.cnt;
+ return this;
+ }
+
+ /** */
+ @Override public void close() {
+ // Do nothing, GC will clean up.
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java
new file mode 100644
index 0000000..293e86a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessor.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.standardscaling;
+
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * The preprocessing function that makes standard scaling, transforms features to make {@code mean} equal to {@code 0}
+ * and {@code variance} equal to {@code 1}. From mathematical point of view it's the following function which is applied
+ * to every element in a dataset:
+ *
+ * {@code a_i = (a_i - mean_i) / sigma_i for all i},
+ *
+ * where {@code i} is a number of column, {@code mean_i} is the mean value this column and {@code sigma_i} is the
+ * standard deviation in this column.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class StandardScalerPreprocessor<K, V> implements IgniteBiFunction<K, V, Vector> {
+ /** */
+ private static final long serialVersionUID = -5977957318991608203L;
+
+ /** Means for each column. */
+ private final double[] means;
+ /** Standard deviation for each column. */
+ private final double[] sigmas;
+
+ /** Base preprocessor. */
+ private final IgniteBiFunction<K, V, Vector> basePreprocessor;
+
+ /**
+ * Constructs a new instance of standardscaling preprocessor.
+ *
+ * @param means Means of each column.
+ * @param sigmas Standard deviations in each column.
+ * @param basePreprocessor Base preprocessor.
+ */
+ public StandardScalerPreprocessor(double[] means, double[] sigmas,
+ IgniteBiFunction<K, V, Vector> basePreprocessor) {
+ assert means.length == sigmas.length;
+
+ this.means = means;
+ this.sigmas = sigmas;
+ this.basePreprocessor = basePreprocessor;
+ }
+
+ /**
+ * Applies this preprocessor.
+ *
+ * @param k Key.
+ * @param v Value.
+ * @return Preprocessed row.
+ */
+ @Override public Vector apply(K k, V v) {
+ Vector res = basePreprocessor.apply(k, v);
+
+ assert res.size() == means.length;
+
+ for (int i = 0; i < res.size(); i++)
+ res.set(i, (res.get(i) - means[i]) / sigmas[i]);
+
+ return res;
+ }
+
+ /** */
+ public double[] getMeans() {
+ return means;
+ }
+
+ /** */
+ public double[] getSigmas() {
+ return sigmas;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java
new file mode 100644
index 0000000..3661772
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.standardscaling;
+
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
+
+/**
+ * Trainer of the standard scaler preprocessor.
+ *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class StandardScalerTrainer<K, V> implements PreprocessingTrainer<K, V, Vector, Vector> {
+ /** {@inheritDoc} */
+ @Override public StandardScalerPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> basePreprocessor) {
+ StandardScalerData standardScalerData = computeSum(datasetBuilder, basePreprocessor);
+
+ int n = standardScalerData.sum.length;
+ long cnt = standardScalerData.cnt;
+ double[] mean = new double[n];
+ double[] sigma = new double[n];
+
+ for (int i = 0; i < n; i++) {
+ mean[i] = standardScalerData.sum[i] / cnt;
+ double variace = (standardScalerData.squaredSum[i] - Math.pow(standardScalerData.sum[i], 2) / cnt) / cnt;
+ sigma[i] = Math.sqrt(variace);
+ }
+ return new StandardScalerPreprocessor<>(mean, sigma, basePreprocessor);
+ }
+
+ /** Computes sum, squared sum and row count. */
+ private StandardScalerData computeSum(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> basePreprocessor) {
+ try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
+ (upstream, upstreamSize, ctx) -> {
+ double[] sum = null;
+ double[] squaredSum = null;
+ long cnt = 0;
+
+ while (upstream.hasNext()) {
+ UpstreamEntry<K, V> entity = upstream.next();
+ Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
+
+ if (sum == null) {
+ sum = new double[row.size()];
+ squaredSum = new double[row.size()];
+ }
+ else {
+ assert sum.length == row.size() : "Base preprocessor must return exactly " + sum.length
+ + " features";
+ }
+
+ ++cnt;
+ for (int i = 0; i < row.size(); i++) {
+ double x = row.get(i);
+ sum[i] += x;
+ squaredSum[i] += x * x;
+ }
+ }
+ return new StandardScalerData(sum, squaredSum, cnt);
+ }
+ )) {
+
+ return dataset.compute(data -> data,
+ (a, b) -> {
+ if (a == null)
+ return b;
+ if (b == null)
+ return a;
+
+ return a.merge(b);
+ });
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java
new file mode 100644
index 0000000..5f5de3b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Standard scaler preprocessor.
+ */
+package org.apache.ignite.ml.preprocessing.standardscaling;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java
new file mode 100644
index 0000000..3c325b3
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerPreprocessorTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.standardscaling;
+
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link StandardScalerPreprocessor}.
+ */
+public class StandardScalerPreprocessorTest {
+
+ /** Test {@code apply()} method. */
+ @Test
+ public void testApply() {
+ double[][] inputData = new double[][] {
+ {0, 2., 4., .1},
+ {0, 1., -18., 2.2},
+ {1, 4., 10., -.1},
+ {1, 0., 22., 1.3}
+ };
+ double[] means = new double[] {0.5, 1.75, 4.5, 0.875};
+ double[] sigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247};
+
+ StandardScalerPreprocessor<Integer, Vector> preprocessor = new StandardScalerPreprocessor<>(
+ means,
+ sigmas,
+ (k, v) -> v
+ );
+
+ double[][] expectedData = new double[][] {
+ {-1., 0.16903085, -0.03444183, -0.82999331},
+ {-1., -0.50709255, -1.54988233, 1.41902081},
+ {1., 1.52127766, 0.37886012, -1.04418513},
+ {1., -1.18321596, 1.20546403, 0.45515762}
+ };
+
+ for (int i = 0; i < inputData.length; i++)
+ assertArrayEquals(expectedData[i], preprocessor.apply(i, VectorUtils.of(inputData[i])).asArray(), 1e-8);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/41f4225c/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java
new file mode 100644
index 0000000..679cc48
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainerTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.standardscaling;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link StandardScalerTrainer}.
+ */
+public class StandardScalerTrainerTest extends TrainerTest {
+
+ /** Data. */
+ private DatasetBuilder<Integer, Vector> datasetBuilder;
+
+ /** Trainer to be tested. */
+ private StandardScalerTrainer<Integer, Vector> standardizationTrainer;
+
+ /** */
+ @Before
+ public void prepareDataset() {
+ Map<Integer, Vector> data = new HashMap<>();
+ data.put(1, VectorUtils.of(0, 2., 4., .1));
+ data.put(2, VectorUtils.of(0, 1., -18., 2.2));
+ data.put(3, VectorUtils.of(1, 4., 10., -.1));
+ data.put(4, VectorUtils.of(1, 0., 22., 1.3));
+ datasetBuilder = new LocalDatasetBuilder<>(data, parts);
+ }
+
+ /** */
+ @Before
+ public void createTrainer() {
+ standardizationTrainer = new StandardScalerTrainer<>();
+ }
+
+ /** Test {@code fit()} method. */
+ @Test
+ public void testCalculatesCorrectMeans() {
+ double[] expectedMeans = new double[] {0.5, 1.75, 4.5, 0.875};
+
+ StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+ datasetBuilder,
+ (k, v) -> v
+ );
+
+ assertArrayEquals(expectedMeans, preprocessor.getMeans(), 1e-8);
+ }
+
+ /** Test {@code fit()} method. */
+ @Test
+ public void testCalculatesCorrectStandardDeviations() {
+ double[] expectedSigmas = new double[] {0.5, 1.47901995, 14.51723114, 0.93374247};
+
+ StandardScalerPreprocessor<Integer, Vector> preprocessor = standardizationTrainer.fit(
+ datasetBuilder,
+ (k, v) -> v
+ );
+
+ assertArrayEquals(expectedSigmas, preprocessor.getSigmas(), 1e-8);
+ }
+}