You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2016/09/30 00:02:55 UTC
[2/2] lucene-solr:branch_6x: SOLR-9258: Optimizing,
storing and deploying AI models with Streaming Expressions
SOLR-9258: Optimizing, storing and deploying AI models with Streaming Expressions
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/9d1fb907
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/9d1fb907
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/9d1fb907
Branch: refs/heads/branch_6x
Commit: 9d1fb90784b02f03dbac2bc2604943864c00cfd0
Parents: f12f5e3
Author: Joel Bernstein <jb...@apache.org>
Authored: Thu Sep 29 17:44:20 2016 -0400
Committer: Joel Bernstein <jb...@apache.org>
Committed: Thu Sep 29 18:58:44 2016 -0400
----------------------------------------------------------------------
.../org/apache/solr/handler/ClassifyStream.java | 230 +++++++++++++++++++
.../org/apache/solr/handler/StreamHandler.java | 12 +-
.../apache/solr/client/solrj/io/ModelCache.java | 154 +++++++++++++
.../client/solrj/io/stream/ModelStream.java | 200 ++++++++++++++++
.../client/solrj/io/stream/StreamContext.java | 11 +
.../client/solrj/io/stream/TopicStream.java | 10 +-
.../solrj/solr/configsets/ml/conf/schema.xml | 2 +-
.../solrj/io/stream/StreamExpressionTest.java | 145 +++++++++++-
8 files changed, 758 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
new file mode 100644
index 0000000..6b0a02a
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.handler;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Set;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.Locale;
+
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.StreamContext;
+import org.apache.solr.client.solrj.io.stream.TupleStream;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.core.SolrCore;
+import org.apache.lucene.analysis.*;
+
+/**
+ * The classify expression retrieves a model trained by the train expression and uses it to classify documents from a stream
+ * Syntax:
+ * classif(model(...), anyStream(...), field="body")
+ **/
+
+public class ClassifyStream extends TupleStream implements Expressible {
+ private TupleStream docStream;
+ private TupleStream modelStream;
+
+ private String field;
+ private String analyzerField;
+ private Tuple modelTuple;
+
+ Analyzer analyzer;
+ private Map<CharSequence, Integer> termToIndex;
+ private List<Double> idfs;
+ private List<Double> modelWeights;
+
+ public ClassifyStream(StreamExpression expression, StreamFactory factory) throws IOException {
+ List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
+ if (streamExpressions.size() != 2) {
+ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two stream but found %d",expression, streamExpressions.size()));
+ }
+
+ modelStream = factory.constructStream(streamExpressions.get(0));
+ docStream = factory.constructStream(streamExpressions.get(1));
+
+ StreamExpressionNamedParameter fieldParameter = factory.getNamedOperand(expression, "field");
+ if (fieldParameter == null) {
+ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - field parameter must be specified",expression, streamExpressions.size()));
+ }
+ analyzerField = field = fieldParameter.getParameter().toString();
+
+ StreamExpressionNamedParameter analyzerFieldParameter = factory.getNamedOperand(expression, "analyzerField");
+ if (analyzerFieldParameter != null) {
+ analyzerField = analyzerFieldParameter.getParameter().toString();
+ }
+ }
+
+ @Override
+ public void setStreamContext(StreamContext context) {
+ Object solrCoreObj = context.get("solr-core");
+ if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) {
+ throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
+ }
+ SolrCore solrCore = (SolrCore) solrCoreObj;
+ analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer();
+
+ this.docStream.setStreamContext(context);
+ this.modelStream.setStreamContext(context);
+ }
+
+ @Override
+ public List<TupleStream> children() {
+ List<TupleStream> l = new ArrayList<>();
+ l.add(docStream);
+ l.add(modelStream);
+ return l;
+ }
+
+ @Override
+ public void open() throws IOException {
+ this.docStream.open();
+ this.modelStream.open();
+ }
+
+ @Override
+ public void close() throws IOException {
+ this.docStream.close();
+ this.modelStream.close();
+ }
+
+ @Override
+ public Tuple read() throws IOException {
+ if (modelTuple == null) {
+
+ modelTuple = modelStream.read();
+ if (modelTuple == null || modelTuple.EOF) {
+ throw new IOException("Model tuple not found for classify stream!");
+ }
+
+ termToIndex = new HashMap<>();
+
+ List<String> terms = modelTuple.getStrings("terms_ss");
+
+ for (int i = 0; i < terms.size(); i++) {
+ termToIndex.put(terms.get(i), i);
+ }
+
+ idfs = modelTuple.getDoubles("idfs_ds");
+ modelWeights = modelTuple.getDoubles("weights_ds");
+ }
+
+ Tuple docTuple = docStream.read();
+ if (docTuple.EOF) return docTuple;
+
+ String text = docTuple.getString(field);
+
+ double tfs[] = new double[termToIndex.size()];
+
+ TokenStream tokenStream = analyzer.tokenStream(analyzerField, text);
+ CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class);
+ tokenStream.reset();
+
+ int termCount = 0;
+ while (tokenStream.incrementToken()) {
+ termCount++;
+ if (termToIndex.containsKey(termAtt.toString())) {
+ tfs[termToIndex.get(termAtt.toString())]++;
+ }
+ }
+
+ tokenStream.end();
+ tokenStream.close();
+
+ List<Double> tfidfs = new ArrayList<>(termToIndex.size());
+ tfidfs.add(1.0);
+ for (int i = 0; i < tfs.length; i++) {
+ if (tfs[i] != 0) {
+ tfs[i] = 1 + Math.log(tfs[i]);
+ }
+ tfidfs.add(this.idfs.get(i) * tfs[i]);
+ }
+
+ double total = 0.0;
+ for (int i = 0; i < tfidfs.size(); i++) {
+ total += tfidfs.get(i) * modelWeights.get(i);
+ }
+
+ double score = total * ((float) (1.0 / Math.sqrt(termCount)));
+ double positiveProb = sigmoid(total);
+
+ docTuple.put("probability_d", positiveProb);
+ docTuple.put("score_d", score);
+
+ return docTuple;
+ }
+
+ private double sigmoid(double in) {
+ double d = 1.0 / (1+Math.exp(-in));
+ return d;
+ }
+
+ @Override
+ public StreamComparator getStreamSort() {
+ return null;
+ }
+
+ @Override
+ public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+ return toExpression(factory, true);
+ }
+
+ private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
+ // function name
+ StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+
+ if (includeStreams) {
+ if (docStream instanceof Expressible && modelStream instanceof Expressible) {
+ expression.addParameter(((Expressible)modelStream).toExpression(factory));
+ expression.addParameter(((Expressible)docStream).toExpression(factory));
+ } else {
+ throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression");
+ }
+ }
+
+ expression.addParameter(new StreamExpressionNamedParameter("field", field));
+ expression.addParameter(new StreamExpressionNamedParameter("analyzerField", analyzerField));
+
+ return expression;
+ }
+
+ @Override
+ public Explanation toExplanation(StreamFactory factory) throws IOException {
+ StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
+
+ explanation.setFunctionName(factory.getFunctionName(this.getClass()));
+ explanation.setImplementingClass(this.getClass().getName());
+ explanation.setExpressionType(Explanation.ExpressionType.STREAM_DECORATOR);
+ explanation.setExpression(toExpression(factory, false).toString());
+
+ explanation.addChild(docStream.toExplanation(factory));
+ explanation.addChild(modelStream.toExplanation(factory));
+
+ return explanation;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index a88ee96..dfacc1e 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -25,6 +25,7 @@ import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import org.apache.solr.client.solrj.io.ModelCache;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
@@ -36,10 +37,10 @@ import org.apache.solr.client.solrj.io.ops.GroupOperation;
import org.apache.solr.client.solrj.io.ops.ReplaceOperation;
import org.apache.solr.client.solrj.io.stream.*;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
-import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric;
@@ -64,6 +65,7 @@ import org.slf4j.LoggerFactory;
public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, PermissionNameProvider {
static SolrClientCache clientCache = new SolrClientCache();
+ static ModelCache modelCache = null;
private StreamFactory streamFactory = new StreamFactory();
private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private String coreName;
@@ -96,6 +98,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
defaultZkhost = core.getCoreDescriptor().getCoreContainer().getZkController().getZkServerAddress();
streamFactory.withCollectionZkHost(defaultCollection, defaultZkhost);
streamFactory.withDefaultZkHost(defaultZkhost);
+ modelCache = new ModelCache(250,
+ defaultZkhost,
+ clientCache);
}
streamFactory
@@ -130,6 +135,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("shortestPath", ShortestPathStream.class)
.withFunctionName("gatherNodes", GatherNodesStream.class)
.withFunctionName("scoreNodes", ScoreNodesStream.class)
+ .withFunctionName("model", ModelStream.class)
+ .withFunctionName("classify", ClassifyStream.class)
// metrics
.withFunctionName("min", MinMetric.class)
@@ -197,7 +204,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
context.workerID = worker;
context.numWorkers = numWorkers;
context.setSolrClientCache(clientCache);
+ context.setModelCache(modelCache);
context.put("core", this.coreName);
+ context.put("solr-core", req.getCore());
tupleStream.setStreamContext(context);
// if asking for explanation then go get it
@@ -454,5 +463,4 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
return tuple;
}
}
-
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
new file mode 100644
index 0000000..4fe3d8a
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.client.solrj.io;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.lang.invoke.MethodHandles;
+import java.util.Collections;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.apache.solr.client.solrj.SolrClient;
+import org.apache.solr.client.solrj.impl.CloudSolrClient;
+import org.apache.solr.client.solrj.impl.HttpSolrClient;
+import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
+import org.apache.solr.client.solrj.io.stream.StreamContext;
+import org.apache.solr.client.solrj.io.stream.TopicStream;
+import org.apache.solr.common.params.ModifiableSolrParams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * The Model cache keeps a local in-memory copy of models
+ */
+
+public class ModelCache implements Serializable {
+
+ private final LRU models;
+ private String defaultZkHost;
+ private SolrClientCache solrClientCache;
+
+ public ModelCache(int size,
+ String defaultZkHost,
+ SolrClientCache solrClientCache) {
+ this.models = new LRU(size);
+ this.defaultZkHost = defaultZkHost;
+ this.solrClientCache = solrClientCache;
+ }
+
+ public Tuple getModel(String collection,
+ String modelID,
+ long checkMillis) throws IOException {
+ return getModel(defaultZkHost, collection, modelID, checkMillis);
+ }
+
+ public Tuple getModel(String zkHost,
+ String collection,
+ String modelID,
+ long checkMillis) throws IOException {
+ Model model = null;
+ long currentTime = new Date().getTime();
+ synchronized (this) {
+ model = models.get(modelID);
+ if(model != null && ((currentTime - model.getLastChecked()) <= checkMillis)) {
+ return model.getTuple();
+ }
+
+ if(model != null){
+ //model is expired
+ models.remove(modelID);
+ }
+ }
+
+ //Model is not in cache or has expired so fetch the model
+ ModifiableSolrParams params = new ModifiableSolrParams();
+ params.set("q","name_s:"+modelID);
+ params.set("fl", "terms_ss, idfs_ds, weights_ds, iteration_i, _version_");
+ params.set("sort", "iteration_i desc");
+ StreamContext streamContext = new StreamContext();
+ streamContext.setSolrClientCache(solrClientCache);
+ CloudSolrStream stream = new CloudSolrStream(zkHost, collection, params);
+ stream.setStreamContext(streamContext);
+ Tuple tuple = null;
+ try {
+ stream.open();
+ tuple = stream.read();
+ if (tuple.EOF) {
+ return null;
+ }
+ } finally {
+ stream.close();
+ }
+
+ synchronized (this) {
+ //check again to see if another thread has updated the same model
+ Model m = models.get(modelID);
+ if (m != null) {
+ Tuple t = m.getTuple();
+ long v = t.getLong("_version_");
+ if (v >= tuple.getLong("_version_")) {
+ return t;
+ } else {
+ models.put(modelID, new Model(tuple, currentTime));
+ return tuple;
+ }
+ } else {
+ models.put(modelID, new Model(tuple, currentTime));
+ return tuple;
+ }
+ }
+ }
+
+ private static class Model {
+ private Tuple tuple;
+ private long lastChecked;
+
+ public Model(Tuple tuple, long lastChecked) {
+ this.tuple = tuple;
+ this.lastChecked = lastChecked;
+ }
+
+ public Tuple getTuple() {
+ return tuple;
+ }
+
+ public long getLastChecked() {
+ return lastChecked;
+ }
+ }
+
+ private static class LRU extends LinkedHashMap<String, Model> {
+
+ private int maxSize;
+
+ public LRU(int maxSize) {
+ this.maxSize = maxSize;
+ }
+
+ public boolean removeEldestEntry(Map.Entry eldest) {
+ if(size()> maxSize) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
new file mode 100644
index 0000000..70b740d
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.stream;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import org.apache.solr.client.solrj.io.ModelCache;
+import org.apache.solr.client.solrj.io.SolrClientCache;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+/**
+* The ModelStream retrieves a stored model from a Solr Cloud collection.
+*
+* Syntax: model(collection, id="modelID")
+**/
+
+public class ModelStream extends TupleStream implements Expressible {
+
+ private static final long serialVersionUID = 1;
+
+ protected String zkHost;
+ protected String collection;
+ protected String modelID;
+ protected ModelCache modelCache;
+ protected SolrClientCache solrClientCache;
+ protected Tuple model;
+ protected long cacheMillis;
+
+ public ModelStream(String zkHost,
+ String collectionName,
+ String modelID,
+ long cacheMillis) throws IOException {
+
+ init(collectionName, zkHost, modelID, cacheMillis);
+ }
+
+
+ public ModelStream(StreamExpression expression, StreamFactory factory) throws IOException{
+ // grab all parameters out
+ String collectionName = factory.getValueOperand(expression, 0);
+ List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+ StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
+
+ // Collection Name
+ if(null == collectionName){
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
+ }
+
+ // Named parameters - passed directly to solr as solrparams
+ if(0 == namedParams.size()){
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
+ }
+
+ Map<String,String> params = new HashMap<String,String>();
+ for(StreamExpressionNamedParameter namedParam : namedParams){
+ if(!namedParam.getName().equals("zkHost")) {
+ params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
+ }
+ }
+
+ String modelID = params.get("id");
+ if (modelID == null) {
+ throw new IOException("id param cannot be null for ModelStream");
+ }
+
+ long cacheMillis = 300000;
+ String cacheMillisParam = params.get("cacheMillis");
+
+ if(cacheMillisParam != null) {
+ cacheMillis = Long.parseLong(cacheMillisParam);
+ }
+
+ // zkHost, optional - if not provided then will look into factory list to get
+ String zkHost = null;
+ if(null == zkHostExpression) {
+ zkHost = factory.getCollectionZkHost(collectionName);
+ }
+ else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
+ zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
+ }
+
+ if (zkHost == null) {
+ zkHost = factory.getDefaultZkHost();
+ }
+
+ if(null == zkHost){
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
+ }
+
+ // We've got all the required items
+ init(collectionName, zkHost, modelID, cacheMillis);
+ }
+
+ @Override
+ public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+ return toExpression(factory, true);
+ }
+
+ private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
+ // function name
+ StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+ // collection
+ expression.addParameter(collection);
+
+ // zkHost
+ expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
+ expression.addParameter(new StreamExpressionNamedParameter("id", modelID));
+ expression.addParameter(new StreamExpressionNamedParameter("cacheMillis", Long.toString(cacheMillis)));
+
+ return expression;
+ }
+
+ private void init(String collectionName,
+ String zkHost,
+ String modelID,
+ long cacheMillis) throws IOException {
+ this.zkHost = zkHost;
+ this.collection = collectionName;
+ this.modelID = modelID;
+ this.cacheMillis = cacheMillis;
+ }
+
+ public void setStreamContext(StreamContext context) {
+ this.solrClientCache = context.getSolrClientCache();
+ this.modelCache = context.getModelCache();
+ }
+
+ public void open() throws IOException {
+ this.model = modelCache.getModel(collection, modelID, cacheMillis);
+ }
+
+ public List<TupleStream> children() {
+ List<TupleStream> l = new ArrayList();
+ return l;
+ }
+
+ public void close() throws IOException {
+
+ }
+
+ /** Return the stream sort - ie, the order in which records are returned */
+ public StreamComparator getStreamSort(){
+ return null;
+ }
+
+ @Override
+ public Explanation toExplanation(StreamFactory factory) throws IOException {
+ StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
+ explanation.setFunctionName(factory.getFunctionName(this.getClass()));
+ explanation.setImplementingClass(this.getClass().getName());
+ explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
+ explanation.setExpression(toExpression(factory).toString());
+
+ return explanation;
+ }
+
+ public Tuple read() throws IOException {
+ Tuple tuple = null;
+
+ if(model != null) {
+ tuple = model;
+ model = null;
+ } else {
+ Map map = new HashMap();
+ map.put("EOF", true);
+ tuple = new Tuple(map);
+ }
+
+ return tuple;
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
index 8ca808f..6cbf090 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
@@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.stream;
import java.io.Serializable;
import java.util.Map;
import java.util.HashMap;
+
+import org.apache.solr.client.solrj.io.ModelCache;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
@@ -37,6 +39,7 @@ public class StreamContext implements Serializable{
public int workerID;
public int numWorkers;
private SolrClientCache clientCache;
+ private ModelCache modelCache;
private StreamFactory streamFactory;
public Object get(Object key) {
@@ -55,10 +58,18 @@ public class StreamContext implements Serializable{
this.clientCache = clientCache;
}
+ public void setModelCache(ModelCache modelCache) {
+ this.modelCache = modelCache;
+ }
+
public SolrClientCache getSolrClientCache() {
return this.clientCache;
}
+ public ModelCache getModelCache() {
+ return this.modelCache;
+ }
+
public void setStreamFactory(StreamFactory streamFactory) {
this.streamFactory = streamFactory;
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
index 97317a0..d81391d 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
@@ -72,6 +72,7 @@ public class TopicStream extends CloudSolrStream implements Expressible {
private long count;
private int runCount;
+ private boolean initialRun = true;
private String id;
protected long checkpointEvery;
private Map<String, Long> checkpoints = new HashMap<String, Long>();
@@ -350,9 +351,14 @@ public class TopicStream extends CloudSolrStream implements Expressible {
}
public void close() throws IOException {
- runCount = 0;
try {
- persistCheckpoints();
+
+ if (initialRun || runCount > 0) {
+ persistCheckpoints();
+ initialRun = false;
+ runCount = 0;
+ }
+
} finally {
if(solrStreams != null) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
index 3206811..c70b9fd 100644
--- a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
+++ b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
@@ -50,7 +50,7 @@
<field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
- <field name="_version_" type="long" indexed="true" stored="true"/>
+ <field name="_version_" type="long" indexed="true" stored="true" docValues="true"/>
<!-- Dynamic field definitions. If a field name is not found, dynamicFields
will be used if the name matches any of the patterns.
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index 7c3a3a6..87fc951 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -19,6 +19,7 @@ package org.apache.solr.client.solrj.io.stream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
@@ -3072,6 +3073,148 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
+ @Test
+ public void testClassifyStream() throws Exception {
+ CollectionAdminRequest.createCollection("modelCollection", "ml", 2, 1).process(cluster.getSolrClient());
+ AbstractDistribZkTestBase.waitForRecoveriesToFinish("modelCollection", cluster.getSolrClient().getZkStateReader(),
+ false, true, TIMEOUT);
+ CollectionAdminRequest.createCollection("uknownCollection", "ml", 2, 1).process(cluster.getSolrClient());
+ AbstractDistribZkTestBase.waitForRecoveriesToFinish("uknownCollection", cluster.getSolrClient().getZkStateReader(),
+ false, true, TIMEOUT);
+ CollectionAdminRequest.createCollection("checkpointCollection", "ml", 2, 1).process(cluster.getSolrClient());
+ AbstractDistribZkTestBase.waitForRecoveriesToFinish("checkpointCollection", cluster.getSolrClient().getZkStateReader(),
+ false, true, TIMEOUT);
+
+ UpdateRequest updateRequest = new UpdateRequest();
+
+ for (int i = 0; i < 500; i+=2) {
+ updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
+ updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "0");
+ }
+
+ updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+ updateRequest = new UpdateRequest();
+ updateRequest.add(id, String.valueOf(0), "text_s", "a b c c d");
+ updateRequest.add(id, String.valueOf(1), "text_s", "a b e e f");
+ updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+ String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTION;
+ TupleStream updateTrainModelStream;
+ ModifiableSolrParams paramsLoc;
+
+ StreamFactory factory = new StreamFactory()
+ .withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
+ .withCollectionZkHost("modelCollection", cluster.getZkServer().getZkAddress())
+ .withCollectionZkHost("uknownCollection", cluster.getZkServer().getZkAddress())
+ .withFunctionName("features", FeaturesSelectionStream.class)
+ .withFunctionName("train", TextLogitStream.class)
+ .withFunctionName("search", CloudSolrStream.class)
+ .withFunctionName("update", UpdateStream.class);
+
+ // train the model
+ String textLogitExpression = "train(" +
+ "collection1, " +
+ "features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4),"+
+ "q=\"*:*\", " +
+ "name=\"model\", " +
+ "field=\"tv_text\", " +
+ "outcome=\"out_i\", " +
+ "maxIterations=100)";
+ updateTrainModelStream = factory.constructStream("update(modelCollection, batchSize=5, "+textLogitExpression+")");
+ getTuples(updateTrainModelStream);
+ cluster.getSolrClient().commit("modelCollection");
+
+ // classify unknown documents
+ String expr = "classify(" +
+ "model(modelCollection, id=\"model\", cacheMillis=5000)," +
+ "topic(checkpointCollection, uknownCollection, q=\"*:*\", fl=\"text_s, id\", id=\"1000000\", initialCheckpoint=\"0\")," +
+ "field=\"text_s\"," +
+ "analyzerField=\"tv_text\")";
+
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", expr);
+ paramsLoc.set("qt","/stream");
+ SolrStream classifyStream = new SolrStream(url, paramsLoc);
+ Map<String, Double> idToLabel = getIdToLabel(classifyStream, "probability_d");
+ assertEquals(idToLabel.size(), 2);
+ assertEquals(1.0, idToLabel.get("0"), 0.001);
+ assertEquals(0, idToLabel.get("1"), 0.001);
+
+ // Add more documents and classify it
+ updateRequest = new UpdateRequest();
+ updateRequest.add(id, String.valueOf(2), "text_s", "a b c c d");
+ updateRequest.add(id, String.valueOf(3), "text_s", "a b e e f");
+ updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+ classifyStream = new SolrStream(url, paramsLoc);
+ idToLabel = getIdToLabel(classifyStream, "probability_d");
+ assertEquals(idToLabel.size(), 2);
+ assertEquals(1.0, idToLabel.get("2"), 0.001);
+ assertEquals(0, idToLabel.get("3"), 0.001);
+
+
+ // Train another model
+ updateRequest = new UpdateRequest();
+ updateRequest.deleteByQuery("*:*");
+ updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+ updateRequest = new UpdateRequest();
+ for (int i = 0; i < 500; i+=2) {
+ updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "0");
+ updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "1");
+ }
+ updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+ updateTrainModelStream = factory.constructStream("update(modelCollection, batchSize=5, "+textLogitExpression+")");
+ getTuples(updateTrainModelStream);
+ cluster.getSolrClient().commit("modelCollection");
+
+ // Add more documents and classify it
+ updateRequest = new UpdateRequest();
+ updateRequest.add(id, String.valueOf(4), "text_s", "a b c c d");
+ updateRequest.add(id, String.valueOf(5), "text_s", "a b e e f");
+ updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+ //Sleep for 5 seconds to let model cache expire
+ Thread.sleep(5100);
+
+ classifyStream = new SolrStream(url, paramsLoc);
+ idToLabel = getIdToLabel(classifyStream, "probability_d");
+ assertEquals(idToLabel.size(), 2);
+ assertEquals(0, idToLabel.get("4"), 0.001);
+ assertEquals(1.0, idToLabel.get("5"), 0.001);
+
+ //Classify in parallel
+
+ // classify unknown documents
+
+ expr = "parallel(collection1, workers=2, sort=\"_version_ asc\", classify(" +
+ "model(modelCollection, id=\"model\")," +
+ "topic(checkpointCollection, uknownCollection, q=\"id:(4 5)\", fl=\"text_s, id, _version_\", id=\"2000000\", partitionKeys=\"id\", initialCheckpoint=\"0\")," +
+ "field=\"text_s\"," +
+ "analyzerField=\"tv_text\"))";
+
+ paramsLoc.set("expr", expr);
+ classifyStream = new SolrStream(url, paramsLoc);
+ idToLabel = getIdToLabel(classifyStream, "probability_d");
+ assertEquals(idToLabel.size(), 2);
+ assertEquals(0, idToLabel.get("4"), 0.001);
+ assertEquals(1.0, idToLabel.get("5"), 0.001);
+
+ CollectionAdminRequest.deleteCollection("modelCollection").process(cluster.getSolrClient());
+ CollectionAdminRequest.deleteCollection("uknownCollection").process(cluster.getSolrClient());
+ CollectionAdminRequest.deleteCollection("checkpointCollection").process(cluster.getSolrClient());
+ }
+
+ private Map<String,Double> getIdToLabel(TupleStream stream, String outField) throws IOException {
+ Map<String, Double> idToLabel = new HashMap<>();
+ List<Tuple> tuples = getTuples(stream);
+ for (Tuple tuple : tuples) {
+ idToLabel.put(tuple.getString("id"), tuple.getDouble(outField));
+ }
+ return idToLabel;
+ }
+
@Test
public void testBasicTextLogitStream() throws Exception {
@@ -3357,7 +3500,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertOrder(tuples, 2);
}
-
+
protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
tupleStream.open();
List<Tuple> tuples = new ArrayList<Tuple>();