You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/06/03 11:28:47 UTC

[systemds] branch master updated: [SYSTEMDS-2984] Fix matrix/frame cleanup w/ nested lists

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 7778fe1  [SYSTEMDS-2984] Fix matrix/frame cleanup w/ nested lists
7778fe1 is described below

commit 7778fe15d806369ab47a458dda458afad25a0f7e
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jun 3 13:24:38 2021 +0200

    [SYSTEMDS-2984] Fix matrix/frame cleanup w/ nested lists
    
    This patch fixes the cleanup logic of rmvar instructions for matrices
    and frames, which are removed from the symbol table, buffer pool, and
    file representations if they go out of scope. We already properly
    handled lists which might hide the liveness of certain objects. This
    patch now fixes the handling for nested lists of arbitrary levels.
---
 .../runtime/controlprogram/LocalVariableMap.java   |  2 +-
 .../controlprogram/context/ExecutionContext.java   |  3 +--
 .../sysds/runtime/instructions/cp/ListObject.java  |  5 ++++
 .../{PrintMatrixTest.java => PrintListTest.java}   | 24 ++++++++++-------
 .../sysds/test/functions/misc/PrintMatrixTest.java |  5 +---
 src/test/scripts/functions/misc/PrintListTest1.dml | 27 +++++++++++++++++++
 src/test/scripts/functions/misc/PrintListTest2.dml | 31 ++++++++++++++++++++++
 7 files changed, 80 insertions(+), 17 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
index 7ff83f5..bac6759 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
@@ -118,7 +118,7 @@ public class LocalVariableMap implements Cloneable
 	public boolean hasReferences( Data d ) {
 		//perf: avoid java streams here for reduced overhead in rmvar
 		for( Data o : localMap.values() )
-			if( o instanceof ListObject ? ((ListObject)o).getData().contains(d) : o == d )
+			if( o instanceof ListObject ? ((ListObject)o).contains(d) : o == d )
 				return true;
 		return false;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 6b1054b..0bee2ef 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -765,8 +765,7 @@ public class ExecutionContext {
 			cleanupCacheableData( (CacheableData<?>)dat );
 		else if( dat instanceof ListObject )
 			for( Data dat2 : ((ListObject)dat).getData() )
-				if( dat2 instanceof CacheableData<?> )
-					cleanupCacheableData( (CacheableData<?>)dat2 );
+				cleanupDataObject(dat2);
 	}
 	
 	public void cleanupCacheableData(CacheableData<?> mo) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index a8397cb..392e9cf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -136,6 +136,11 @@ public class ListObject extends Data implements Externalizable {
 		return _lineage;
 	}
 
+	public boolean contains(Data d) {
+		return _data.stream().anyMatch(lo -> lo instanceof ListObject ?
+			(lo == d || ((ListObject)lo).contains(d)) : lo == d);
+	}
+	
 	public long getDataSize() {
 		return _data.stream().filter(data -> data instanceof CacheableData)
 			.mapToLong(data -> ((CacheableData<?>) data).getDataSize()).sum();
diff --git a/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/misc/PrintListTest.java
similarity index 72%
copy from src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java
copy to src/test/java/org/apache/sysds/test/functions/misc/PrintListTest.java
index 4ef018b..b076845 100644
--- a/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/misc/PrintListTest.java
@@ -24,32 +24,36 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 
-/**
- *   
- */
-public class PrintMatrixTest extends AutomatedTestBase
-{	
+public class PrintListTest extends AutomatedTestBase
+{
 	private final static String TEST_DIR = "functions/misc/";
-	private final static String TEST_NAME1 = "PrintMatrixTest";
-	private final static String TEST_CLASS_DIR = TEST_DIR + PrintMatrixTest.class.getSimpleName() + "/";
+	private final static String TEST_NAME1 = "PrintListTest1";
+	private final static String TEST_NAME2 = "PrintListTest2";
+	private final static String TEST_CLASS_DIR = TEST_DIR + PrintListTest.class.getSimpleName() + "/";
 	
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
 	}
 	
 	@Test
-	public void testPrintMatrix() {
+	public void testPrintNestedList() {
 		runTest( TEST_NAME1, false );
 	}
 	
+	@Test
+	public void testPrintMixedLists() {
+		runTest( TEST_NAME2, false );
+	}
+	
 	private void runTest( String testName, boolean exceptionExpected ) {
 		TestConfiguration config = getTestConfiguration(TEST_NAME1);
 		loadTestConfiguration(config);
 		
 		String HOME = SCRIPT_DIR + TEST_DIR;
-		fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
-		programArgs = new String[]{""};
+		fullDMLScriptName = HOME + testName + ".dml";
+		programArgs = new String[]{"-explain"};
 		
 		//run tests
 		runTest(true, exceptionExpected, DMLException.class, -1);
diff --git a/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java
index 4ef018b..171ada2 100644
--- a/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/misc/PrintMatrixTest.java
@@ -24,11 +24,8 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 
-/**
- *   
- */
 public class PrintMatrixTest extends AutomatedTestBase
-{	
+{
 	private final static String TEST_DIR = "functions/misc/";
 	private final static String TEST_NAME1 = "PrintMatrixTest";
 	private final static String TEST_CLASS_DIR = TEST_DIR + PrintMatrixTest.class.getSimpleName() + "/";
diff --git a/src/test/scripts/functions/misc/PrintListTest1.dml b/src/test/scripts/functions/misc/PrintListTest1.dml
new file mode 100644
index 0000000..6d63cc6
--- /dev/null
+++ b/src/test/scripts/functions/misc/PrintListTest1.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a_1 = matrix(3, 1, 1)
+a_2 = matrix(2, 1, 2)
+a_3 = matrix(45, 1, 3)
+b = list(a_1, a_2)
+c = list(a_3, b)
+print(toString(c))
diff --git a/src/test/scripts/functions/misc/PrintListTest2.dml b/src/test/scripts/functions/misc/PrintListTest2.dml
new file mode 100644
index 0000000..48c481a
--- /dev/null
+++ b/src/test/scripts/functions/misc/PrintListTest2.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a_1 = matrix(3, 1, 1)
+a_2 = matrix(2, 1, 2)
+a_3 = matrix(45, 1, 3)
+b = list(a_1, a_2)
+c = list(a_3, b)
+print(toString(c))
+print(toString(b))
+print(toString(a_1))
+print(toString(a_2))
+print(toString(a_3))