You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/24 14:12:47 UTC
[ignite] branch master updated: IGNITE-10834: [ML] Add NamedVectors
to replace HashMap in Model
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new c56801e IGNITE-10834: [ML] Add NamedVectors to replace HashMap in Model
c56801e is described below
commit c56801ed7a9e9d65dad51b2bb1b9b9074b07c901
Author: Anton Dmitriev <dm...@gmail.com>
AuthorDate: Thu Jan 24 17:12:28 2019 +0300
IGNITE-10834: [ML] Add NamedVectors to replace HashMap in Model
This closed #5881
---
.../ml/xgboost/XGBoostModelParserExample.java | 6 +-
.../examples/ml/mleap/MLeapModelParserExample.java | 6 +-
.../org/apache/ignite/ml/mleap/MLeapModel.java | 19 +++---
.../apache/ignite/ml/mleap/MLeapModelParser.java | 4 +-
.../ignite/ml/mleap/MLeapModelParserTest.java | 3 +-
.../ml/math/primitives/vector/NamedVector.java | 50 ++++++++++++++
.../ml/math/primitives/vector/VectorUtils.java | 25 +++++++
.../vector/impl/DelegatingNamedVector.java | 78 ++++++++++++++++++++++
.../ignite/ml/xgboost/XGModelComposition.java | 29 +++++---
.../ignite/ml/xgboost/parser/XGModelParser.java | 4 +-
.../ml/xgboost/parser/XGBoostModelParserTest.java | 3 +-
11 files changed, 198 insertions(+), 29 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java
index 0ec05c5..5f75310 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/xgboost/XGBoostModelParserExample.java
@@ -31,6 +31,8 @@ import org.apache.ignite.ml.inference.builder.AsyncModelBuilder;
import org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder;
import org.apache.ignite.ml.inference.reader.FileSystemModelReader;
import org.apache.ignite.ml.inference.reader.ModelReader;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.xgboost.parser.XGModelParser;
/**
@@ -69,7 +71,7 @@ public class XGBoostModelParserExample {
if (testExpRes == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]");
- try (Model<HashMap<String, Double>, Future<Double>> mdl = mdlBuilder.build(reader, parser);
+ try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(testData);
Scanner testExpResultsScanner = new Scanner(testExpRes)) {
@@ -86,7 +88,7 @@ public class XGBoostModelParserExample {
testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
}
- double prediction = mdl.predict(testObj).get();
+ double prediction = mdl.predict(VectorUtils.of(testObj)).get();
double expPrediction = Double.parseDouble(testExpResultsStr);
diff --git a/examples/src/main/spark/org/apache/ignite/examples/ml/mleap/MLeapModelParserExample.java b/examples/src/main/spark/org/apache/ignite/examples/ml/mleap/MLeapModelParserExample.java
index 79958dd..2462bfd 100644
--- a/examples/src/main/spark/org/apache/ignite/examples/ml/mleap/MLeapModelParserExample.java
+++ b/examples/src/main/spark/org/apache/ignite/examples/ml/mleap/MLeapModelParserExample.java
@@ -29,6 +29,8 @@ import org.apache.ignite.ml.inference.builder.AsyncModelBuilder;
import org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder;
import org.apache.ignite.ml.inference.reader.FileSystemModelReader;
import org.apache.ignite.ml.inference.reader.ModelReader;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.mleap.MLeapModelParser;
/**
@@ -53,7 +55,7 @@ public class MLeapModelParserExample {
AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 4, 4);
- try (Model<HashMap<String, Double>, Future<Double>> mdl = mdlBuilder.build(reader, parser)) {
+ try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser)) {
HashMap<String, Double> input = new HashMap<>();
input.put("bathrooms", 1.0);
input.put("bedrooms", 1.0);
@@ -64,7 +66,7 @@ public class MLeapModelParserExample {
input.put("square_feet", 1.0);
input.put("review_scores_rating", 1.0);
- Future<Double> prediction = mdl.predict(input);
+ Future<Double> prediction = mdl.predict(VectorUtils.of(input));
System.out.println("Predicted price: " + prediction.get());
}
diff --git a/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModel.java b/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModel.java
index 2ebd8c0..fe33a5f 100644
--- a/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModel.java
+++ b/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModel.java
@@ -18,7 +18,6 @@
package org.apache.ignite.ml.mleap;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -31,6 +30,8 @@ import ml.combust.mleap.runtime.frame.Row;
import ml.combust.mleap.runtime.frame.Transformer;
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder;
import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.util.Try;
@@ -38,7 +39,7 @@ import scala.util.Try;
/**
* MLeap model imported and wrapped to be compatible with Apache Ignite infrastructure.
*/
-public class MLeapModel implements Model<HashMap<String, Double>, Double> {
+public class MLeapModel implements Model<NamedVector, Double> {
/** MLeap model (transformer in terms of MLeap). */
private final Transformer transformer;
@@ -57,23 +58,25 @@ public class MLeapModel implements Model<HashMap<String, Double>, Double> {
*/
public MLeapModel(Transformer transformer, List<String> schema, String outputFieldName) {
this.transformer = transformer;
- this.schema = schema;
+ this.schema = new ArrayList<>(schema);
this.outputFieldName = outputFieldName;
}
- // TODO: IGNITE-10834 Add NamedVectors to replace HashMap in Model.
/** {@inheritDoc} */
- @Override public Double predict(HashMap<String, Double> input) {
+ @Override public Double predict(NamedVector input) {
LeapFrameBuilder builder = new LeapFrameBuilder();
List<StructField> structFields = new ArrayList<>();
- for (String fieldName : input.keySet())
+ List<Object> values = new ArrayList<>();
+ for (String fieldName : input.getKeys()) {
structFields.add(new StructField(fieldName, ScalarType.Double()));
+ values.add(input.get(fieldName));
+ }
StructType schema = builder.createSchema(structFields);
List<Row> rows = new ArrayList<>();
- rows.add(builder.createRowFromIterable(new ArrayList<>(input.values())));
+ rows.add(builder.createRowFromIterable(values));
DefaultLeapFrame inputFrame = builder.createFrame(schema, rows);
@@ -94,7 +97,7 @@ public class MLeapModel implements Model<HashMap<String, Double>, Double> {
.boxed()
.collect(Collectors.toMap(schema::get, i -> input[i]));
- return predict(new HashMap<>(vec));
+ return predict(VectorUtils.of(vec));
}
/**
diff --git a/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModelParser.java b/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModelParser.java
index d3a1d81..7f7c258 100644
--- a/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModelParser.java
+++ b/modules/ml/mleap-model-parser/src/main/java/org/apache/ignite/ml/mleap/MLeapModelParser.java
@@ -21,7 +21,6 @@ import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.List;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.StructField;
@@ -32,12 +31,13 @@ import ml.combust.mleap.runtime.javadsl.BundleBuilder;
import ml.combust.mleap.runtime.javadsl.ContextBuilder;
import ml.combust.mleap.runtime.transformer.PipelineModel;
import org.apache.ignite.ml.inference.parser.ModelParser;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import scala.collection.JavaConverters;
/**
* MLeap model parser.
*/
-public class MLeapModelParser implements ModelParser<HashMap<String, Double>, Double, MLeapModel> {
+public class MLeapModelParser implements ModelParser<NamedVector, Double, MLeapModel> {
/** */
private static final long serialVersionUID = -370352744966205715L;
diff --git a/modules/ml/mleap-model-parser/src/test/java/org/apache/ignite/ml/mleap/MLeapModelParserTest.java b/modules/ml/mleap-model-parser/src/test/java/org/apache/ignite/ml/mleap/MLeapModelParserTest.java
index 989c48b..12f6ee1 100644
--- a/modules/ml/mleap-model-parser/src/test/java/org/apache/ignite/ml/mleap/MLeapModelParserTest.java
+++ b/modules/ml/mleap-model-parser/src/test/java/org/apache/ignite/ml/mleap/MLeapModelParserTest.java
@@ -23,6 +23,7 @@ import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
import org.apache.ignite.ml.inference.builder.SyncModelBuilder;
import org.apache.ignite.ml.inference.reader.FileSystemModelReader;
import org.apache.ignite.ml.inference.reader.ModelReader;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -60,7 +61,7 @@ public class MLeapModelParserTest {
input.put("imp_square_feet", 1.0);
input.put("imp_review_scores_rating", 1.0);
- double prediction = mdl.predict(input);
+ double prediction = mdl.predict(VectorUtils.of(input));
assertEquals(95.3919, prediction, 1e-5);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/NamedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/NamedVector.java
new file mode 100644
index 0000000..6d45eed
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/NamedVector.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.primitives.vector;
+
+import java.util.Set;
+
+/**
+ * A named vector interface based on {@link Vector}. In addition to base vector functionality allows to set and get
+ * elements using names as index.
+ */
+public interface NamedVector extends Vector {
+ /**
+ * Returns element with specified string index.
+ *
+ * @param idx Element string index.
+ * @return Element value.
+ */
+ public double get(String idx);
+
+ /**
+ * Sets element with specified string index and value.
+ *
+ * @param idx Element string index.
+ * @param val Element value.
+ * @return This vector.
+ */
+ public NamedVector set(String idx, double val);
+
+ /**
+ * Returns list of string indexes used in this vector.
+ *
+ * @return List of string indexes used in this vector.
+ */
+ public Set<String> getKeys();
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
index eaf7f91..9525f60 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
@@ -18,10 +18,13 @@
package org.apache.ignite.ml.math.primitives.vector;
import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Objects;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.impl.DelegatingNamedVector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
@@ -219,6 +222,28 @@ public class VectorUtils {
}
/**
+ * Creates named vector based on map of keys and values.
+ *
+ * @param values Values.
+ * @return Named vector.
+ */
+ public static NamedVector of(Map<String, Double> values) {
+ SparseVector vector = new SparseVector(values.size(), StorageConstants.RANDOM_ACCESS_MODE);
+ for (int i = 0; i < values.size(); i++)
+ vector.set(i, Double.NaN);
+
+ Map<String, Integer> dict = new HashMap<>();
+ int idx = 0;
+ for (Map.Entry<String, Double> e : values.entrySet()) {
+ dict.put(e.getKey(), idx);
+ vector.set(idx, e.getValue());
+ idx++;
+ }
+
+ return new DelegatingNamedVector(vector, dict);
+ }
+
+ /**
* Concatenates two given vectors.
*
* @param v1 First vector.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingNamedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingNamedVector.java
new file mode 100644
index 0000000..afc5f09
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingNamedVector.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math.primitives.vector.impl;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.omg.CORBA.NamedValue;
+
+/**
+ * Delegating named vector that delegates all operations to underlying vector and adds implementation of
+ * {@link NamedValue} functionality using embedded map that maps string index on real integer index.
+ */
+public class DelegatingNamedVector extends DelegatingVector implements NamedVector {
+ /** */
+ private static final long serialVersionUID = -3425468245964928754L;
+
+ /** Map that maps string index on real integer index. */
+ private final Map<String, Integer> map;
+
+ /**
+ * Constructs a new instance of delegating named vector.
+ */
+ public DelegatingNamedVector() {
+ this.map = Collections.emptyMap();
+ }
+
+ /**
+ * Constructs a new instance of delegating named vector.
+ *
+ * @param vector Underlying vector.
+ * @param map Map that maps string index on real integer index.
+ */
+ public DelegatingNamedVector(Vector vector, Map<String, Integer> map) {
+ super(vector);
+
+ this.map = Objects.requireNonNull(map);
+ }
+
+ /** {@inheritDoc} */
+ @Override public double get(String idx) {
+ int intIdx = Objects.requireNonNull(map.get(idx), "Index not found [name='" + idx + "']");
+
+ return get(intIdx);
+ }
+
+ /** {@inheritDoc} */
+ @Override public NamedVector set(String idx, double val) {
+ int intIdx = Objects.requireNonNull(map.get(idx), "Index not found [name='" + idx + "']");
+
+ set(intIdx, val);
+
+ return this;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Set<String> getKeys() {
+ return map.keySet();
+ }
+}
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java
index 8001fcd..89e83be 100644
--- a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java
+++ b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/XGModelComposition.java
@@ -17,12 +17,14 @@
package org.apache.ignite.ml.xgboost;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -32,7 +34,10 @@ import static org.apache.ignite.ml.math.StorageConstants.RANDOM_ACCESS_MODE;
/**
* XGBoost model composition.
*/
-public class XGModelComposition implements IgniteModel<HashMap<String, Double>, Double> {
+public class XGModelComposition implements IgniteModel<NamedVector, Double> {
+ /** */
+ private static final long serialVersionUID = 6765344479174942051L;
+
/** Dictionary used for matching feature names and indexes. */
private final Map<String, Integer> dict;
@@ -45,18 +50,18 @@ public class XGModelComposition implements IgniteModel<HashMap<String, Double>,
* @param models Basic models.
*/
public XGModelComposition(Map<String, Integer> dict, List<DecisionTreeNode> models) {
- this.dict = dict;
+ this.dict = new HashMap<>(dict);
this.modelsComposition = new ModelsComposition(models, new XGModelPredictionsAggregator());
}
/** {@inheritDoc} */
- @Override public Double predict(HashMap<String, Double> map) {
- return modelsComposition.predict(toVector(map));
+ @Override public Double predict(NamedVector input) {
+ return modelsComposition.predict(reencode(input));
}
/** */
public Map<String, Integer> getDict() {
- return dict;
+ return Collections.unmodifiableMap(dict);
}
/** */
@@ -72,20 +77,19 @@ public class XGModelComposition implements IgniteModel<HashMap<String, Double>,
/**
* Converts hash map into sparse vector using dictionary.
*
- * @param input Hash map with pairs of feature name and feature value.
+ * @param vector Named vector.
* @return Sparse vector.
*/
- private Vector toVector(Map<String, Double> input) {
+ private Vector reencode(NamedVector vector) {
Vector inputVector = new SparseVector(dict.size(), RANDOM_ACCESS_MODE);
for (int i = 0; i < dict.size(); i++)
inputVector.set(i, Double.NaN);
- for (Map.Entry<String, Double> feature : input.entrySet()) {
- Integer idx = dict.get(feature.getKey());
+ for (String key : vector.getKeys()) {
+ Integer idx = dict.get(key);
if (idx != null)
- inputVector.set(idx, feature.getValue());
-
+ inputVector.set(idx, vector.get(key));
}
return inputVector;
@@ -95,6 +99,9 @@ public class XGModelComposition implements IgniteModel<HashMap<String, Double>,
* XG model predictions aggregator.
*/
private static class XGModelPredictionsAggregator implements PredictionsAggregator {
+ /** */
+ private static final long serialVersionUID = 1274109586500815229L;
+
/** {@inheritDoc} */
@Override public Double apply(double[] predictions) {
double res = 0;
diff --git a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/XGModelParser.java b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/XGModelParser.java
index c99b1ae..782a13e 100644
--- a/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/XGModelParser.java
+++ b/modules/ml/xgboost-model-parser/src/main/java/org/apache/ignite/ml/xgboost/parser/XGModelParser.java
@@ -19,11 +19,11 @@ package org.apache.ignite.ml.xgboost.parser;
import java.io.ByteArrayInputStream;
import java.io.IOException;
-import java.util.HashMap;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.apache.ignite.ml.inference.parser.ModelParser;
+import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.apache.ignite.ml.xgboost.XGModelComposition;
import org.apache.ignite.ml.xgboost.parser.visitor.XGModelVisitor;
@@ -63,7 +63,7 @@ import org.apache.ignite.ml.xgboost.parser.visitor.XGModelVisitor;
* xgModel : xgTree+ ;
* </pre>
*/
-public class XGModelParser implements ModelParser<HashMap<String, Double>, Double, XGModelComposition> {
+public class XGModelParser implements ModelParser<NamedVector, Double, XGModelComposition> {
/** */
private static final long serialVersionUID = -5819843559270294718L;
diff --git a/modules/ml/xgboost-model-parser/src/test/java/org/apache/ignite/ml/xgboost/parser/XGBoostModelParserTest.java b/modules/ml/xgboost-model-parser/src/test/java/org/apache/ignite/ml/xgboost/parser/XGBoostModelParserTest.java
index 6c8a65f..7a930d6 100644
--- a/modules/ml/xgboost-model-parser/src/test/java/org/apache/ignite/ml/xgboost/parser/XGBoostModelParserTest.java
+++ b/modules/ml/xgboost-model-parser/src/test/java/org/apache/ignite/ml/xgboost/parser/XGBoostModelParserTest.java
@@ -24,6 +24,7 @@ import org.apache.ignite.ml.inference.builder.SingleModelBuilder;
import org.apache.ignite.ml.inference.builder.SyncModelBuilder;
import org.apache.ignite.ml.inference.reader.FileSystemModelReader;
import org.apache.ignite.ml.inference.reader.ModelReader;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.xgboost.XGModelComposition;
import org.junit.Test;
@@ -73,7 +74,7 @@ public class XGBoostModelParserTest {
testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
}
- double prediction = mdl.predict(testObj);
+ double prediction = mdl.predict(VectorUtils.of(testObj));
double expPrediction = Double.parseDouble(testExpResultsStr);