You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2016/09/30 00:02:55 UTC

[2/2] lucene-solr:branch_6x: SOLR-9258: Optimizing, storing and deploying AI models with Streaming Expressions

SOLR-9258: Optimizing, storing and deploying AI models with Streaming Expressions


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/9d1fb907
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/9d1fb907
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/9d1fb907

Branch: refs/heads/branch_6x
Commit: 9d1fb90784b02f03dbac2bc2604943864c00cfd0
Parents: f12f5e3
Author: Joel Bernstein <jb...@apache.org>
Authored: Thu Sep 29 17:44:20 2016 -0400
Committer: Joel Bernstein <jb...@apache.org>
Committed: Thu Sep 29 18:58:44 2016 -0400

----------------------------------------------------------------------
 .../org/apache/solr/handler/ClassifyStream.java | 230 +++++++++++++++++++
 .../org/apache/solr/handler/StreamHandler.java  |  12 +-
 .../apache/solr/client/solrj/io/ModelCache.java | 154 +++++++++++++
 .../client/solrj/io/stream/ModelStream.java     | 200 ++++++++++++++++
 .../client/solrj/io/stream/StreamContext.java   |  11 +
 .../client/solrj/io/stream/TopicStream.java     |  10 +-
 .../solrj/solr/configsets/ml/conf/schema.xml    |   2 +-
 .../solrj/io/stream/StreamExpressionTest.java   | 145 +++++++++++-
 8 files changed, 758 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
new file mode 100644
index 0000000..6b0a02a
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.handler;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Set;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.Locale;
+
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.StreamContext;
+import org.apache.solr.client.solrj.io.stream.TupleStream;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.core.SolrCore;
+import org.apache.lucene.analysis.*;
+
+/**
+ *  The classify expression retrieves a model trained by the train expression and uses it to classify documents from a stream
+ *  Syntax:
+ *  classif(model(...), anyStream(...), field="body")
+ **/
+
+public class ClassifyStream extends TupleStream implements Expressible {
+  private TupleStream docStream;
+  private TupleStream modelStream;
+
+  private String field;
+  private String analyzerField;
+  private Tuple  modelTuple;
+
+  Analyzer analyzer;
+  private Map<CharSequence, Integer> termToIndex;
+  private List<Double> idfs;
+  private List<Double> modelWeights;
+
+  public ClassifyStream(StreamExpression expression, StreamFactory factory) throws IOException {
+    List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
+    if (streamExpressions.size() != 2) {
+      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two stream but found %d",expression, streamExpressions.size()));
+    }
+
+    modelStream = factory.constructStream(streamExpressions.get(0));
+    docStream = factory.constructStream(streamExpressions.get(1));
+
+    StreamExpressionNamedParameter fieldParameter = factory.getNamedOperand(expression, "field");
+    if (fieldParameter == null) {
+      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - field parameter must be specified",expression, streamExpressions.size()));
+    }
+    analyzerField = field = fieldParameter.getParameter().toString();
+
+    StreamExpressionNamedParameter analyzerFieldParameter = factory.getNamedOperand(expression, "analyzerField");
+    if (analyzerFieldParameter != null) {
+      analyzerField = analyzerFieldParameter.getParameter().toString();
+    }
+  }
+
+  @Override
+  public void setStreamContext(StreamContext context) {
+    Object solrCoreObj = context.get("solr-core");
+    if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) {
+      throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
+    }
+    SolrCore solrCore = (SolrCore) solrCoreObj;
+    analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer();
+
+    this.docStream.setStreamContext(context);
+    this.modelStream.setStreamContext(context);
+  }
+
+  @Override
+  public List<TupleStream> children() {
+    List<TupleStream> l = new ArrayList<>();
+    l.add(docStream);
+    l.add(modelStream);
+    return l;
+  }
+
+  @Override
+  public void open() throws IOException {
+    this.docStream.open();
+    this.modelStream.open();
+  }
+
+  @Override
+  public void close() throws IOException {
+    this.docStream.close();
+    this.modelStream.close();
+  }
+
+  @Override
+  public Tuple read() throws IOException {
+    if (modelTuple == null) {
+
+      modelTuple = modelStream.read();
+      if (modelTuple == null || modelTuple.EOF) {
+        throw new IOException("Model tuple not found for classify stream!");
+      }
+
+      termToIndex = new HashMap<>();
+
+      List<String> terms = modelTuple.getStrings("terms_ss");
+
+      for (int i = 0; i < terms.size(); i++) {
+        termToIndex.put(terms.get(i), i);
+      }
+
+      idfs = modelTuple.getDoubles("idfs_ds");
+      modelWeights = modelTuple.getDoubles("weights_ds");
+    }
+
+    Tuple docTuple = docStream.read();
+    if (docTuple.EOF) return docTuple;
+
+    String text = docTuple.getString(field);
+
+    double tfs[] = new double[termToIndex.size()];
+
+    TokenStream tokenStream = analyzer.tokenStream(analyzerField, text);
+    CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class);
+    tokenStream.reset();
+
+    int termCount = 0;
+    while (tokenStream.incrementToken()) {
+      termCount++;
+      if (termToIndex.containsKey(termAtt.toString())) {
+        tfs[termToIndex.get(termAtt.toString())]++;
+      }
+    }
+
+    tokenStream.end();
+    tokenStream.close();
+
+    List<Double> tfidfs = new ArrayList<>(termToIndex.size());
+    tfidfs.add(1.0);
+    for (int i = 0; i < tfs.length; i++) {
+      if (tfs[i] != 0) {
+        tfs[i] = 1 + Math.log(tfs[i]);
+      }
+      tfidfs.add(this.idfs.get(i) * tfs[i]);
+    }
+
+    double total = 0.0;
+    for (int i = 0; i < tfidfs.size(); i++) {
+      total += tfidfs.get(i) * modelWeights.get(i);
+    }
+
+    double score = total * ((float) (1.0 / Math.sqrt(termCount)));
+    double positiveProb = sigmoid(total);
+
+    docTuple.put("probability_d", positiveProb);
+    docTuple.put("score_d",  score);
+
+    return docTuple;
+  }
+
+  private double sigmoid(double in) {
+    double d = 1.0 / (1+Math.exp(-in));
+    return d;
+  }
+
+  @Override
+  public StreamComparator getStreamSort() {
+    return null;
+  }
+
+  @Override
+  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+    return toExpression(factory, true);
+  }
+
+  private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
+    // function name
+    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+
+    if (includeStreams) {
+      if (docStream instanceof Expressible && modelStream instanceof Expressible) {
+        expression.addParameter(((Expressible)modelStream).toExpression(factory));
+        expression.addParameter(((Expressible)docStream).toExpression(factory));
+      } else {
+        throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression");
+      }
+    }
+
+    expression.addParameter(new StreamExpressionNamedParameter("field", field));
+    expression.addParameter(new StreamExpressionNamedParameter("analyzerField", analyzerField));
+
+    return expression;
+  }
+
+  @Override
+  public Explanation toExplanation(StreamFactory factory) throws IOException {
+    StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
+
+    explanation.setFunctionName(factory.getFunctionName(this.getClass()));
+    explanation.setImplementingClass(this.getClass().getName());
+    explanation.setExpressionType(Explanation.ExpressionType.STREAM_DECORATOR);
+    explanation.setExpression(toExpression(factory, false).toString());
+
+    explanation.addChild(docStream.toExplanation(factory));
+    explanation.addChild(modelStream.toExplanation(factory));
+
+    return explanation;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index a88ee96..dfacc1e 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 
+import org.apache.solr.client.solrj.io.ModelCache;
 import org.apache.solr.client.solrj.io.SolrClientCache;
 import org.apache.solr.client.solrj.io.Tuple;
 import org.apache.solr.client.solrj.io.comp.StreamComparator;
@@ -36,10 +37,10 @@ import org.apache.solr.client.solrj.io.ops.GroupOperation;
 import org.apache.solr.client.solrj.io.ops.ReplaceOperation;
 import org.apache.solr.client.solrj.io.stream.*;
 import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
 import org.apache.solr.client.solrj.io.stream.expr.Expressible;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
-import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
 import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
 import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
 import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric;
@@ -64,6 +65,7 @@ import org.slf4j.LoggerFactory;
 public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, PermissionNameProvider {
 
   static SolrClientCache clientCache = new SolrClientCache();
+  static ModelCache modelCache = null;
   private StreamFactory streamFactory = new StreamFactory();
   private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
   private String coreName;
@@ -96,6 +98,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
       defaultZkhost = core.getCoreDescriptor().getCoreContainer().getZkController().getZkServerAddress();
       streamFactory.withCollectionZkHost(defaultCollection, defaultZkhost);
       streamFactory.withDefaultZkHost(defaultZkhost);
+      modelCache = new ModelCache(250,
+                                  defaultZkhost,
+                                  clientCache);
     }
 
      streamFactory
@@ -130,6 +135,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
          .withFunctionName("shortestPath", ShortestPathStream.class)
          .withFunctionName("gatherNodes", GatherNodesStream.class)
          .withFunctionName("scoreNodes", ScoreNodesStream.class)
+         .withFunctionName("model", ModelStream.class)
+         .withFunctionName("classify", ClassifyStream.class)
 
       // metrics
       .withFunctionName("min", MinMetric.class)
@@ -197,7 +204,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
     context.workerID = worker;
     context.numWorkers = numWorkers;
     context.setSolrClientCache(clientCache);
+    context.setModelCache(modelCache);
     context.put("core", this.coreName);
+    context.put("solr-core", req.getCore());
     tupleStream.setStreamContext(context);
     
     // if asking for explanation then go get it
@@ -454,5 +463,4 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
       return tuple;
     }
   }
-
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
new file mode 100644
index 0000000..4fe3d8a
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.client.solrj.io;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.lang.invoke.MethodHandles;
+import java.util.Collections;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.apache.solr.client.solrj.SolrClient;
+import org.apache.solr.client.solrj.impl.CloudSolrClient;
+import org.apache.solr.client.solrj.impl.HttpSolrClient;
+import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
+import org.apache.solr.client.solrj.io.stream.StreamContext;
+import org.apache.solr.client.solrj.io.stream.TopicStream;
+import org.apache.solr.common.params.ModifiableSolrParams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ *  The Model cache keeps a local in-memory copy of models
+ */
+
+public class ModelCache implements Serializable {
+
+  private final LRU models;
+  private String defaultZkHost;
+  private SolrClientCache solrClientCache;
+
+  public ModelCache(int size,
+                    String defaultZkHost,
+                    SolrClientCache solrClientCache) {
+    this.models = new LRU(size);
+    this.defaultZkHost = defaultZkHost;
+    this.solrClientCache = solrClientCache;
+  }
+
+  public Tuple getModel(String collection,
+                        String modelID,
+                        long checkMillis) throws IOException {
+    return getModel(defaultZkHost, collection, modelID, checkMillis);
+  }
+
+  public Tuple getModel(String zkHost,
+                        String collection,
+                        String modelID,
+                        long checkMillis) throws IOException {
+    Model model = null;
+    long currentTime = new Date().getTime();
+    synchronized (this) {
+      model = models.get(modelID);
+      if(model != null && ((currentTime - model.getLastChecked()) <= checkMillis)) {
+        return model.getTuple();
+      }
+
+      if(model != null){
+        //model is expired
+        models.remove(modelID);
+      }
+    }
+
+    //Model is not in cache or has expired so fetch the model
+    ModifiableSolrParams params = new ModifiableSolrParams();
+    params.set("q","name_s:"+modelID);
+    params.set("fl", "terms_ss, idfs_ds, weights_ds, iteration_i, _version_");
+    params.set("sort", "iteration_i desc");
+    StreamContext streamContext = new StreamContext();
+    streamContext.setSolrClientCache(solrClientCache);
+    CloudSolrStream stream = new CloudSolrStream(zkHost, collection, params);
+    stream.setStreamContext(streamContext);
+    Tuple tuple = null;
+    try {
+      stream.open();
+      tuple = stream.read();
+      if (tuple.EOF) {
+        return null;
+      }
+    } finally {
+      stream.close();
+    }
+
+    synchronized (this) {
+      //check again to see if another thread has updated the same model
+      Model m = models.get(modelID);
+      if (m != null) {
+        Tuple t = m.getTuple();
+        long v = t.getLong("_version_");
+        if (v >= tuple.getLong("_version_")) {
+          return t;
+        } else {
+          models.put(modelID, new Model(tuple, currentTime));
+          return tuple;
+        }
+      } else {
+        models.put(modelID, new Model(tuple, currentTime));
+        return tuple;
+      }
+    }
+  }
+
+  private static class Model {
+    private Tuple tuple;
+    private long lastChecked;
+
+    public Model(Tuple tuple, long lastChecked) {
+      this.tuple = tuple;
+      this.lastChecked = lastChecked;
+    }
+
+    public Tuple getTuple() {
+      return tuple;
+    }
+
+    public long getLastChecked() {
+      return lastChecked;
+    }
+  }
+
+  private static class LRU extends LinkedHashMap<String, Model> {
+
+    private int maxSize;
+
+    public LRU(int maxSize) {
+      this.maxSize = maxSize;
+    }
+
+    public boolean removeEldestEntry(Map.Entry eldest) {
+      if(size()> maxSize) {
+        return true;
+      } else {
+        return false;
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
new file mode 100644
index 0000000..70b740d
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.stream;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import org.apache.solr.client.solrj.io.ModelCache;
+import org.apache.solr.client.solrj.io.SolrClientCache;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+/**
+*  The ModelStream retrieves a stored model from a Solr Cloud collection.
+*
+*  Syntax: model(collection, id="modelID")
+**/
+
+public class ModelStream extends TupleStream implements Expressible {
+
+  private static final long serialVersionUID = 1;
+
+  protected String zkHost;
+  protected String collection;
+  protected String modelID;
+  protected ModelCache modelCache;
+  protected SolrClientCache solrClientCache;
+  protected Tuple model;
+  protected long cacheMillis;
+
+  public ModelStream(String zkHost,
+                     String collectionName,
+                     String modelID,
+                     long cacheMillis) throws IOException {
+
+    init(collectionName, zkHost, modelID, cacheMillis);
+  }
+
+
+  public ModelStream(StreamExpression expression, StreamFactory factory) throws IOException{
+    // grab all parameters out
+    String collectionName = factory.getValueOperand(expression, 0);
+    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+    StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
+
+    // Collection Name
+    if(null == collectionName){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
+    }
+
+    // Named parameters - passed directly to solr as solrparams
+    if(0 == namedParams.size()){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
+    }
+
+    Map<String,String> params = new HashMap<String,String>();
+    for(StreamExpressionNamedParameter namedParam : namedParams){
+      if(!namedParam.getName().equals("zkHost")) {
+        params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
+      }
+    }
+
+    String modelID = params.get("id");
+    if (modelID == null) {
+      throw new IOException("id param cannot be null for ModelStream");
+    }
+
+    long cacheMillis = 300000;
+    String cacheMillisParam = params.get("cacheMillis");
+
+    if(cacheMillisParam != null) {
+      cacheMillis = Long.parseLong(cacheMillisParam);
+    }
+
+    // zkHost, optional - if not provided then will look into factory list to get
+    String zkHost = null;
+    if(null == zkHostExpression) {
+      zkHost = factory.getCollectionZkHost(collectionName);
+    }
+    else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
+      zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
+    }
+
+    if (zkHost == null) {
+      zkHost = factory.getDefaultZkHost();
+    }
+
+    if(null == zkHost){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
+    }
+
+    // We've got all the required items
+    init(collectionName, zkHost, modelID, cacheMillis);
+  }
+
+  @Override
+  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+    return toExpression(factory, true);
+  }
+
+  private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
+    // function name
+    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+    // collection
+    expression.addParameter(collection);
+
+    // zkHost
+    expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
+    expression.addParameter(new StreamExpressionNamedParameter("id", modelID));
+    expression.addParameter(new StreamExpressionNamedParameter("cacheMillis", Long.toString(cacheMillis)));
+
+    return expression;
+  }
+
+  private void init(String collectionName,
+                    String zkHost,
+                    String modelID,
+                    long cacheMillis) throws IOException {
+    this.zkHost = zkHost;
+    this.collection = collectionName;
+    this.modelID = modelID;
+    this.cacheMillis = cacheMillis;
+  }
+
+  public void setStreamContext(StreamContext context) {
+    this.solrClientCache = context.getSolrClientCache();
+    this.modelCache = context.getModelCache();
+  }
+
+  public void open() throws IOException {
+    this.model = modelCache.getModel(collection, modelID, cacheMillis);
+  }
+
+  public List<TupleStream> children() {
+    List<TupleStream> l =  new ArrayList();
+    return l;
+  }
+
+  public void close() throws IOException {
+
+  }
+
+  /** Return the stream sort - ie, the order in which records are returned */
+  public StreamComparator getStreamSort(){
+    return null;
+  }
+
+  @Override
+  public Explanation toExplanation(StreamFactory factory) throws IOException {
+    StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
+    explanation.setFunctionName(factory.getFunctionName(this.getClass()));
+    explanation.setImplementingClass(this.getClass().getName());
+    explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
+    explanation.setExpression(toExpression(factory).toString());
+
+    return explanation;
+  }
+
+  public Tuple read() throws IOException {
+    Tuple tuple = null;
+
+    if(model != null) {
+      tuple = model;
+      model = null;
+    } else {
+      Map map = new HashMap();
+      map.put("EOF", true);
+      tuple = new Tuple(map);
+    }
+
+    return tuple;
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
index 8ca808f..6cbf090 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
@@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.stream;
 import java.io.Serializable;
 import java.util.Map;
 import java.util.HashMap;
+
+import org.apache.solr.client.solrj.io.ModelCache;
 import org.apache.solr.client.solrj.io.SolrClientCache;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
@@ -37,6 +39,7 @@ public class StreamContext implements Serializable{
   public int workerID;
   public int numWorkers;
   private SolrClientCache clientCache;
+  private ModelCache modelCache;
   private StreamFactory streamFactory;
 
   public Object get(Object key) {
@@ -55,10 +58,18 @@ public class StreamContext implements Serializable{
     this.clientCache = clientCache;
   }
 
+  public void setModelCache(ModelCache modelCache) {
+    this.modelCache = modelCache;
+  }
+
   public SolrClientCache getSolrClientCache() {
     return this.clientCache;
   }
 
+  public ModelCache getModelCache() {
+    return this.modelCache;
+  }
+
   public void setStreamFactory(StreamFactory streamFactory) {
     this.streamFactory = streamFactory;
   }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
index 97317a0..d81391d 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java
@@ -72,6 +72,7 @@ public class TopicStream extends CloudSolrStream implements Expressible  {
 
   private long count;
   private int runCount;
+  private boolean initialRun = true;
   private String id;
   protected long checkpointEvery;
   private Map<String, Long> checkpoints = new HashMap<String, Long>();
@@ -350,9 +351,14 @@ public class TopicStream extends CloudSolrStream implements Expressible  {
   }
 
   public void close() throws IOException {
-    runCount = 0;
     try {
-      persistCheckpoints();
+
+      if (initialRun || runCount > 0) {
+        persistCheckpoints();
+        initialRun = false;
+        runCount = 0;
+      }
+
     } finally {
 
       if(solrStreams != null) {

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
index 3206811..c70b9fd 100644
--- a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
+++ b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
@@ -50,7 +50,7 @@
 
 
     <field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
-    <field name="_version_" type="long" indexed="true" stored="true"/>
+    <field name="_version_" type="long" indexed="true" stored="true" docValues="true"/>
 
     <!-- Dynamic field definitions.  If a field name is not found, dynamicFields
          will be used if the name matches any of the patterns.

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/9d1fb907/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index 7c3a3a6..87fc951 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -19,6 +19,7 @@ package org.apache.solr.client.solrj.io.stream;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
@@ -3072,6 +3073,148 @@ public class StreamExpressionTest extends SolrCloudTestCase {
 
   }
 
+  @Test
+  public void testClassifyStream() throws Exception {
+    CollectionAdminRequest.createCollection("modelCollection", "ml", 2, 1).process(cluster.getSolrClient());
+    AbstractDistribZkTestBase.waitForRecoveriesToFinish("modelCollection", cluster.getSolrClient().getZkStateReader(),
+        false, true, TIMEOUT);
+    CollectionAdminRequest.createCollection("uknownCollection", "ml", 2, 1).process(cluster.getSolrClient());
+    AbstractDistribZkTestBase.waitForRecoveriesToFinish("uknownCollection", cluster.getSolrClient().getZkStateReader(),
+        false, true, TIMEOUT);
+    CollectionAdminRequest.createCollection("checkpointCollection", "ml", 2, 1).process(cluster.getSolrClient());
+    AbstractDistribZkTestBase.waitForRecoveriesToFinish("checkpointCollection", cluster.getSolrClient().getZkStateReader(),
+        false, true, TIMEOUT);
+
+    UpdateRequest updateRequest = new UpdateRequest();
+
+    for (int i = 0; i < 500; i+=2) {
+      updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
+      updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "0");
+    }
+
+    updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+    updateRequest = new UpdateRequest();
+    updateRequest.add(id, String.valueOf(0), "text_s", "a b c c d");
+    updateRequest.add(id, String.valueOf(1), "text_s", "a b e e f");
+    updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+    String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString() + "/" + COLLECTION;
+    TupleStream updateTrainModelStream;
+    ModifiableSolrParams paramsLoc;
+
+    StreamFactory factory = new StreamFactory()
+        .withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
+        .withCollectionZkHost("modelCollection", cluster.getZkServer().getZkAddress())
+        .withCollectionZkHost("uknownCollection", cluster.getZkServer().getZkAddress())
+        .withFunctionName("features", FeaturesSelectionStream.class)
+        .withFunctionName("train", TextLogitStream.class)
+        .withFunctionName("search", CloudSolrStream.class)
+        .withFunctionName("update", UpdateStream.class);
+
+    // train the model
+    String textLogitExpression = "train(" +
+        "collection1, " +
+        "features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4),"+
+        "q=\"*:*\", " +
+        "name=\"model\", " +
+        "field=\"tv_text\", " +
+        "outcome=\"out_i\", " +
+        "maxIterations=100)";
+    updateTrainModelStream = factory.constructStream("update(modelCollection, batchSize=5, "+textLogitExpression+")");
+    getTuples(updateTrainModelStream);
+    cluster.getSolrClient().commit("modelCollection");
+
+    // classify unknown documents
+    String expr = "classify(" +
+        "model(modelCollection, id=\"model\", cacheMillis=5000)," +
+        "topic(checkpointCollection, uknownCollection, q=\"*:*\", fl=\"text_s, id\", id=\"1000000\", initialCheckpoint=\"0\")," +
+        "field=\"text_s\"," +
+        "analyzerField=\"tv_text\")";
+
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", expr);
+    paramsLoc.set("qt","/stream");
+    SolrStream classifyStream = new SolrStream(url, paramsLoc);
+    Map<String, Double> idToLabel = getIdToLabel(classifyStream, "probability_d");
+    assertEquals(idToLabel.size(), 2);
+    assertEquals(1.0, idToLabel.get("0"), 0.001);
+    assertEquals(0, idToLabel.get("1"), 0.001);
+
+    // Add more documents and classify it
+    updateRequest = new UpdateRequest();
+    updateRequest.add(id, String.valueOf(2), "text_s", "a b c c d");
+    updateRequest.add(id, String.valueOf(3), "text_s", "a b e e f");
+    updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+    classifyStream = new SolrStream(url, paramsLoc);
+    idToLabel = getIdToLabel(classifyStream, "probability_d");
+    assertEquals(idToLabel.size(), 2);
+    assertEquals(1.0, idToLabel.get("2"), 0.001);
+    assertEquals(0, idToLabel.get("3"), 0.001);
+
+
+    // Train another model
+    updateRequest = new UpdateRequest();
+    updateRequest.deleteByQuery("*:*");
+    updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+    updateRequest = new UpdateRequest();
+    for (int i = 0; i < 500; i+=2) {
+      updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "0");
+      updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "1");
+    }
+    updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+    updateTrainModelStream = factory.constructStream("update(modelCollection, batchSize=5, "+textLogitExpression+")");
+    getTuples(updateTrainModelStream);
+    cluster.getSolrClient().commit("modelCollection");
+
+    // Add more documents and classify it
+    updateRequest = new UpdateRequest();
+    updateRequest.add(id, String.valueOf(4), "text_s", "a b c c d");
+    updateRequest.add(id, String.valueOf(5), "text_s", "a b e e f");
+    updateRequest.commit(cluster.getSolrClient(), "uknownCollection");
+
+    //Sleep for 5 seconds to let model cache expire
+    Thread.sleep(5100);
+
+    classifyStream = new SolrStream(url, paramsLoc);
+    idToLabel = getIdToLabel(classifyStream, "probability_d");
+    assertEquals(idToLabel.size(), 2);
+    assertEquals(0, idToLabel.get("4"), 0.001);
+    assertEquals(1.0, idToLabel.get("5"), 0.001);
+
+    //Classify in parallel
+
+    // classify unknown documents
+
+    expr = "parallel(collection1, workers=2, sort=\"_version_ asc\", classify(" +
+           "model(modelCollection, id=\"model\")," +
+           "topic(checkpointCollection, uknownCollection, q=\"id:(4 5)\", fl=\"text_s, id, _version_\", id=\"2000000\", partitionKeys=\"id\", initialCheckpoint=\"0\")," +
+           "field=\"text_s\"," +
+           "analyzerField=\"tv_text\"))";
+
+    paramsLoc.set("expr", expr);
+    classifyStream = new SolrStream(url, paramsLoc);
+    idToLabel = getIdToLabel(classifyStream, "probability_d");
+    assertEquals(idToLabel.size(), 2);
+    assertEquals(0, idToLabel.get("4"), 0.001);
+    assertEquals(1.0, idToLabel.get("5"), 0.001);
+
+    CollectionAdminRequest.deleteCollection("modelCollection").process(cluster.getSolrClient());
+    CollectionAdminRequest.deleteCollection("uknownCollection").process(cluster.getSolrClient());
+    CollectionAdminRequest.deleteCollection("checkpointCollection").process(cluster.getSolrClient());
+  }
+
+  private Map<String,Double> getIdToLabel(TupleStream stream, String outField) throws IOException {
+    Map<String, Double> idToLabel = new HashMap<>();
+    List<Tuple> tuples = getTuples(stream);
+    for (Tuple tuple : tuples) {
+      idToLabel.put(tuple.getString("id"), tuple.getDouble(outField));
+    }
+    return idToLabel;
+  }
+
 
   @Test
   public void testBasicTextLogitStream() throws Exception {
@@ -3357,7 +3500,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertOrder(tuples, 2);
 
   }
-  
+
   protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
     tupleStream.open();
     List<Tuple> tuples = new ArrayList<Tuple>();