You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/07/25 18:29:45 UTC
[systemds] branch master updated: [SYSTEMDS-2589,
2510] Removed dependency commons.collections
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5d00ea1 [SYSTEMDS-2589,2510] Removed dependency commons.collections
5d00ea1 is described below
commit 5d00ea1732e66e5bbcc32dbd065b6fe2b8de0cdf
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Jul 25 20:27:03 2020 +0200
[SYSTEMDS-2589,2510] Removed dependency commons.collections
This patch removes the rarely used (and thus unnecessary) dependency to
commons.collections to simplify packaging/installation in various
environments. The needed functionality has been re-implemented from
scratch, which fixes undetected bugs (ignore order on list compare), and
improved performance (containsAny, constant-time lookups) due to
specialization.
Furthermore, this patch also replaces all occurrences of 'blacklist'
with more inclusive language.
---
pom.xml | 6 -
.../java/org/apache/sysds/api/jmlc/Connection.java | 4 +-
.../java/org/apache/sysds/api/jmlc/JMLCUtils.java | 6 +-
.../sysds/hops/codegen/cplan/CNodeMultiAgg.java | 4 +-
.../codegen/opt/PlanSelectionFuseCostBased.java | 13 +-
.../codegen/opt/PlanSelectionFuseCostBasedV2.java | 37 +++---
.../sysds/hops/codegen/opt/ReachabilityGraph.java | 2 +-
.../hops/codegen/template/CPlanMemoTable.java | 25 ++--
.../apache/sysds/hops/ipa/FunctionCallGraph.java | 8 +-
.../RewriteMatrixMultChainOptimization.java | 2 +-
.../ParameterizedBuiltinFunctionExpression.java | 8 +-
.../runtime/controlprogram/LocalVariableMap.java | 8 +-
.../runtime/controlprogram/ParForProgramBlock.java | 9 +-
.../parfor/CachedReuseVariables.java | 6 +-
.../parfor/RemoteParForSparkWorker.java | 6 +-
.../runtime/transform/decode/DecoderFactory.java | 10 +-
.../runtime/transform/encode/EncoderFactory.java | 13 +-
.../sysds/runtime/transform/meta/TfMetaUtils.java | 6 +-
.../apache/sysds/runtime/util/CollectionUtils.java | 144 +++++++++++++++++++++
.../sysds/runtime/util/ProgramConverter.java | 6 +-
.../apache/sysds/runtime/util/UtilFunctions.java | 51 --------
.../test/functions/frame/FrameConverterTest.java | 5 +-
.../transform/TransformCSVFrameEncodeReadTest.java | 2 +-
.../transform/TransformFrameEncodeApplyTest.java | 2 +-
24 files changed, 233 insertions(+), 150 deletions(-)
diff --git a/pom.xml b/pom.xml
index c32ffad..063e532 100644
--- a/pom.xml
+++ b/pom.xml
@@ -979,12 +979,6 @@
<scope>test</scope>
</dependency>
- <dependency>
- <groupId>commons-collections</groupId>
- <artifactId>commons-collections</artifactId>
- <version>3.2.1</version>
- </dependency>
-
<!-- fast java compiler for codegen, consistent version w/ spark -->
<dependency>
<groupId>org.codehaus.janino</groupId>
diff --git a/src/main/java/org/apache/sysds/api/jmlc/Connection.java b/src/main/java/org/apache/sysds/api/jmlc/Connection.java
index 22bdac8..44e9114 100644
--- a/src/main/java/org/apache/sysds/api/jmlc/Connection.java
+++ b/src/main/java/org/apache/sysds/api/jmlc/Connection.java
@@ -63,8 +63,8 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.DataConverter;
-import org.apache.sysds.runtime.util.UtilFunctions;
/**
* Interaction with SystemDS using the JMLC (Java Machine Learning Connector) API is initiated with
@@ -241,7 +241,7 @@ public class Connection implements Closeable
throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs));
//check for valid names of input and output variables
- String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream()
+ String[] invalidVars = CollectionUtils.asSet(inputs, outputs).stream()
.filter(k -> k==null || k.startsWith("$")).toArray(String[]::new);
if( invalidVars.length > 0 )
throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars));
diff --git a/src/main/java/org/apache/sysds/api/jmlc/JMLCUtils.java b/src/main/java/org/apache/sysds/api/jmlc/JMLCUtils.java
index 0d9aa60..878a5fc 100644
--- a/src/main/java/org/apache/sysds/api/jmlc/JMLCUtils.java
+++ b/src/main/java/org/apache/sysds/api/jmlc/JMLCUtils.java
@@ -51,16 +51,16 @@ public class JMLCUtils
*/
public static void cleanupRuntimeProgram( Program prog, String[] outputs) {
Map<String, FunctionProgramBlock> funcMap = prog.getFunctionProgramBlocks();
- HashSet<String> blacklist = new HashSet<>(Arrays.asList(outputs));
+ HashSet<String> excludeList = new HashSet<>(Arrays.asList(outputs));
if( funcMap != null && !funcMap.isEmpty() ) {
for( Entry<String, FunctionProgramBlock> e : funcMap.entrySet() ) {
FunctionProgramBlock fpb = e.getValue();
for( ProgramBlock pb : fpb.getChildBlocks() )
- rCleanupRuntimeProgram(pb, blacklist);
+ rCleanupRuntimeProgram(pb, excludeList);
}
}
for( ProgramBlock pb : prog.getProgramBlocks() )
- rCleanupRuntimeProgram(pb, blacklist);
+ rCleanupRuntimeProgram(pb, excludeList);
}
/**
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
index 7395e89..2a5dec8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
@@ -22,10 +22,10 @@ package org.apache.sysds.hops.codegen.cplan;
import java.util.ArrayList;
import java.util.Arrays;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
public class CNodeMultiAgg extends CNodeTpl
@@ -181,7 +181,7 @@ public class CNodeMultiAgg extends CNodeTpl
return false;
CNodeMultiAgg that = (CNodeMultiAgg)o;
return super.equals(o)
- && CollectionUtils.isEqualCollection(_aggOps, that._aggOps)
+ && CollectionUtils.equals(_aggOps, that._aggOps)
&& equalInputReferences(
_outputs, that._outputs, _inputs, that._inputs);
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.java
index 9f1df43..e19b5a5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -31,7 +31,6 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -57,6 +56,7 @@ import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
@@ -330,11 +330,11 @@ public class PlanSelectionFuseCostBased extends PlanSelection
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
&& isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
- List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
- memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist));
+ List<MemoTableEntry> excludeList = memo.get(hopID, TemplateType.ROW);
+ memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(excludeList));
if( LOG.isTraceEnabled() ) {
LOG.trace("Removed row memo table entries w/o aggregation: "
- + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+ + Arrays.toString(excludeList.toArray(new MemoTableEntry[0])));
}
}
}
@@ -350,7 +350,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if( rmEntry != null ) {
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
- memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+ memo.getPlansExcludeListed().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if( LOG.isTraceEnabled() )
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
@@ -838,8 +838,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
for( Long hopID : _aggregates.keySet() )
ret &= !that._inputAggs.contains(hopID);
//check partial shared reads
- ret &= !CollectionUtils.intersection(
- _fusedInputs, that._fusedInputs).isEmpty();
+ ret &= CollectionUtils.containsAny(_fusedInputs, that._fusedInputs);
//check consistent sizes (result correctness)
Hop in1 = _aggregates.values().iterator().next();
Hop in2 = that._aggregates.values().iterator().next();
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 2f8437e..0b20876 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -32,7 +32,6 @@ import java.util.List;
import java.util.Map.Entry;
import java.util.stream.Collectors;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
@@ -70,6 +69,7 @@ import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
@@ -661,17 +661,17 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable memo, PlanPartition part) {
//get row entries that are (a) reachable from rowwise ops (top down) other than
//operator root nodes, or dependent upon row-wise ops (bottom up)
- HashSet<Long> blacklist = new HashSet<>();
+ HashSet<Long> excludeList = new HashSet<>();
HashSet<Pair<Long, Integer>> visited = new HashSet<>();
for( Long hopID : part.getRoots() ) {
rCollectDependentRowOps(memo.getHopRefs().get(hopID),
- memo, part, blacklist, visited, null, false);
+ memo, part, excludeList, visited, null, false);
}
- return blacklist;
+ return excludeList;
}
private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part,
- HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp)
+ HashSet<Long> excludeList, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp)
{
//avoid redundant evaluation of processed and non-partition nodes
Pair<Long, Integer> key = Pair.of(hop.getHopID(),
@@ -688,9 +688,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
&& memo.contains(hop.getHopID(), TemplateType.ROW)
&& !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
if( inRow && foundRowOp )
- blacklist.add(hop.getHopID());
+ excludeList.add(hop.getHopID());
if( isRowAggOp(hop, inRow) || diffPlans ) {
- blacklist.add(hop.getHopID());
+ excludeList.add(hop.getHopID());
foundRowOp = true;
}
@@ -699,16 +699,16 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
boolean lfoundRowOp = foundRowOp && me != null
&& (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type));
rCollectDependentRowOps(hop.getInput().get(i), memo,
- part, blacklist, visited, me!=null?me.type:null, lfoundRowOp);
+ part, excludeList, visited, me!=null?me.type:null, lfoundRowOp);
}
//process node itself (bottom-up)
- if( !blacklist.contains(hop.getHopID()) ) {
+ if( !excludeList.contains(hop.getHopID()) ) {
for( int i=0; i<hop.getInput().size(); i++ )
if( me != null && me.type == TemplateType.ROW
&& (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type))
- && blacklist.contains(hop.getInput().get(i).getHopID()) ) {
- blacklist.add(hop.getHopID());
+ && excludeList.contains(hop.getInput().get(i).getHopID()) ) {
+ excludeList.add(hop.getHopID());
}
}
@@ -750,23 +750,23 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
|| (hop instanceof AggBinaryOp && in.getDim1() <= in.getBlocksize()
&& HopRewriteUtils.isTransposeOperation(in));
if( isSpark && !validNcol ) {
- List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
+ List<MemoTableEntry> excludeList = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW);
memo.removeAllRefTo(hopID, TemplateType.ROW);
if( LOG.isTraceEnabled() ) {
LOG.trace("Removed row memo table entries w/ violated blocksize constraint ("+hopID+"): "
- + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+ + Arrays.toString(excludeList.toArray(new MemoTableEntry[0])));
}
}
}
}
//prune row aggregates with pure cellwise operations
- //(we determine a blacklist of all operators in a partition that either
+ //(we determine an excludeList of all operators in a partition that either
//depend upon row aggregates or on which row aggregates depend)
- HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, part);
+ HashSet<Long> excludeList = collectIrreplaceableRowOps(memo, part);
for( Long hopID : part.getPartition() ) {
- if( blacklist.contains(hopID) ) continue;
+ if( excludeList.contains(hopID) ) continue;
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if( me != null && me.type == TemplateType.ROW
&& memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL) ) {
@@ -790,7 +790,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if( rmEntry != null ) {
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
- memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+ memo.getPlansExcludeListed().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if( LOG.isTraceEnabled() )
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
@@ -1309,8 +1309,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
for( Long hopID : _aggregates.keySet() )
ret &= !that._inputAggs.contains(hopID);
//check partial shared reads
- ret &= !CollectionUtils.intersection(
- _fusedInputs, that._fusedInputs).isEmpty();
+ ret &= CollectionUtils.containsAny(_fusedInputs, that._fusedInputs);
//check consistent sizes (result correctness)
Hop in1 = _aggregates.values().iterator().next();
Hop in2 = that._aggregates.values().iterator().next();
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysds/hops/codegen/opt/ReachabilityGraph.java
index 3ee6adf..e62b98c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/ReachabilityGraph.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/ReachabilityGraph.java
@@ -27,13 +27,13 @@ import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.codegen.opt.PlanSelection.VisitMarkCost;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
/**
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanMemoTable.java
index d133dd9..c6db903 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanMemoTable.java
@@ -41,6 +41,7 @@ import org.apache.sysds.hops.codegen.opt.InterestingPoint;
import org.apache.sysds.hops.codegen.opt.PlanSelection;
import org.apache.sysds.hops.codegen.template.TemplateBase.CloseType;
import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
public class CPlanMemoTable
@@ -49,20 +50,20 @@ public class CPlanMemoTable
protected HashMap<Long, List<MemoTableEntry>> _plans;
protected HashMap<Long, Hop> _hopRefs;
- protected HashSet<Long> _plansBlacklist;
+ protected HashSet<Long> _plansExcludeList;
public CPlanMemoTable() {
_plans = new HashMap<>();
_hopRefs = new HashMap<>();
- _plansBlacklist = new HashSet<>();
+ _plansExcludeList = new HashSet<>();
}
public HashMap<Long, List<MemoTableEntry>> getPlans() {
return _plans;
}
- public HashSet<Long> getPlansBlacklisted() {
- return _plansBlacklist;
+ public HashSet<Long> getPlansExcludeListed() {
+ return _plansExcludeList;
}
public HashMap<Long, Hop> getHopRefs() {
@@ -95,7 +96,7 @@ public class CPlanMemoTable
public boolean contains(long hopID, boolean checkClose, TemplateType... type) {
if( !checkClose && type.length==1 )
return contains(hopID, type[0]);
- Set<TemplateType> probe = UtilFunctions.asSet(type);
+ Set<TemplateType> probe = CollectionUtils.asSet(type);
return contains(hopID) && get(hopID).stream()
.anyMatch(p -> (!checkClose||!p.isClosed()) && probe.contains(p.type));
}
@@ -126,7 +127,7 @@ public class CPlanMemoTable
}
public boolean containsTopLevel(long hopID) {
- return !_plansBlacklist.contains(hopID)
+ return !_plansExcludeList.contains(hopID)
&& getBest(hopID) != null;
}
@@ -161,9 +162,9 @@ public class CPlanMemoTable
_plans.get(hop.getHopID()).addAll(P.plans);
}
- public void remove(Hop hop, Set<MemoTableEntry> blackList) {
+ public void remove(Hop hop, Set<MemoTableEntry> excludeList) {
_plans.get(hop.getHopID())
- .removeIf(p -> blackList.contains(p));
+ .removeIf(p -> excludeList.contains(p));
}
public void remove(Hop hop, TemplateType type) {
@@ -252,14 +253,14 @@ public class CPlanMemoTable
}
//prune dominated plans (e.g., plan referenced by other plan and this
- //other plan is single consumer) by marking it as blacklisted because
+ //other plan is single consumer) by marking it as exclude-listed because
//the chain of entries is still required for cplan construction
if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
for( MemoTableEntry me : e.getValue() ) {
for( int i=0; i<=2; i++ )
if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
- _plansBlacklist.add(me.input(i));
+ _plansExcludeList.add(me.input(i));
}
}
@@ -367,8 +368,8 @@ public class CPlanMemoTable
sb.append(Arrays.toString(e.getValue().toArray(new MemoTableEntry[0]))+"\n");
}
sb.append("----------------------------------\n");
- sb.append("Blacklisted Plans: ");
- sb.append(Arrays.toString(_plansBlacklist.toArray(new Long[0]))+"\n");
+ sb.append("ExcludeListed Plans: ");
+ sb.append(Arrays.toString(_plansExcludeList.toArray(new Long[0]))+"\n");
sb.append("----------------------------------\n");
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index 06c59ec..394cd70 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -273,14 +273,14 @@ public class FunctionCallGraph
/**
* Returns all functions that are reachable either directly or indirectly
* form the main program, except the main program itself and the given
- * blacklist of function names.
+ * exclude-list of function names.
*
- * @param blacklist list of function keys to exclude
+ * @param excludeList list of function keys to exclude
* @return set of function keys (namespace and name)
*/
- public Set<String> getReachableFunctions(Set<String> blacklist) {
+ public Set<String> getReachableFunctions(Set<String> excludeList) {
return _fGraph.keySet().stream()
- .filter(p -> !blacklist.contains(p) && !MAIN_FUNCTION_KEY.equals(p))
+ .filter(p -> !excludeList.contains(p) && !MAIN_FUNCTION_KEY.equals(p))
.collect(Collectors.toSet());
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
index 21cd973..7ae06ca 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
@@ -22,13 +22,13 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.utils.Explain;
/**
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 47d4972..2e33676 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -35,7 +35,7 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
-import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.runtime.util.CollectionUtils;
public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
@@ -286,7 +286,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
raiseValidateError("Should provide more arguments for function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS);
}
//check for invalid parameters
- Set<String> valid = UtilFunctions.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+ Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
@@ -429,7 +429,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
private void validateExtractTriangular(DataIdentifier output, Builtins op, boolean conditional) {
//check for invalid parameters
- Set<String> valid = UtilFunctions.asSet("target", "diag", "values");
+ Set<String> valid = CollectionUtils.asSet("target", "diag", "values");
checkInvalidParameters(op, getVarParams(), valid);
//check existence and correctness of arguments
@@ -524,7 +524,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
private void validateRemoveEmpty(DataIdentifier output, boolean conditional) {
//check for invalid parameters
- Set<String> valid = UtilFunctions.asSet("target", "margin", "select", "empty.return");
+ Set<String> valid = CollectionUtils.asSet("target", "margin", "select", "empty.return");
Set<String> invalid = _varParams.keySet().stream()
.filter(k -> !valid.contains(k)).collect(Collectors.toSet());
if( !invalid.isEmpty() )
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
index 1ac47b7..e5afdd1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
@@ -106,14 +106,14 @@ public class LocalVariableMap implements Cloneable
localMap.clear();
}
- public void removeAllIn(Set<String> blacklist) {
+ public void removeAllIn(Set<String> excludeList) {
localMap.entrySet().removeIf(
- e -> blacklist.contains(e.getKey()));
+ e -> excludeList.contains(e.getKey()));
}
- public void removeAllNotIn(Set<String> blacklist) {
+ public void removeAllNotIn(Set<String> excludeList) {
localMap.entrySet().removeIf(
- e -> !blacklist.contains(e.getKey()));
+ e -> !excludeList.contains(e.getKey()));
}
public boolean hasReferences( Data d ) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index ffec99d..a92a7ee 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -83,6 +83,7 @@ import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
@@ -1104,10 +1105,10 @@ public class ParForProgramBlock extends ForProgramBlock
}
}
- private void exportMatricesToHDFS(ExecutionContext ec, String... blacklistNames)
+ private void exportMatricesToHDFS(ExecutionContext ec, String... excludeListNames)
{
ParForStatementBlock sb = (ParForStatementBlock)getStatementBlock();
- Set<String> blacklist = UtilFunctions.asSet(blacklistNames);
+ Set<String> excludeList = CollectionUtils.asSet(excludeListNames);
if( LIVEVAR_AWARE_EXPORT && sb != null)
{
@@ -1115,7 +1116,7 @@ public class ParForProgramBlock extends ForProgramBlock
//export only variables that are read in the body
VariableSet varsRead = sb.variablesRead();
for (String key : ec.getVariables().keySet() ) {
- if( varsRead.containsVariable(key) && !blacklist.contains(key) ) {
+ if( varsRead.containsVariable(key) && !excludeList.contains(key) ) {
Data d = ec.getVariable(key);
if( d.getDataType() == DataType.MATRIX )
((MatrixObject)d).exportData(_replicationExport);
@@ -1126,7 +1127,7 @@ public class ParForProgramBlock extends ForProgramBlock
{
//export all matrices in symbol table
for (String key : ec.getVariables().keySet() ) {
- if( !blacklist.contains(key) ) {
+ if( !excludeList.contains(key) ) {
Data d = ec.getVariable(key);
if( d.getDataType() == DataType.MATRIX )
((MatrixObject)d).exportData(_replicationExport);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/CachedReuseVariables.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/CachedReuseVariables.java
index 775aaf8..1cdf594 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/CachedReuseVariables.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/CachedReuseVariables.java
@@ -46,7 +46,7 @@ public class CachedReuseVariables
}
@SuppressWarnings("unused")
- public synchronized void reuseVariables(long pfid, LocalVariableMap vars, Collection<String> blacklist, Map<String, Broadcast<CacheBlock>> _brInputs, boolean cleanCache) {
+ public synchronized void reuseVariables(long pfid, LocalVariableMap vars, Collection<String> excludeList, Map<String, Broadcast<CacheBlock>> _brInputs, boolean cleanCache) {
//fetch the broadcast variables
if (ParForProgramBlock.ALLOW_BROADCAST_INPUTS && !containsVars(pfid)) {
@@ -63,8 +63,8 @@ public class CachedReuseVariables
if( cleanCache )
_data.clear();
tmp = new LocalVariableMap(vars);
- tmp.removeAllIn((blacklist instanceof HashSet) ?
- (HashSet<String>)blacklist : new HashSet<>(blacklist));
+ tmp.removeAllIn((excludeList instanceof HashSet) ?
+ (HashSet<String>)excludeList : new HashSet<>(excludeList));
_data.put(pfid, new SoftReference<>(tmp));
}
//reuse existing reuse map
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/RemoteParForSparkWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
index ee3cbb4..b002f04 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
@@ -39,8 +39,8 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.ProgramConverter;
-import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;
@@ -134,9 +134,9 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun
//reuse shared inputs (to read shared inputs once per process instead of once per core;
//we reuse everything except result variables and partitioned input matrices)
- Collection<String> blacklist = UtilFunctions.asSet(_resultVars.stream()
+ Collection<String> excludeList = CollectionUtils.asSet(_resultVars.stream()
.map(v -> v._name).collect(Collectors.toList()), _ec.getVarListPartitioned());
- reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist, _brInputs, _cleanCache);
+ reuseVars.reuseVariables(_jobid, _ec.getVariables(), excludeList, _brInputs, _cleanCache);
//setup the buffer pool
RemoteParForUtils.setupBufferPool(_workerID);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
index 5ed7318..977d494 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
@@ -23,7 +23,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.common.Types.ValueType;
@@ -32,7 +31,8 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
-
+import static org.apache.sysds.runtime.util.CollectionUtils.except;
+import static org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
public class DecoderFactory
{
@@ -40,7 +40,6 @@ public class DecoderFactory
return createDecoder(spec, colnames, schema, meta, meta.getNumColumns());
}
- @SuppressWarnings("unchecked")
public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int clen)
{
Decoder decoder = null;
@@ -56,10 +55,9 @@ public class DecoderFactory
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString())));
List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString())));
- rcIDs = new ArrayList<Integer>(CollectionUtils.union(rcIDs, dcIDs));
+ rcIDs = unionDistinct(rcIDs, dcIDs);
int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns();
- List<Integer> ptIDs = new ArrayList<Integer>(CollectionUtils
- .subtract(UtilFunctions.getSeqList(1, len, 1), rcIDs));
+ List<Integer> ptIDs = except(UtilFunctions.getSeqList(1, len, 1), rcIDs);
//create default schema if unspecified (with double columns for pass-through)
if( schema == null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index c90a170..b7443f4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -24,7 +24,6 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.common.Types.ValueType;
@@ -33,6 +32,9 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
+import static org.apache.sysds.runtime.util.CollectionUtils.except;
+import static org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
+
public class EncoderFactory
{
@@ -45,7 +47,6 @@ public class EncoderFactory
return createEncoder(spec, colnames, lschema, meta);
}
- @SuppressWarnings("unchecked")
public static Encoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) {
Encoder encoder = null;
int clen = schema.length;
@@ -64,11 +65,9 @@ public class EncoderFactory
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString())));
List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames);
//note: any dummycode column requires recode as preparation, unless it follows binning
- rcIDs = new ArrayList<Integer>(CollectionUtils.subtract(
- CollectionUtils.union(rcIDs, CollectionUtils.subtract(dcIDs, binIDs)), haIDs));
- List<Integer> ptIDs = new ArrayList<Integer>(CollectionUtils.subtract(
- CollectionUtils.subtract(UtilFunctions.getSeqList(1, clen, 1),
- CollectionUtils.union(rcIDs,haIDs)), binIDs));
+ rcIDs = except(unionDistinct(rcIDs, except(dcIDs, binIDs)), haIDs);
+ List<Integer> ptIDs = except(except(UtilFunctions.getSeqList(1, clen, 1),
+ unionDistinct(rcIDs,haIDs)), binIDs);
List<Integer> oIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.OMIT.toString())));
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 0e9db76..3f10dd0 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -25,13 +25,11 @@ import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
-import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
@@ -46,6 +44,7 @@ import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.decode.DecoderRecode;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -336,7 +335,6 @@ public class TfMetaUtils
* @return list of column ids
* @throws IOException if IOException occurs
*/
- @SuppressWarnings("unchecked")
private static List<Integer> parseRecodeColIDs(String spec, String[] colnames)
throws IOException
{
@@ -352,7 +350,7 @@ public class TfMetaUtils
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString())));
List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString())));
- specRecodeIDs = new ArrayList<Integer>(CollectionUtils.union(rcIDs, dcIDs));
+ specRecodeIDs = CollectionUtils.unionDistinct(rcIDs, dcIDs);
}
catch(Exception ex) {
throw new IOException(ex);
diff --git a/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java
new file mode 100644
index 0000000..a26f3b8
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Set;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+
+public class CollectionUtils {
+
+ @SafeVarargs
+ public static <T> List<T> asList(List<T>... inputs) {
+ List<T> ret = new ArrayList<>();
+ for( List<T> list : inputs )
+ ret.addAll(list);
+ return ret;
+ }
+
+ @SafeVarargs
+ public static <T> ArrayList<T> asArrayList(T... inputs) {
+ ArrayList<T> ret = new ArrayList<>();
+ for( T list : inputs )
+ ret.add(list);
+ return ret;
+ }
+
+ @SafeVarargs
+ public static <T> Set<T> asSet(T... inputs) {
+ Set<T> ret = new HashSet<>();
+ for( T element : inputs )
+ ret.add(element);
+ return ret;
+ }
+
+ @SafeVarargs
+ public static <T> Set<T> asSet(T[]... inputs) {
+ Set<T> ret = new HashSet<>();
+ for( T[] input : inputs )
+ for( T element : input )
+ ret.add(element);
+ return ret;
+ }
+
+ @SafeVarargs
+ public static <T> Set<T> asSet(List<T>... inputs) {
+ Set<T> ret = new HashSet<>();
+ for( List<T> list : inputs )
+ ret.addAll(list);
+ return ret;
+ }
+
+ public static <T> Stream<T> getStream(Iterator<T> iter) {
+ Iterable<T> iterable = () -> iter;
+ return StreamSupport.stream(iterable.spliterator(), false);
+ }
+
+ public static <T> boolean equals(List<T> a, List<T> b) {
+ //basic checks for early abort
+ if( a == b ) return true; //incl both null
+ if( a == null || b == null || a.size() != b.size() )
+ return false;
+ ListIterator<T> iter1 = a.listIterator();
+ ListIterator<T> iter2 = b.listIterator();
+ while( iter1.hasNext() ) //equal length
+ if( !iter1.next().equals(iter2.next()) )
+ return false;
+ return true;
+ }
+
+ public static <T> boolean containsAny(Collection<T> a, Collection<T> b) {
+ //build probe table for constant-time lookups (reuse hashsets)
+ Collection<T> tmp1 = a.size() < b.size() ? a : b;
+ Set<T> probe = (tmp1 instanceof HashSet) ?
+ (Set<T>) tmp1 : new HashSet<>(tmp1);
+ //probe if there is a non-empty intersection
+ Collection<T> tmp2 = (a.size() < b.size() ? b : a);
+ for( T item : tmp2 )
+ if( probe.contains(item) )
+ return true;
+ return false;
+ }
+
+ @SuppressWarnings("unchecked")
+ public static <T> List<T> unionDistinct(List<T> a, List<T> b) {
+ List<T> ret = new ArrayList<>(); // in-order results
+ Set<T> probe = new HashSet<>(); // constant-time probe table
+ for(List<T> list : new List[] {a,b})
+ for( T item : list )
+ if( !probe.contains(item) ) {
+ ret.add(item);
+ probe.add(item);
+ }
+ return ret;
+ }
+
+ public static <T> List<T> unionAll(List<T> a, List<T> b) {
+ return CollectionUtils.asList(a, b);
+ }
+
+
+ public static <T> List<T> except(List<T> a, List<T> exceptions) {
+ List<T> ret = new ArrayList<>();
+ Set<T> probe = new HashSet<>(exceptions);
+ for( T item : a )
+ if( !probe.contains(item) )
+ ret.add(item);
+ return ret;
+ }
+
+ public static <T> void addAll(Collection<T> a, T[] b) {
+ for( T item : b )
+ a.add(item);
+ }
+
+ public static <T> int cardinality(T a, List<T> b) {
+ int count = 0;
+ for(T item : b)
+ count += a.equals(item) ? 1 : 0;
+ return count;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 38fc256..144f347 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -521,7 +521,7 @@ public class ProgramConverter
IfStatement origstmt = (IfStatement) orig.getStatement(0);
IfStatement istmt = new IfStatement(); //only shallow
istmt.setConditionalPredicate(origstmt.getConditionalPredicate());
- isb.setStatements(UtilFunctions.asArrayList(istmt));
+ isb.setStatements(CollectionUtils.asArrayList(istmt));
for( StatementBlock c : origstmt.getIfBody() )
istmt.addStatementBlockIfBody(rCreateDeepCopyStatementBlock(c));
for( StatementBlock c : origstmt.getElseBody() )
@@ -534,7 +534,7 @@ public class ProgramConverter
WhileStatement origstmt = (WhileStatement) orig.getStatement(0);
WhileStatement wstmt = new WhileStatement(); //only shallow
wstmt.setPredicate(origstmt.getConditionalPredicate());
- wsb.setStatements(UtilFunctions.asArrayList(wstmt));
+ wsb.setStatements(CollectionUtils.asArrayList(wstmt));
for( StatementBlock c : origstmt.getBody() )
wstmt.addStatementBlock(rCreateDeepCopyStatementBlock(c));
ret = wsb;
@@ -545,7 +545,7 @@ public class ProgramConverter
ForStatement origstmt = (ForStatement) orig.getStatement(0);
ForStatement fstmt = new ForStatement(); //only shallow
fstmt.setPredicate(origstmt.getIterablePredicate());
- fsb.setStatements(UtilFunctions.asArrayList(fstmt));
+ fsb.setStatements(CollectionUtils.asArrayList(fstmt));
for( StatementBlock c : origstmt.getBody() )
fstmt.addStatementBlock(rCreateDeepCopyStatementBlock(c));
ret = fsb;
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 885d584..88b7041 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -39,14 +39,9 @@ import org.apache.sysds.runtime.meta.TensorCharacteristics;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
-import java.util.HashSet;
-import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Set;
import java.util.concurrent.Future;
-import java.util.stream.Stream;
-import java.util.stream.StreamSupport;
public class UtilFunctions
{
@@ -760,52 +755,6 @@ public class UtilFunctions
return false;
}
- @SafeVarargs
- public static <T> List<T> asList(List<T>... inputs) {
- List<T> ret = new ArrayList<>();
- for( List<T> list : inputs )
- ret.addAll(list);
- return ret;
- }
-
- @SafeVarargs
- public static <T> ArrayList<T> asArrayList(T... inputs) {
- ArrayList<T> ret = new ArrayList<>();
- for( T list : inputs )
- ret.add(list);
- return ret;
- }
-
- @SafeVarargs
- public static <T> Set<T> asSet(List<T>... inputs) {
- Set<T> ret = new HashSet<>();
- for( List<T> list : inputs )
- ret.addAll(list);
- return ret;
- }
-
- @SafeVarargs
- public static <T> Set<T> asSet(T[]... inputs) {
- Set<T> ret = new HashSet<>();
- for( T[] input : inputs )
- for( T element : input )
- ret.add(element);
- return ret;
- }
-
- @SafeVarargs
- public static <T> Set<T> asSet(T... inputs) {
- Set<T> ret = new HashSet<>();
- for( T element : inputs )
- ret.add(element);
- return ret;
- }
-
- public static <T> Stream<T> getStream(Iterator<T> iter) {
- Iterable<T> iterable = () -> iter;
- return StreamSupport.stream(iterable.spliterator(), false);
- }
-
public static long prod(long[] arr) {
long ret = 1;
for(int i=0; i<arr.length; i++)
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
index 95ae502..8e098a6 100644
--- a/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameConverterTest.java
@@ -61,6 +61,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -83,11 +84,11 @@ public class FrameConverterTest extends AutomatedTestBase
private final static List<ValueType> schemaMixedLargeListInt = Collections.nCopies(200, ValueType.INT64);
private final static List<ValueType> schemaMixedLargeListBool = Collections.nCopies(200, ValueType.BOOLEAN);
- private static final List<ValueType> schemaMixedLargeList = UtilFunctions.asList(
+ private static final List<ValueType> schemaMixedLargeList = CollectionUtils.asList(
schemaMixedLargeListStr, schemaMixedLargeListDble, schemaMixedLargeListInt, schemaMixedLargeListBool);
private static final ValueType[] schemaMixedLarge = schemaMixedLargeList.toArray(new ValueType[0]);
- private static final List<ValueType> schemaMixedLargeListDFrame = UtilFunctions.asList(
+ private static final List<ValueType> schemaMixedLargeListDFrame = CollectionUtils.asList(
schemaMixedLargeListStr.subList(0, 100), schemaMixedLargeListDble.subList(0, 100),
schemaMixedLargeListInt.subList(0, 100), schemaMixedLargeListBool.subList(0, 100));
private static final ValueType[] schemaMixedLargeDFrame = schemaMixedLargeListDFrame.toArray(new ValueType[0]);
diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java
index b09bc65..1dba969 100644
--- a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java
@@ -128,7 +128,7 @@ public class TransformCSVFrameEncodeReadTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
int nrows = subset ? 4 : 13;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
- programArgs = new String[]{"-explain", "-stats","-args",
+ programArgs = new String[]{"-stats","-args",
HOME + "input/" + DATASET, String.valueOf(nrows), output("R") };
runTest(true, false, null, -1);
diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
index 74156d4..92b0b2d 100644
--- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
@@ -388,7 +388,7 @@ public class TransformFrameEncodeApplyTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
- programArgs = new String[]{"-explain", "recompile_hops", "-nvargs",
+ programArgs = new String[]{"-nvargs",
"DATA=" + HOME + "input/" + DATASET,
"TFSPEC=" + HOME + "input/" + SPEC,
"TFDATA1=" + output("tfout1"),