You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2016/04/18 22:59:33 UTC
[2/2] lucene-solr:master: SOLR-8925: Add gatherNodes Streaming
Expression to support breadth first traversals
SOLR-8925: Add gatherNodes Streaming Expression to support breadth first traversals
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/8659ea33
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/8659ea33
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/8659ea33
Branch: refs/heads/master
Commit: 8659ea33d909ca76c793a778c694feea0c74af3b
Parents: 84a6ff6
Author: jbernste <jb...@apache.org>
Authored: Mon Apr 18 15:36:12 2016 -0400
Committer: jbernste <jb...@apache.org>
Committed: Mon Apr 18 16:09:56 2016 -0400
----------------------------------------------------------------------
.../org/apache/solr/handler/StreamHandler.java | 18 +-
.../solrj/io/graph/GatherNodesStream.java | 580 +++++++++++++++++++
.../apache/solr/client/solrj/io/graph/Node.java | 90 +++
.../solr/client/solrj/io/graph/Traversal.java | 96 +++
.../solrj/io/graph/TraversalIterator.java | 120 ++++
.../client/solrj/io/stream/StreamContext.java | 4 +
.../solrj/io/stream/metrics/CountMetric.java | 4 +
.../solrj/io/stream/metrics/MaxMetric.java | 5 +
.../solrj/io/stream/metrics/MeanMetric.java | 5 +
.../client/solrj/io/stream/metrics/Metric.java | 2 +
.../solrj/io/stream/metrics/MinMetric.java | 7 +-
.../solrj/io/stream/metrics/SumMetric.java | 5 +
.../solrj/io/graph/GraphExpressionTest.java | 402 ++++++++++++-
13 files changed, 1331 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index 5ddd312..7c47c76 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -28,6 +28,7 @@ import java.util.Map.Entry;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.graph.GatherNodesStream;
import org.apache.solr.client.solrj.io.graph.ShortestPathStream;
import org.apache.solr.client.solrj.io.ops.ConcatOperation;
import org.apache.solr.client.solrj.io.ops.DistinctOperation;
@@ -117,11 +118,10 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("outerHashJoin", OuterHashJoinStream.class)
.withFunctionName("intersect", IntersectStream.class)
.withFunctionName("complement", ComplementStream.class)
- .withFunctionName("daemon", DaemonStream.class)
- .withFunctionName("sort", SortStream.class)
-
- // graph streams
- .withFunctionName("shortestPath", ShortestPathStream.class)
+ .withFunctionName("sort", SortStream.class)
+ .withFunctionName("daemon", DaemonStream.class)
+ .withFunctionName("shortestPath", ShortestPathStream.class)
+ .withFunctionName("gatherNodes", GatherNodesStream.class)
// metrics
.withFunctionName("min", MinMetric.class)
@@ -274,6 +274,14 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
public Tuple read() {
String msg = e.getMessage();
+
+ Throwable t = e.getCause();
+ while(t != null) {
+ msg = t.getMessage();
+ t = t.getCause();
+ }
+
+
Map m = new HashMap();
m.put("EOF", true);
m.put("EXCEPTION", msg);
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java
new file mode 100644
index 0000000..759aa0f
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java
@@ -0,0 +1,580 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.graph;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.ArrayList;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor;
+import org.apache.solr.client.solrj.io.stream.*;
+import org.apache.solr.client.solrj.io.stream.metrics.*;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.common.util.ExecutorUtil;
+import org.apache.solr.common.util.SolrjNamedThreadFactory;
+
+public class GatherNodesStream extends TupleStream implements Expressible {
+
+ private String zkHost;
+ private String collection;
+ private StreamContext streamContext;
+ private Map queryParams;
+ private String traverseFrom;
+ private String traverseTo;
+ private String gather;
+ private boolean trackTraversal;
+ private boolean useDefaultTraversal;
+
+ private TupleStream tupleStream;
+ private Set<Traversal.Scatter> scatter;
+ private Iterator<Tuple> out;
+ private Traversal traversal;
+ private List<Metric> metrics;
+
+ public GatherNodesStream(String zkHost,
+ String collection,
+ TupleStream tupleStream,
+ String traverseFrom,
+ String traverseTo,
+ String gather,
+ Map queryParams,
+ List<Metric> metrics,
+ boolean trackTraversal,
+ Set<Traversal.Scatter> scatter) {
+
+ init(zkHost,
+ collection,
+ tupleStream,
+ traverseFrom,
+ traverseTo,
+ gather,
+ queryParams,
+ metrics,
+ trackTraversal,
+ scatter);
+ }
+
+ public GatherNodesStream(StreamExpression expression, StreamFactory factory) throws IOException {
+
+
+ String collectionName = factory.getValueOperand(expression, 0);
+ List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+ StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
+
+ List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
+ // Collection Name
+ if(null == collectionName) {
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
+ }
+
+
+ Set<Traversal.Scatter> scatter = new HashSet();
+
+ StreamExpressionNamedParameter scatterExpression = factory.getNamedOperand(expression, "scatter");
+
+ if(scatterExpression == null) {
+ scatter.add(Traversal.Scatter.LEAVES);
+ } else {
+ String s = ((StreamExpressionValue)scatterExpression.getParameter()).getValue();
+ String[] sArray = s.split(",");
+ for(String sv : sArray) {
+ sv = sv.trim();
+ if(Traversal.Scatter.BRANCHES.toString().equalsIgnoreCase(sv)) {
+ scatter.add(Traversal.Scatter.BRANCHES);
+ } else if (Traversal.Scatter.LEAVES.toString().equalsIgnoreCase(sv)) {
+ scatter.add(Traversal.Scatter.LEAVES);
+ }
+ }
+ }
+
+ String gather = null;
+ StreamExpressionNamedParameter gatherExpression = factory.getNamedOperand(expression, "gather");
+
+ if(gatherExpression == null) {
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - from param is required",expression));
+ } else {
+ gather = ((StreamExpressionValue)gatherExpression.getParameter()).getValue();
+ }
+
+ String traverseFrom = null;
+ String traverseTo = null;
+ StreamExpressionNamedParameter edgeExpression = factory.getNamedOperand(expression, "walk");
+
+ TupleStream stream = null;
+
+ if(edgeExpression == null) {
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - walk param is required", expression));
+ } else {
+ if(streamExpressions.size() > 0) {
+ stream = factory.constructStream(streamExpressions.get(0));
+ String edge = ((StreamExpressionValue) edgeExpression.getParameter()).getValue();
+ String[] fields = edge.split("->");
+ if (fields.length != 2) {
+ throw new IOException(String.format(Locale.ROOT, "invalid expression %s - walk param separated by an -> and must contain two fields", expression));
+ }
+ traverseFrom = fields[0].trim();
+ traverseTo = fields[1].trim();
+ } else {
+ String edge = ((StreamExpressionValue) edgeExpression.getParameter()).getValue();
+ String[] fields = edge.split("->");
+ if (fields.length != 2) {
+ throw new IOException(String.format(Locale.ROOT, "invalid expression %s - walk param separated by an -> and must contain two fields", expression));
+ }
+
+ String[] rootNodes = fields[0].split(",");
+ List<String> l = new ArrayList();
+ for(String n : rootNodes) {
+ l.add(n.trim());
+ }
+
+ stream = new NodeStream(l);
+ traverseFrom = "node";
+ traverseTo = fields[1].trim();
+ }
+ }
+
+ List<StreamExpression> metricExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, Metric.class);
+ List<Metric> metrics = new ArrayList();
+ for(int idx = 0; idx < metricExpressions.size(); ++idx){
+ metrics.add(factory.constructMetric(metricExpressions.get(idx)));
+ }
+
+ boolean trackTraversal = false;
+
+ StreamExpressionNamedParameter trackExpression = factory.getNamedOperand(expression, "trackTraversal");
+
+ if(trackExpression != null) {
+ trackTraversal = Boolean.parseBoolean(((StreamExpressionValue) trackExpression.getParameter()).getValue());
+ } else {
+ useDefaultTraversal = true;
+ }
+
+ StreamExpressionNamedParameter scopeExpression = factory.getNamedOperand(expression, "localScope");
+
+ if(trackExpression != null) {
+ trackTraversal = Boolean.parseBoolean(((StreamExpressionValue) trackExpression.getParameter()).getValue());
+ }
+
+ Map<String,String> params = new HashMap<String,String>();
+ for(StreamExpressionNamedParameter namedParam : namedParams){
+ if(!namedParam.getName().equals("zkHost") &&
+ !namedParam.getName().equals("gather") &&
+ !namedParam.getName().equals("walk") &&
+ !namedParam.getName().equals("scatter") &&
+ !namedParam.getName().equals("trackTraversal"))
+ {
+ params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
+ }
+ }
+
+ // zkHost, optional - if not provided then will look into factory list to get
+ String zkHost = null;
+ if(null == zkHostExpression){
+ zkHost = factory.getCollectionZkHost(collectionName);
+ if(zkHost == null) {
+ zkHost = factory.getDefaultZkHost();
+ }
+ } else if(zkHostExpression.getParameter() instanceof StreamExpressionValue) {
+ zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
+ }
+
+ if(null == zkHost){
+ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
+ }
+
+ // We've got all the required items
+ init(zkHost,
+ collectionName,
+ stream,
+ traverseFrom,
+ traverseTo ,
+ gather,
+ params,
+ metrics,
+ trackTraversal,
+ scatter);
+ }
+
+ private void init(String zkHost,
+ String collection,
+ TupleStream tupleStream,
+ String traverseFrom,
+ String traverseTo,
+ String gather,
+ Map queryParams,
+ List<Metric> metrics,
+ boolean trackTraversal,
+ Set<Traversal.Scatter> scatter) {
+ this.zkHost = zkHost;
+ this.collection = collection;
+ this.tupleStream = tupleStream;
+ this.traverseFrom = traverseFrom;
+ this.traverseTo = traverseTo;
+ this.gather = gather;
+ this.queryParams = queryParams;
+ this.metrics = metrics;
+ this.trackTraversal = trackTraversal;
+ this.scatter = scatter;
+ }
+
+ @Override
+ public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+
+ StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+
+ // collection
+ expression.addParameter(collection);
+
+ if(tupleStream instanceof Expressible){
+ expression.addParameter(((Expressible)tupleStream).toExpression(factory));
+ }
+ else{
+ throw new IOException("This GatherNodesStream contains a non-expressible TupleStream - it cannot be converted to an expression");
+ }
+
+ Set<Map.Entry> entries = queryParams.entrySet();
+ // parameters
+ for(Map.Entry param : entries){
+ String value = param.getValue().toString();
+
+ // SOLR-8409: This is a special case where the params contain a " character
+ // Do note that in any other BASE streams with parameters where a " might come into play
+ // that this same replacement needs to take place.
+ value = value.replace("\"", "\\\"");
+
+ expression.addParameter(new StreamExpressionNamedParameter(param.getKey().toString(), value));
+ }
+
+ if(metrics != null) {
+ for (Metric metric : metrics) {
+ expression.addParameter(metric.toExpression(factory));
+ }
+ }
+
+ expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
+ expression.addParameter(new StreamExpressionNamedParameter("gather", zkHost));
+ expression.addParameter(new StreamExpressionNamedParameter("walk", traverseFrom+"->"+traverseTo));
+ expression.addParameter(new StreamExpressionNamedParameter("trackTraversal", Boolean.toString(trackTraversal)));
+
+ StringBuilder buf = new StringBuilder();
+ for(Traversal.Scatter sc : scatter) {
+ if(buf.length() > 0 ) {
+ buf.append(",");
+ }
+ buf.append(sc.toString());
+ }
+
+ expression.addParameter(new StreamExpressionNamedParameter("scatter", buf.toString()));
+
+ return expression;
+ }
+
+ public void setStreamContext(StreamContext context) {
+ this.traversal = (Traversal) context.get("traversal");
+ if (traversal == null) {
+ //No traversal in the context. So create a new context and a new traversal.
+ //This ensures that two separate traversals in the same expression don't pollute each others traversal.
+ StreamContext localContext = new StreamContext();
+
+ localContext.numWorkers = context.numWorkers;
+ localContext.workerID = context.workerID;
+ localContext.setSolrClientCache(context.getSolrClientCache());
+ localContext.setStreamFactory(context.getStreamFactory());
+
+ for(Object key :context.getEntries().keySet()) {
+ localContext.put(key, context.get(key));
+ }
+
+ traversal = new Traversal();
+
+ localContext.put("traversal", traversal);
+
+ this.tupleStream.setStreamContext(localContext);
+ this.streamContext = localContext;
+ } else {
+ this.tupleStream.setStreamContext(context);
+ this.streamContext = context;
+ }
+ }
+
+ public List<TupleStream> children() {
+ List<TupleStream> l = new ArrayList();
+ l.add(tupleStream);
+ return l;
+ }
+
+ public void open() throws IOException {
+ tupleStream.open();
+ }
+
+ private class JoinRunner implements Callable<List<Tuple>> {
+
+ private List<String> nodes;
+ private List<Tuple> edges = new ArrayList();
+
+ public JoinRunner(List<String> nodes) {
+ this.nodes = nodes;
+ }
+
+ public List<Tuple> call() {
+
+ Map joinParams = new HashMap();
+ Set<String> flSet = new HashSet();
+ flSet.add(gather);
+ flSet.add(traverseTo);
+
+ //Add the metric columns
+
+ if(metrics != null) {
+ for(Metric metric : metrics) {
+ for(String column : metric.getColumns()) {
+ flSet.add(column);
+ }
+ }
+ }
+
+ if(queryParams.containsKey("fl")) {
+ String flString = (String)queryParams.get("fl");
+ String[] flArray = flString.split(",");
+ for(String f : flArray) {
+ flSet.add(f.trim());
+ }
+ }
+
+ Iterator<String> it = flSet.iterator();
+ StringBuilder buf = new StringBuilder();
+ while(it.hasNext()) {
+ buf.append(it.next());
+ if(it.hasNext()) {
+ buf.append(",");
+ }
+ }
+
+ joinParams.putAll(queryParams);
+ joinParams.put("fl", buf.toString());
+ joinParams.put("qt", "/export");
+ joinParams.put("sort", gather + " asc,"+traverseTo +" asc");
+
+ StringBuffer nodeQuery = new StringBuffer();
+
+ for(String node : nodes) {
+ nodeQuery.append(node).append(" ");
+ }
+
+ String q = traverseTo + ":(" + nodeQuery.toString().trim() + ")";
+
+
+ joinParams.put("q", q);
+ TupleStream stream = null;
+ try {
+ stream = new UniqueStream(new CloudSolrStream(zkHost, collection, joinParams), new MultipleFieldEqualitor(new FieldEqualitor(gather), new FieldEqualitor(traverseTo)));
+ stream.setStreamContext(streamContext);
+ stream.open();
+ BATCH:
+ while (true) {
+ Tuple tuple = stream.read();
+ if (tuple.EOF) {
+ break BATCH;
+ }
+
+ edges.add(tuple);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ try {
+ stream.close();
+ } catch(Exception ce) {
+ throw new RuntimeException(ce);
+ }
+ }
+ return edges;
+ }
+ }
+
+
+ public void close() throws IOException {
+ tupleStream.close();
+ }
+
+ public Tuple read() throws IOException {
+
+ if (out == null) {
+ List<String> joinBatch = new ArrayList();
+ List<Future<List<Tuple>>> futures = new ArrayList();
+ Map<String, Node> level = new HashMap();
+
+ ExecutorService threadPool = null;
+ try {
+ threadPool = ExecutorUtil.newMDCAwareFixedThreadPool(4, new SolrjNamedThreadFactory("GatherNodesStream"));
+
+ Map<String, Node> roots = new HashMap();
+
+ while (true) {
+ Tuple tuple = tupleStream.read();
+ if (tuple.EOF) {
+ if (joinBatch.size() > 0) {
+ JoinRunner joinRunner = new JoinRunner(joinBatch);
+ Future future = threadPool.submit(joinRunner);
+ futures.add(future);
+ }
+ break;
+ }
+
+ String value = tuple.getString(traverseFrom);
+
+ if(traversal.getDepth() == 0) {
+ //This gathers the root nodes
+ //We check to see if there are dupes in the root nodes because root streams may not have been uniqued.
+ String key = collection+"."+value;
+ if(!roots.containsKey(key)) {
+ Node node = new Node(value, trackTraversal);
+ if (metrics != null) {
+ List<Metric> _metrics = new ArrayList();
+ for (Metric metric : metrics) {
+ _metrics.add(metric.newInstance());
+ }
+ node.setMetrics(_metrics);
+ }
+
+ roots.put(key, node);
+ } else {
+ continue;
+ }
+ }
+
+ joinBatch.add(value);
+ if (joinBatch.size() == 400) {
+ JoinRunner joinRunner = new JoinRunner(joinBatch);
+ Future future = threadPool.submit(joinRunner);
+ futures.add(future);
+ joinBatch = new ArrayList();
+ }
+ }
+
+ if(traversal.getDepth() == 0) {
+ traversal.addLevel(roots, collection, traverseFrom);
+ }
+
+ this.traversal.setScatter(scatter);
+
+ if(useDefaultTraversal) {
+ this.trackTraversal = traversal.getTrackTraversal();
+ } else {
+ this.traversal.setTrackTraversal(trackTraversal);
+ }
+
+ for (Future<List<Tuple>> future : futures) {
+ List<Tuple> tuples = future.get();
+ for (Tuple tuple : tuples) {
+ String _traverseTo = tuple.getString(traverseTo);
+ String _gather = tuple.getString(gather);
+ String key = collection + "." + _gather;
+ if (!traversal.visited(key, _traverseTo, tuple)) {
+ Node node = level.get(key);
+ if (node != null) {
+ node.add((traversal.getDepth()-1)+"^"+_traverseTo, tuple);
+ } else {
+ node = new Node(_gather, trackTraversal);
+ if (metrics != null) {
+ List<Metric> _metrics = new ArrayList();
+ for (Metric metric : metrics) {
+ _metrics.add(metric.newInstance());
+ }
+ node.setMetrics(_metrics);
+ }
+ node.add((traversal.getDepth()-1)+"^"+_traverseTo, tuple);
+ level.put(key, node);
+ }
+ }
+ }
+ }
+
+ traversal.addLevel(level, collection, gather);
+ out = traversal.iterator();
+ } catch(Exception e) {
+ throw new RuntimeException(e);
+ } finally {
+ threadPool.shutdown();
+ }
+ }
+
+ if (out.hasNext()) {
+ return out.next();
+ } else {
+ Map map = new HashMap();
+ map.put("EOF", true);
+ Tuple tuple = new Tuple(map);
+ return tuple;
+ }
+ }
+
+ public int getCost() {
+ return 0;
+ }
+
+ @Override
+ public StreamComparator getStreamSort() {
+ return null;
+ }
+
+ class NodeStream extends TupleStream {
+
+ private List<String> ids;
+ private Iterator<String> it;
+
+ public NodeStream(List<String> ids) {
+ this.ids = ids;
+ }
+
+ public void open() {this.it = ids.iterator();}
+ public void close() {}
+ public StreamComparator getStreamSort() {return null;}
+ public List<TupleStream> children() {return new ArrayList();}
+ public void setStreamContext(StreamContext context) {}
+
+ public Tuple read() {
+ HashMap map = new HashMap();
+ if(it.hasNext()) {
+ map.put("node",it.next());
+ return new Tuple(map);
+ } else {
+
+ map.put("EOF", true);
+ return new Tuple(map);
+ }
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java
new file mode 100644
index 0000000..befa5a7
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.graph;
+
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.stream.metrics.*;
+import java.util.*;
+
+public class Node {
+
+ private String id;
+ private List<Metric> metrics;
+ private Set<String> ancestors;
+
+ public Node(String id, boolean track) {
+ this.id=id;
+ if(track) {
+ ancestors = new HashSet();
+ }
+ }
+
+ public void setMetrics(List<Metric> metrics) {
+ this.metrics = metrics;
+ }
+
+ public void add(String ancestor, Tuple tuple) {
+ if(ancestors != null) {
+ ancestors.add(ancestor);
+ }
+
+ if(metrics != null) {
+ for(Metric metric : metrics) {
+ metric.update(tuple);
+ }
+ }
+ }
+
+ public Tuple toTuple(String collection, String field, int level, Traversal traversal) {
+ Map map = new HashMap();
+
+ map.put("node", id);
+ map.put("collection", collection);
+ map.put("field", field);
+ map.put("level", level);
+
+ boolean prependCollection = traversal.isMultiCollection();
+ List<String> cols = traversal.getCollections();
+
+ if(ancestors != null) {
+ List<String> l = new ArrayList();
+ for(String ancestor : ancestors) {
+ String[] ancestorParts = ancestor.split("\\^");
+
+ if(prependCollection) {
+ //prepend the collection
+ int colIndex = Integer.parseInt(ancestorParts[0]);
+ l.add(cols.get(colIndex)+"/"+ancestorParts[1]);
+ } else {
+ // Use only the ancestor id.
+ l.add(ancestorParts[1]);
+ }
+ }
+
+ map.put("ancestors", l);
+ }
+
+ if(metrics != null) {
+ for(Metric metric : metrics) {
+ map.put(metric.getIdentifier(), metric.getValue());
+ }
+ }
+
+ return new Tuple(map);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java
new file mode 100644
index 0000000..43d23b3
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.graph;
+
+import org.apache.solr.client.solrj.io.Tuple;
+import java.util.*;
+
+public class Traversal {
+
+ private List<Map<String, Node>> graph = new ArrayList();
+ private List<String> fields = new ArrayList();
+ private List<String> collections = new ArrayList();
+ private Set<Scatter> scatter = new HashSet();
+ private Set<String> collectionSet = new HashSet();
+ private boolean trackTraversal;
+ private int depth;
+
+ public void addLevel(Map<String, Node> level, String collection, String field) {
+ graph.add(level);
+ collections.add(collection);
+ collectionSet.add(collection);
+ fields.add(field);
+ ++depth;
+ }
+
+ public int getDepth() {
+ return depth;
+ }
+
+ public boolean getTrackTraversal() {
+ return this.trackTraversal;
+ }
+
+ public boolean visited(String nodeId, String ancestorId, Tuple tuple) {
+ for(Map<String, Node> level : graph) {
+ Node node = level.get(nodeId);
+ if(node != null) {
+ node.add(depth+"^"+ancestorId, tuple);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public boolean isMultiCollection() {
+ return collectionSet.size() > 1;
+ }
+
+ public List<Map<String, Node>> getGraph() {
+ return graph;
+ }
+
+ public void setScatter(Set<Scatter> scatter) {
+ this.scatter = scatter;
+ }
+
+ public Set<Scatter> getScatter() {
+ return this.scatter;
+ }
+
+ public void setTrackTraversal(boolean trackTraversal) {
+ this.trackTraversal = trackTraversal;
+ }
+
+ public List<String> getCollections() {
+ return this.collections;
+ }
+
+ public List<String> getFields() {
+ return this.fields;
+ }
+
+ public enum Scatter {
+ BRANCHES,
+ LEAVES;
+ }
+
+ public Iterator<Tuple> iterator() {
+ return new TraversalIterator(this, scatter);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java
new file mode 100644
index 0000000..7cfe375
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.graph;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.graph.Traversal.Scatter;
+
+class TraversalIterator implements Iterator {
+
+ private List<Map<String,Node>> graph;
+ private List<String> collections;
+ private List<String> fields;
+
+ private Iterator<Iterator<Node>> graphIterator;
+ private Iterator<Node> levelIterator;
+
+ private Iterator<String> fieldIterator;
+ private Iterator<String> collectionIterator;
+ private Iterator<Integer> levelNumIterator;
+ private String outField;
+ private String outCollection;
+ private int outLevel;
+ private Traversal traversal;
+
+ public TraversalIterator(Traversal traversal, Set<Scatter> scatter) {
+ this.traversal = traversal;
+ graph = traversal.getGraph();
+ collections = traversal.getCollections();
+ fields = traversal.getFields();
+
+ List<String> outCollections = new ArrayList();
+ List<String> outFields = new ArrayList();
+ List<Integer> levelNums = new ArrayList();
+ List<Iterator<Node>> levelIterators = new ArrayList();
+
+ if(scatter.contains(Scatter.BRANCHES)) {
+ if(graph.size() > 1) {
+ for(int i=0; i<graph.size()-1; i++) {
+ Map<String, Node> graphLevel = graph.get(i);
+ String collection = collections.get(i);
+ String field = fields.get(i);
+ outCollections.add(collection);
+ outFields.add(field);
+ levelNums.add(i);
+ levelIterators.add(graphLevel.values().iterator());
+ }
+ }
+ }
+
+ if(scatter.contains(Scatter.LEAVES)) {
+ int leavesLevel = graph.size() > 1 ? graph.size()-1 : 0 ;
+ Map<String, Node> graphLevel = graph.get(leavesLevel);
+ String collection = collections.get(leavesLevel);
+ String field = fields.get(leavesLevel);
+ levelNums.add(leavesLevel);
+ outCollections.add(collection);
+ outFields.add(field);
+ levelIterators.add(graphLevel.values().iterator());
+ }
+
+ graphIterator = levelIterators.iterator();
+ levelIterator = graphIterator.next();
+
+ fieldIterator = outFields.iterator();
+ collectionIterator = outCollections.iterator();
+ levelNumIterator = levelNums.iterator();
+
+ outField = fieldIterator.next();
+ outCollection = collectionIterator.next();
+ outLevel = levelNumIterator.next();
+ }
+
+ @Override
+ public boolean hasNext() {
+ if(levelIterator.hasNext()) {
+ return true;
+ } else {
+ if(graphIterator.hasNext()) {
+ levelIterator = graphIterator.next();
+ outField = fieldIterator.next();
+ outCollection = collectionIterator.next();
+ outLevel = levelNumIterator.next();
+ return hasNext();
+ } else {
+ return false;
+ }
+ }
+ }
+
+ @Override
+ public Tuple next() {
+ if(hasNext()) {
+ Node node = levelIterator.next();
+ return node.toTuple(outCollection, outField, outLevel, traversal);
+ } else {
+ return null;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
index ff0aefa..87e3035 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java
@@ -49,6 +49,10 @@ public class StreamContext implements Serializable{
this.entries.put(key, value);
}
+ public Map getEntries() {
+ return this.entries;
+ }
+
public void setSolrClientCache(SolrClientCache clientCache) {
this.clientCache = clientCache;
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java
index 0e19177..445b530 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java
@@ -49,6 +49,10 @@ public class CountMetric extends Metric implements Serializable {
init(functionName);
}
+
+ public String[] getColumns() {
+ return new String[0];
+ }
private void init(String functionName){
setFunctionName(functionName);
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java
index 8f2069e..0594bf4 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java
@@ -67,6 +67,11 @@ public class MaxMetric extends Metric implements Serializable {
}
}
+ public String[] getColumns() {
+ String[] cols = {columnName};
+ return cols;
+ }
+
public void update(Tuple tuple) {
Object o = tuple.get(columnName);
if(o instanceof Double) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java
index 0a5726c..097e04b 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java
@@ -80,6 +80,11 @@ public class MeanMetric extends Metric implements Serializable {
return new MeanMetric(columnName);
}
+ public String[] getColumns() {
+ String[] cols = {columnName};
+ return cols;
+ }
+
public double getValue() {
double dcount = (double)count;
if(longSum == 0) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java
index e732182..07a400a 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java
@@ -54,4 +54,6 @@ public abstract class Metric implements Serializable, Expressible {
public abstract double getValue();
public abstract void update(Tuple tuple);
public abstract Metric newInstance();
+ public abstract String[] getColumns();
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java
index 7c6060e..0a56580 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java
@@ -56,7 +56,12 @@ public class MinMetric extends Metric {
setFunctionName(functionName);
setIdentifier(functionName, "(", columnName, ")");
}
-
+
+
+ public String[] getColumns() {
+ String[] cols = {columnName};
+ return cols;
+ }
public double getValue() {
if(longMin == Long.MAX_VALUE) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java
index 805f978..578dae7 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java
@@ -58,6 +58,11 @@ public class SumMetric extends Metric implements Serializable {
setIdentifier(functionName, "(", columnName, ")");
}
+ public String[] getColumns() {
+ String[] cols = {columnName};
+ return cols;
+ }
+
public void update(Tuple tuple) {
Object o = tuple.get(columnName);
if(o instanceof Double) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/8659ea33/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java
index db58a90..b5231e2 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java
@@ -20,6 +20,7 @@ package org.apache.solr.client.solrj.io.graph;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -31,8 +32,15 @@ import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
+import org.apache.solr.client.solrj.io.comp.FieldComparator;
import org.apache.solr.client.solrj.io.stream.*;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
+import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
+import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric;
+import org.apache.solr.client.solrj.io.stream.metrics.MinMetric;
+import org.apache.solr.client.solrj.io.stream.metrics.SumMetric;
import org.apache.solr.cloud.AbstractFullDistribZkTestBase;
import org.apache.solr.cloud.AbstractZkTestCase;
import org.apache.solr.common.SolrInputDocument;
@@ -117,6 +125,8 @@ public class GraphExpressionTest extends AbstractFullDistribZkTestBase {
commit();
testShortestPathStream();
+ testGatherNodesStream();
+ testGatherNodesFriendsStream();
}
private void testShortestPathStream() throws Exception {
@@ -265,9 +275,399 @@ public class GraphExpressionTest extends AbstractFullDistribZkTestBase {
commit();
}
+
+ private void testGatherNodesStream() throws Exception {
+
+ indexr(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20");
+ indexr(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30");
+ indexr(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1");
+ indexr(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2");
+ indexr(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5");
+ indexr(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10");
+ indexr(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20");
+ indexr(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10");
+ indexr(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10");
+ indexr(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40");
+ indexr(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10");
+ indexr(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10");
+
+ commit();
+
+ List<Tuple> tuples = null;
+ Set<String> paths = null;
+ GatherNodesStream stream = null;
+ StreamContext context = new StreamContext();
+ SolrClientCache cache = new SolrClientCache();
+ context.setSolrClientCache(cache);
+
+ StreamFactory factory = new StreamFactory()
+ .withCollectionZkHost("collection1", zkServer.getZkAddress())
+ .withFunctionName("gatherNodes", GatherNodesStream.class)
+ .withFunctionName("search", CloudSolrStream.class)
+ .withFunctionName("count", CountMetric.class)
+ .withFunctionName("avg", MeanMetric.class)
+ .withFunctionName("sum", SumMetric.class)
+ .withFunctionName("min", MinMetric.class)
+ .withFunctionName("max", MaxMetric.class);
+
+ String expr = "gatherNodes(collection1, " +
+ "walk=\"product1->product_s\"," +
+ "gather=\"basket_s\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 4);
+ assertTrue(tuples.get(0).getString("node").equals("basket1"));
+ assertTrue(tuples.get(1).getString("node").equals("basket2"));
+ assertTrue(tuples.get(2).getString("node").equals("basket3"));
+ assertTrue(tuples.get(3).getString("node").equals("basket4"));
+
+ String expr2 = "gatherNodes(collection1, " +
+ expr+","+
+ "walk=\"node->basket_s\"," +
+ "gather=\"product_s\", count(*), avg(price_f), sum(price_f), min(price_f), max(price_f))";
+
+ stream = (GatherNodesStream)factory.constructStream(expr2);
+
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+
+ stream.setStreamContext(context);
+
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+
+
+ assertTrue(tuples.size() == 5);
+
+
+ assertTrue(tuples.get(0).getString("node").equals("product3"));
+ assertTrue(tuples.get(0).getDouble("count(*)").equals(3.0D));
+
+ assertTrue(tuples.get(1).getString("node").equals("product4"));
+ assertTrue(tuples.get(1).getDouble("count(*)").equals(2.0D));
+ assertTrue(tuples.get(1).getDouble("avg(price_f)").equals(30.0D));
+ assertTrue(tuples.get(1).getDouble("sum(price_f)").equals(60.0D));
+ assertTrue(tuples.get(1).getDouble("min(price_f)").equals(20.0D));
+ assertTrue(tuples.get(1).getDouble("max(price_f)").equals(40.0D));
+
+ assertTrue(tuples.get(2).getString("node").equals("product5"));
+ assertTrue(tuples.get(2).getDouble("count(*)").equals(1.0D));
+ assertTrue(tuples.get(3).getString("node").equals("product6"));
+ assertTrue(tuples.get(3).getDouble("count(*)").equals(1.0D));
+ assertTrue(tuples.get(4).getString("node").equals("product7"));
+ assertTrue(tuples.get(4).getDouble("count(*)").equals(1.0D));
+
+ //Test list of root nodes
+ expr = "gatherNodes(collection1, " +
+ "walk=\"product4, product7->product_s\"," +
+ "gather=\"basket_s\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+ tuples = getTuples(stream);
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 3);
+ assertTrue(tuples.get(0).getString("node").equals("basket2"));
+ assertTrue(tuples.get(1).getString("node").equals("basket3"));
+ assertTrue(tuples.get(2).getString("node").equals("basket4"));
+
+ //Test with negative filter query
+
+ expr = "gatherNodes(collection1, " +
+ "walk=\"product4, product7->product_s\"," +
+ "gather=\"basket_s\", fq=\"-basket_s:basket4\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 2);
+ assertTrue(tuples.get(0).getString("node").equals("basket2"));
+ assertTrue(tuples.get(1).getString("node").equals("basket3"));
+
+ cache.close();
+ del("*:*");
+ commit();
+ }
+
+ private void testGatherNodesFriendsStream() throws Exception {
+
+ indexr(id, "0", "from_s", "bill", "to_s", "jim", "message_t", "Hello jim");
+ indexr(id, "1", "from_s", "bill", "to_s", "sam", "message_t", "Hello sam");
+ indexr(id, "2", "from_s", "bill", "to_s", "max", "message_t", "Hello max");
+ indexr(id, "3", "from_s", "max", "to_s", "kip", "message_t", "Hello kip");
+ indexr(id, "4", "from_s", "sam", "to_s", "steve", "message_t", "Hello steve");
+ indexr(id, "5", "from_s", "jim", "to_s", "ann", "message_t", "Hello steve");
+
+ commit();
+
+ List<Tuple> tuples = null;
+ Set<String> paths = null;
+ GatherNodesStream stream = null;
+ StreamContext context = new StreamContext();
+ SolrClientCache cache = new SolrClientCache();
+ context.setSolrClientCache(cache);
+
+ StreamFactory factory = new StreamFactory()
+ .withCollectionZkHost("collection1", zkServer.getZkAddress())
+ .withFunctionName("gatherNodes", GatherNodesStream.class)
+ .withFunctionName("search", CloudSolrStream.class)
+ .withFunctionName("count", CountMetric.class)
+ .withFunctionName("hashJoin", HashJoinStream.class)
+ .withFunctionName("avg", MeanMetric.class)
+ .withFunctionName("sum", SumMetric.class)
+ .withFunctionName("min", MinMetric.class)
+ .withFunctionName("max", MaxMetric.class);
+
+ String expr = "gatherNodes(collection1, " +
+ "walk=\"bill->from_s\"," +
+ "gather=\"to_s\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 3);
+ assertTrue(tuples.get(0).getString("node").equals("jim"));
+ assertTrue(tuples.get(1).getString("node").equals("max"));
+ assertTrue(tuples.get(2).getString("node").equals("sam"));
+
+ //Test scatter branches, leaves and trackTraversal
+
+ expr = "gatherNodes(collection1, " +
+ "walk=\"bill->from_s\"," +
+ "gather=\"to_s\","+
+ "scatter=\"branches, leaves\", trackTraversal=\"true\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 4);
+ assertTrue(tuples.get(0).getString("node").equals("bill"));
+ assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
+ assertTrue(tuples.get(0).getStrings("ancestors").size() == 0);
+ assertTrue(tuples.get(1).getString("node").equals("jim"));
+ assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
+ List<String> ancestors = tuples.get(1).getStrings("ancestors");
+ System.out.println("##################### Ancestors:"+ancestors);
+ assert(ancestors.size() == 1);
+ assert(ancestors.get(0).equals("bill"));
+
+ assertTrue(tuples.get(2).getString("node").equals("max"));
+ assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
+ ancestors = tuples.get(2).getStrings("ancestors");
+ assert(ancestors.size() == 1);
+ assert(ancestors.get(0).equals("bill"));
+
+ assertTrue(tuples.get(3).getString("node").equals("sam"));
+ assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
+ ancestors = tuples.get(3).getStrings("ancestors");
+ assert(ancestors.size() == 1);
+ assert(ancestors.get(0).equals("bill"));
+
+ // Test query root
+
+ expr = "gatherNodes(collection1, " +
+ "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+
+ "walk=\"from_s->from_s\"," +
+ "gather=\"to_s\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 3);
+ assertTrue(tuples.get(0).getString("node").equals("jim"));
+ assertTrue(tuples.get(1).getString("node").equals("max"));
+ assertTrue(tuples.get(2).getString("node").equals("sam"));
+
+
+ // Test query root scatter branches
+
+ expr = "gatherNodes(collection1, " +
+ "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+
+ "walk=\"from_s->from_s\"," +
+ "gather=\"to_s\", scatter=\"branches, leaves\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+ assertTrue(tuples.size() == 4);
+ assertTrue(tuples.get(0).getString("node").equals("bill"));
+ assertTrue(tuples.get(0).getLong("level").equals(new Long(0)));
+ assertTrue(tuples.get(1).getString("node").equals("jim"));
+ assertTrue(tuples.get(1).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(2).getString("node").equals("max"));
+ assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(3).getString("node").equals("sam"));
+ assertTrue(tuples.get(3).getLong("level").equals(new Long(1)));
+
+ expr = "gatherNodes(collection1, " +
+ "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+
+ "walk=\"from_s->from_s\"," +
+ "gather=\"to_s\")";
+
+ String expr2 = "gatherNodes(collection1, " +
+ expr+","+
+ "walk=\"node->from_s\"," +
+ "gather=\"to_s\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr2);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+
+ assertTrue(tuples.size() == 3);
+ assertTrue(tuples.get(0).getString("node").equals("ann"));
+ assertTrue(tuples.get(1).getString("node").equals("kip"));
+ assertTrue(tuples.get(2).getString("node").equals("steve"));
+
+
+ //Test two traversals in the same expression
+ String expr3 = "hashJoin("+expr2+", hashed="+expr2+", on=\"node\")";
+
+ HashJoinStream hstream = (HashJoinStream)factory.constructStream(expr3);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ hstream.setStreamContext(context);
+
+ tuples = getTuples(hstream);
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+
+ assertTrue(tuples.size() == 3);
+ assertTrue(tuples.get(0).getString("node").equals("ann"));
+ assertTrue(tuples.get(1).getString("node").equals("kip"));
+ assertTrue(tuples.get(2).getString("node").equals("steve"));
+
+ //=================================
+
+
+ expr = "gatherNodes(collection1, " +
+ "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+
+ "walk=\"from_s->from_s\"," +
+ "gather=\"to_s\")";
+
+ expr2 = "gatherNodes(collection1, " +
+ expr+","+
+ "walk=\"node->from_s\"," +
+ "gather=\"to_s\", scatter=\"branches, leaves\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr2);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+
+
+ assertTrue(tuples.size() == 7);
+ assertTrue(tuples.get(0).getString("node").equals("ann"));
+ assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
+ assertTrue(tuples.get(1).getString("node").equals("bill"));
+ assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
+ assertTrue(tuples.get(2).getString("node").equals("jim"));
+ assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(3).getString("node").equals("kip"));
+ assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
+ assertTrue(tuples.get(4).getString("node").equals("max"));
+ assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(5).getString("node").equals("sam"));
+ assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(6).getString("node").equals("steve"));
+ assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
+
+ //Add a cycle from jim to bill
+ indexr(id, "6", "from_s", "jim", "to_s", "bill", "message_t", "Hello steve");
+ indexr(id, "7", "from_s", "sam", "to_s", "bill", "message_t", "Hello steve");
+
+ commit();
+
+ expr = "gatherNodes(collection1, " +
+ "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+
+ "walk=\"from_s->from_s\"," +
+ "gather=\"to_s\", trackTraversal=\"true\")";
+
+ expr2 = "gatherNodes(collection1, " +
+ expr+","+
+ "walk=\"node->from_s\"," +
+ "gather=\"to_s\", scatter=\"branches, leaves\", trackTraversal=\"true\")";
+
+ stream = (GatherNodesStream)factory.constructStream(expr2);
+ context = new StreamContext();
+ context.setSolrClientCache(cache);
+ stream.setStreamContext(context);
+
+ tuples = getTuples(stream);
+ Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
+
+ assertTrue(tuples.size() == 7);
+ assertTrue(tuples.get(0).getString("node").equals("ann"));
+ assertTrue(tuples.get(0).getLong("level").equals(new Long(2)));
+ //Bill should now have one ancestor
+ assertTrue(tuples.get(1).getString("node").equals("bill"));
+ assertTrue(tuples.get(1).getLong("level").equals(new Long(0)));
+ assertTrue(tuples.get(1).getStrings("ancestors").size() == 2);
+ List<String> anc = tuples.get(1).getStrings("ancestors");
+
+ Collections.sort(anc);
+ assertTrue(anc.get(0).equals("jim"));
+ assertTrue(anc.get(1).equals("sam"));
+
+ assertTrue(tuples.get(2).getString("node").equals("jim"));
+ assertTrue(tuples.get(2).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(3).getString("node").equals("kip"));
+ assertTrue(tuples.get(3).getLong("level").equals(new Long(2)));
+ assertTrue(tuples.get(4).getString("node").equals("max"));
+ assertTrue(tuples.get(4).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(5).getString("node").equals("sam"));
+ assertTrue(tuples.get(5).getLong("level").equals(new Long(1)));
+ assertTrue(tuples.get(6).getString("node").equals("steve"));
+ assertTrue(tuples.get(6).getLong("level").equals(new Long(2)));
+
+ cache.close();
+ del("*:*");
+ commit();
+ }
+
+
+
protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
tupleStream.open();
- List<Tuple> tuples = new ArrayList<Tuple>();
+ List<Tuple> tuples = new ArrayList();
for(Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) {
tuples.add(t);
}