You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@buildstream.apache.org by gi...@apache.org on 2020/12/29 13:09:21 UTC

[buildstream] 05/06: artifactcache: Move pull logic into CASRemote

This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 15db919f6fafd0c2b0abbddde78a084d5b5379a5
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Thu Dec 13 12:04:26 2018 +0000

    artifactcache: Move pull logic into CASRemote
    
    Seperates the pull logic into a remote/local API, so that artifact cache
    iterates over blob digests checks whether it has them, and then requests them
    if not. The request command allows batching of blobs where appropriate.
    
    Tests have been updated to ensure the correct tmpdir is set up in process
    wrappers, else invalid cross link errors happen in the CI. Additional asserts
    have been added to check that the temporary directories are cleared by the end
    of a pull.
    
    Part of #802
---
 buildstream/_artifactcache.py         |  54 +++++---
 buildstream/_cas/__init__.py          |   2 +-
 buildstream/_cas/cascache.py          | 224 ++--------------------------------
 buildstream/_cas/casremote.py         | 147 ++++++++++++++++++++++
 buildstream/_cas/transfer.py          |  51 ++++++++
 buildstream/sandbox/_sandboxremote.py |   4 +-
 conftest.py                           |   7 ++
 tests/artifactcache/pull.py           |  26 +++-
 tests/artifactcache/push.py           |  28 +++--
 tests/integration/pullbuildtrees.py   |   6 +
 10 files changed, 299 insertions(+), 250 deletions(-)

diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py
index dd5b4b5..21db707 100644
--- a/buildstream/_artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -28,7 +28,8 @@ from ._message import Message, MessageType
 from . import utils
 from . import _yaml
 
-from ._cas import CASRemote, CASRemoteSpec
+from ._cas import BlobNotFound, CASRemote, CASRemoteSpec
+from ._cas.transfer import cas_directory_download, cas_tree_download
 
 
 CACHE_SIZE_FILE = "cache_size"
@@ -644,19 +645,31 @@ class ArtifactCache():
                 display_key = element._get_brief_display_key()
                 element.status("Pulling artifact {} <- {}".format(display_key, remote.spec.url))
 
-                if self.cas.pull(ref, remote, progress=progress, subdir=subdir, excluded_subdirs=excluded_subdirs):
-                    element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
-                    if subdir:
-                        # Attempt to extract subdir into artifact extract dir if it already exists
-                        # without containing the subdir. If the respective artifact extract dir does not
-                        # exist a complete extraction will complete.
-                        self.extract(element, key, subdir)
-                    # no need to pull from additional remotes
-                    return True
-                else:
+                root_digest = remote.get_reference(ref)
+
+                if not root_digest:
                     element.info("Remote ({}) does not have {} cached".format(
-                        remote.spec.url, element._get_brief_display_key()
-                    ))
+                        remote.spec.url, element._get_brief_display_key()))
+                    continue
+
+                try:
+                    cas_directory_download(self.cas, remote, root_digest, excluded_subdirs)
+                except BlobNotFound:
+                    element.info("Remote ({}) is missing blobs for {}".format(
+                        remote.spec.url, element._get_brief_display_key()))
+                    continue
+
+                self.cas.set_ref(ref, root_digest)
+
+                if subdir:
+                    # Attempt to extract subdir into artifact extract dir if it already exists
+                    # without containing the subdir. If the respective artifact extract dir does not
+                    # exist a complete extraction will complete.
+                    self.extract(element, key, subdir)
+
+                element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
+                # no need to pull from additional remotes
+                return True
 
             except CASError as e:
                 raise ArtifactError("Failed to pull artifact {}: {}".format(
@@ -671,15 +684,16 @@ class ArtifactCache():
     #
     # Args:
     #     project (Project): The current project
-    #     digest (Digest): The digest of the tree
+    #     tree_digest (Digest): The digest of the tree
     #
-    def pull_tree(self, project, digest):
+    def pull_tree(self, project, tree_digest):
         for remote in self._remotes[project]:
-            digest = self.cas.pull_tree(remote, digest)
-
-            if digest:
-                # no need to pull from additional remotes
-                return digest
+            try:
+                root_digest = cas_tree_download(self.cas, remote, tree_digest)
+            except BlobNotFound:
+                continue
+            else:
+                return root_digest
 
         return None
 
diff --git a/buildstream/_cas/__init__.py b/buildstream/_cas/__init__.py
index a88e413..20c0279 100644
--- a/buildstream/_cas/__init__.py
+++ b/buildstream/_cas/__init__.py
@@ -18,4 +18,4 @@
 #        Tristan Van Berkom <tr...@codethink.co.uk>
 
 from .cascache import CASCache
-from .casremote import CASRemote, CASRemoteSpec
+from .casremote import CASRemote, CASRemoteSpec, BlobNotFound
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index adbd34c..e3b0332 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -33,7 +33,7 @@ from .._protos.buildstream.v2 import buildstream_pb2
 from .. import utils
 from .._exceptions import CASCacheError
 
-from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate
+from .casremote import _CASBatchUpdate
 
 
 # A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -183,73 +183,6 @@ class CASCache():
 
         return modified, removed, added
 
-    # pull():
-    #
-    # Pull a ref from a remote repository.
-    #
-    # Args:
-    #     ref (str): The ref to pull
-    #     remote (CASRemote): The remote repository to pull from
-    #     progress (callable): The progress callback, if any
-    #     subdir (str): The optional specific subdir to pull
-    #     excluded_subdirs (list): The optional list of subdirs to not pull
-    #
-    # Returns:
-    #   (bool): True if pull was successful, False if ref was not available
-    #
-    def pull(self, ref, remote, *, progress=None, subdir=None, excluded_subdirs=None):
-        try:
-            remote.init()
-
-            request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name)
-            request.key = ref
-            response = remote.ref_storage.GetReference(request)
-
-            tree = remote_execution_pb2.Digest()
-            tree.hash = response.digest.hash
-            tree.size_bytes = response.digest.size_bytes
-
-            # Check if the element artifact is present, if so just fetch the subdir.
-            if subdir and os.path.exists(self.objpath(tree)):
-                self._fetch_subdir(remote, tree, subdir)
-            else:
-                # Fetch artifact, excluded_subdirs determined in pullqueue
-                self._fetch_directory(remote, tree, excluded_subdirs=excluded_subdirs)
-
-            self.set_ref(ref, tree)
-
-            return True
-        except grpc.RpcError as e:
-            if e.code() != grpc.StatusCode.NOT_FOUND:
-                raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e
-            else:
-                return False
-        except BlobNotFound as e:
-            return False
-
-    # pull_tree():
-    #
-    # Pull a single Tree rather than a ref.
-    # Does not update local refs.
-    #
-    # Args:
-    #     remote (CASRemote): The remote to pull from
-    #     digest (Digest): The digest of the tree
-    #
-    def pull_tree(self, remote, digest):
-        try:
-            remote.init()
-
-            digest = self._fetch_tree(remote, digest)
-
-            return digest
-
-        except grpc.RpcError as e:
-            if e.code() != grpc.StatusCode.NOT_FOUND:
-                raise
-
-        return None
-
     # link_ref():
     #
     # Add an alias for an existing ref.
@@ -591,6 +524,16 @@ class CASCache():
         reachable = set()
         self._reachable_refs_dir(reachable, tree, update_mtime=True)
 
+    # Check to see if a blob is in the local CAS
+    # return None if not
+    def check_blob(self, digest):
+        objpath = self.objpath(digest)
+        if os.path.exists(objpath):
+            # already in local repository
+            return objpath
+        else:
+            return None
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -780,151 +723,6 @@ class CASCache():
         for dirnode in directory.directories:
             yield from self._required_blobs(dirnode.digest)
 
-    # _ensure_blob():
-    #
-    # Fetch and add blob if it's not already local.
-    #
-    # Args:
-    #     remote (Remote): The remote to use.
-    #     digest (Digest): Digest object for the blob to fetch.
-    #
-    # Returns:
-    #     (str): The path of the object
-    #
-    def _ensure_blob(self, remote, digest):
-        objpath = self.objpath(digest)
-        if os.path.exists(objpath):
-            # already in local repository
-            return objpath
-
-        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
-            remote._fetch_blob(digest, f)
-
-            added_digest = self.add_object(path=f.name, link_directly=True)
-            assert added_digest.hash == digest.hash
-
-        return objpath
-
-    def _batch_download_complete(self, batch):
-        for digest, data in batch.send():
-            with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
-                f.write(data)
-                f.flush()
-
-                added_digest = self.add_object(path=f.name, link_directly=True)
-                assert added_digest.hash == digest.hash
-
-    # Helper function for _fetch_directory().
-    def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue):
-        self._batch_download_complete(batch)
-
-        # All previously scheduled directories are now locally available,
-        # move them to the processing queue.
-        fetch_queue.extend(fetch_next_queue)
-        fetch_next_queue.clear()
-        return _CASBatchRead(remote)
-
-    # Helper function for _fetch_directory().
-    def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False):
-        in_local_cache = os.path.exists(self.objpath(digest))
-
-        if in_local_cache:
-            # Skip download, already in local cache.
-            pass
-        elif (digest.size_bytes >= remote.max_batch_total_size_bytes or
-              not remote.batch_read_supported):
-            # Too large for batch request, download in independent request.
-            self._ensure_blob(remote, digest)
-            in_local_cache = True
-        else:
-            if not batch.add(digest):
-                # Not enough space left in batch request.
-                # Complete pending batch first.
-                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-                batch.add(digest)
-
-        if recursive:
-            if in_local_cache:
-                # Add directory to processing queue.
-                fetch_queue.append(digest)
-            else:
-                # Directory will be available after completing pending batch.
-                # Add directory to deferred processing queue.
-                fetch_next_queue.append(digest)
-
-        return batch
-
-    # _fetch_directory():
-    #
-    # Fetches remote directory and adds it to content addressable store.
-    #
-    # Fetches files, symbolic links and recursively other directories in
-    # the remote directory and adds them to the content addressable
-    # store.
-    #
-    # Args:
-    #     remote (Remote): The remote to use.
-    #     dir_digest (Digest): Digest object for the directory to fetch.
-    #     excluded_subdirs (list): The optional list of subdirs to not fetch
-    #
-    def _fetch_directory(self, remote, dir_digest, *, excluded_subdirs=None):
-        fetch_queue = [dir_digest]
-        fetch_next_queue = []
-        batch = _CASBatchRead(remote)
-        if not excluded_subdirs:
-            excluded_subdirs = []
-
-        while len(fetch_queue) + len(fetch_next_queue) > 0:
-            if not fetch_queue:
-                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-
-            dir_digest = fetch_queue.pop(0)
-
-            objpath = self._ensure_blob(remote, dir_digest)
-
-            directory = remote_execution_pb2.Directory()
-            with open(objpath, 'rb') as f:
-                directory.ParseFromString(f.read())
-
-            for dirnode in directory.directories:
-                if dirnode.name not in excluded_subdirs:
-                    batch = self._fetch_directory_node(remote, dirnode.digest, batch,
-                                                       fetch_queue, fetch_next_queue, recursive=True)
-
-            for filenode in directory.files:
-                batch = self._fetch_directory_node(remote, filenode.digest, batch,
-                                                   fetch_queue, fetch_next_queue)
-
-        # Fetch final batch
-        self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-
-    def _fetch_subdir(self, remote, tree, subdir):
-        subdirdigest = self._get_subdir(tree, subdir)
-        self._fetch_directory(remote, subdirdigest)
-
-    def _fetch_tree(self, remote, digest):
-        # download but do not store the Tree object
-        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out:
-            remote._fetch_blob(digest, out)
-
-            tree = remote_execution_pb2.Tree()
-
-            with open(out.name, 'rb') as f:
-                tree.ParseFromString(f.read())
-
-            tree.children.extend([tree.root])
-            for directory in tree.children:
-                for filenode in directory.files:
-                    self._ensure_blob(remote, filenode.digest)
-
-                # place directory blob only in final location when we've downloaded
-                # all referenced blobs to avoid dangling references in the repository
-                dirbuffer = directory.SerializeToString()
-                dirdigest = self.add_object(buffer=dirbuffer)
-                assert dirdigest.size_bytes == len(dirbuffer)
-
-        return dirdigest
-
     def _send_directory(self, remote, digest, u_uid=uuid.uuid4()):
         required_blobs = self._required_blobs(digest)
 
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
index f7af253..0e75b09 100644
--- a/buildstream/_cas/casremote.py
+++ b/buildstream/_cas/casremote.py
@@ -3,6 +3,7 @@ import io
 import os
 import multiprocessing
 import signal
+import tempfile
 from urllib.parse import urlparse
 import uuid
 
@@ -96,6 +97,11 @@ class CASRemote():
         self.tmpdir = str(tmpdir)
         os.makedirs(self.tmpdir, exist_ok=True)
 
+        self.__tmp_downloads = []  # files in the tmpdir waiting to be added to local caches
+
+        self.__batch_read = None
+        self.__batch_update = None
+
     def init(self):
         if not self._initialized:
             url = urlparse(self.spec.url)
@@ -153,6 +159,7 @@ class CASRemote():
                 request = remote_execution_pb2.BatchReadBlobsRequest()
                 response = self.cas.BatchReadBlobs(request)
                 self.batch_read_supported = True
+                self.__batch_read = _CASBatchRead(self)
             except grpc.RpcError as e:
                 if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                     raise
@@ -163,6 +170,7 @@ class CASRemote():
                 request = remote_execution_pb2.BatchUpdateBlobsRequest()
                 response = self.cas.BatchUpdateBlobs(request)
                 self.batch_update_supported = True
+                self.__batch_update = _CASBatchUpdate(self)
             except grpc.RpcError as e:
                 if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
                         e.code() != grpc.StatusCode.PERMISSION_DENIED):
@@ -259,6 +267,136 @@ class CASRemote():
 
         return message_digest
 
+    # get_reference():
+    #
+    # Args:
+    #    ref (str): The ref to request
+    #
+    # Returns:
+    #    (digest): digest of ref, None if not found
+    #
+    def get_reference(self, ref):
+        try:
+            self.init()
+
+            request = buildstream_pb2.GetReferenceRequest()
+            request.key = ref
+            return self.ref_storage.GetReference(request).digest
+        except grpc.RpcError as e:
+            if e.code() != grpc.StatusCode.NOT_FOUND:
+                raise CASRemoteError("Failed to find ref {}: {}".format(ref, e)) from e
+            else:
+                return None
+
+    def get_tree_blob(self, tree_digest):
+        self.init()
+        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+        self._fetch_blob(tree_digest, f)
+
+        tree = remote_execution_pb2.Tree()
+        with open(f.name, 'rb') as tmp:
+            tree.ParseFromString(tmp.read())
+
+        return tree
+
+    # yield_directory_digests():
+    #
+    # Recursively iterates over digests for files, symbolic links and other
+    # directories starting from a root digest
+    #
+    # Args:
+    #     root_digest (digest): The root_digest to get a tree of
+    #     progress (callable): The progress callback, if any
+    #     subdir (str): The optional specific subdir to pull
+    #     excluded_subdirs (list): The optional list of subdirs to not pull
+    #
+    # Returns:
+    #     (iter digests): recursively iterates over digests contained in root directory
+    #
+    def yield_directory_digests(self, root_digest, *, progress=None,
+                                subdir=None, excluded_subdirs=None):
+        self.init()
+
+        # Fetch artifact, excluded_subdirs determined in pullqueue
+        if excluded_subdirs is None:
+            excluded_subdirs = []
+
+        # get directory blob
+        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+        self._fetch_blob(root_digest, f)
+
+        directory = remote_execution_pb2.Directory()
+        with open(f.name, 'rb') as tmp:
+            directory.ParseFromString(tmp.read())
+
+        yield root_digest
+        for filenode in directory.files:
+            yield filenode.digest
+
+        for dirnode in directory.directories:
+            if dirnode.name not in excluded_subdirs:
+                yield from self.yield_directory_digests(dirnode.digest)
+
+    # yield_tree_digests():
+    #
+    # Fetches a tree file from digests and then iterates over child digests
+    #
+    # Args:
+    #     tree_digest (digest): tree digest
+    #
+    # Returns:
+    #     (iter digests): iterates over digests in tree message
+    def yield_tree_digests(self, tree):
+        self.init()
+
+        tree.children.extend([tree.root])
+        for directory in tree.children:
+            for filenode in directory.files:
+                yield filenode.digest
+
+            # add the directory to downloaded tmp files to be added
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            f.write(directory.SerializeToString())
+            f.flush()
+            self.__tmp_downloads.append(f)
+
+    # request_blob():
+    #
+    # Request blob, triggering download depending via bytestream or cas
+    # BatchReadBlobs depending on size.
+    #
+    # Args:
+    #    digest (Digest): digest of the requested blob
+    #
+    def request_blob(self, digest):
+        if (not self.batch_read_supported or
+                digest.size_bytes > self.max_batch_total_size_bytes):
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            self._fetch_blob(digest, f)
+            self.__tmp_downloads.append(f)
+        elif self.__batch_read.add(digest) is False:
+            self._download_batch()
+            self.__batch_read.add(digest)
+
+    # get_blobs():
+    #
+    # Yield over downloaded blobs in the tmp file locations, causing the files
+    # to be deleted once they go out of scope.
+    #
+    # Args:
+    #    complete_batch (bool): download any outstanding batch read request
+    #
+    # Returns:
+    #    iterator over NamedTemporaryFile
+    def get_blobs(self, complete_batch=False):
+        # Send read batch request and download
+        if (complete_batch is True and
+                self.batch_read_supported is True):
+            self._download_batch()
+
+        while self.__tmp_downloads:
+            yield self.__tmp_downloads.pop()
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -301,6 +439,15 @@ class CASRemote():
 
         assert response.committed_size == digest.size_bytes
 
+    def _download_batch(self):
+        for _, data in self.__batch_read.send():
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            f.write(data)
+            f.flush()
+            self.__tmp_downloads.append(f)
+
+        self.__batch_read = _CASBatchRead(self)
+
 
 # Represents a batch of blobs queued for fetching.
 #
diff --git a/buildstream/_cas/transfer.py b/buildstream/_cas/transfer.py
new file mode 100644
index 0000000..5eaaf09
--- /dev/null
+++ b/buildstream/_cas/transfer.py
@@ -0,0 +1,51 @@
+#
+#  Copyright (C) 2017-2018 Codethink Limited
+#
+#  This program is free software; you can redistribute it and/or
+#  modify it under the terms of the GNU Lesser General Public
+#  License as published by the Free Software Foundation; either
+#  version 2 of the License, or (at your option) any later version.
+#
+#  This library is distributed in the hope that it will be useful,
+#  but WITHOUT ANY WARRANTY; without even the implied warranty of
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.	 See the GNU
+#  Lesser General Public License for more details.
+#
+#  You should have received a copy of the GNU Lesser General Public
+#  License along with this library. If not, see <http://www.gnu.org/licenses/>.
+#
+#  Authors:
+#        Raoul Hidalgo Charman <ra...@codethink.co.uk>
+
+from ..utils import _message_digest
+
+
+def cas_directory_download(caslocal, casremote, root_digest, excluded_subdirs):
+    for blob_digest in casremote.yield_directory_digests(
+            root_digest, excluded_subdirs=excluded_subdirs):
+        if caslocal.check_blob(blob_digest):
+            continue
+        casremote.request_blob(blob_digest)
+        for blob_file in casremote.get_blobs():
+            caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # Request final CAS batch
+    for blob_file in casremote.get_blobs(complete_batch=True):
+        caslocal.add_object(path=blob_file.name, link_directly=True)
+
+
+def cas_tree_download(caslocal, casremote, tree_digest):
+    tree = casremote.get_tree_blob(tree_digest)
+    for blob_digest in casremote.yield_tree_digests(tree):
+        if caslocal.check_blob(blob_digest):
+            continue
+        casremote.request_blob(blob_digest)
+        for blob_file in casremote.get_blobs():
+            caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # Get the last batch
+    for blob_file in casremote.get_blobs(complete_batch=True):
+        caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # get root digest from tree and return that
+    return _message_digest(tree.root.SerializeToString())
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index 8c21041..bea1754 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -39,6 +39,7 @@ from .._exceptions import SandboxError
 from .. import _yaml
 from .._protos.google.longrunning import operations_pb2, operations_pb2_grpc
 from .._cas import CASRemote, CASRemoteSpec
+from .._cas.transfer import cas_tree_download
 
 
 class RemoteExecutionSpec(namedtuple('RemoteExecutionSpec', 'exec_service storage_service action_service')):
@@ -281,8 +282,7 @@ class SandboxRemote(Sandbox):
         cascache = context.get_cascache()
         casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
 
-        # Now do a pull to ensure we have the necessary parts.
-        dir_digest = cascache.pull_tree(casremote, tree_digest)
+        dir_digest = cas_tree_download(cascache, casremote, tree_digest)
         if dir_digest is None or not dir_digest.hash or not dir_digest.size_bytes:
             raise SandboxError("Output directory structure pulling from remote failed.")
 
diff --git a/conftest.py b/conftest.py
index f3c09a5..9fb7f12 100755
--- a/conftest.py
+++ b/conftest.py
@@ -46,6 +46,13 @@ def integration_cache(request):
     else:
         cache_dir = os.path.abspath('./integration-cache')
 
+    # Clean up the tmp dir, should be empty but something in CI tests is
+    # leaving files here
+    try:
+        shutil.rmtree(os.path.join(cache_dir, 'tmp'))
+    except FileNotFoundError:
+        pass
+
     yield cache_dir
 
     # Clean up the artifacts after each test run - we only want to
diff --git a/tests/artifactcache/pull.py b/tests/artifactcache/pull.py
index 4c332bf..15d5c67 100644
--- a/tests/artifactcache/pull.py
+++ b/tests/artifactcache/pull.py
@@ -110,7 +110,7 @@ def test_pull(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_pull, queue, user_config_file, project_dir,
-                                                artifact_dir, 'target.bst', element_key))
+                                                artifact_dir, tmpdir, 'target.bst', element_key))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -126,14 +126,18 @@ def test_pull(cli, tmpdir, datafiles):
         assert not error
         assert cas.contains(element, element_key)
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_pull(user_config_file, project_dir, artifact_dir,
+
+def _test_pull(user_config_file, project_dir, artifact_dir, tmpdir,
                element_name, element_key, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -218,7 +222,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push_tree, queue, user_config_file, project_dir,
-                                                artifact_dir, artifact_digest))
+                                                artifact_dir, tmpdir, artifact_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -239,6 +243,9 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Assert that we are not cached locally anymore
         assert cli.get_element_state(project_dir, 'target.bst') != 'cached'
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
         tree_digest = remote_execution_pb2.Digest(hash=tree_hash,
                                                   size_bytes=tree_size)
 
@@ -246,7 +253,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Use subprocess to avoid creation of gRPC threads in main BuildStream process
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_pull_tree, queue, user_config_file, project_dir,
-                                                artifact_dir, tree_digest))
+                                                artifact_dir, tmpdir, tree_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -267,13 +274,18 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Ensure the entire Tree stucture has been pulled
         assert os.path.exists(cas.objpath(directory_digest))
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
 
-def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+def _test_push_tree(user_config_file, project_dir, artifact_dir, tmpdir,
+                    artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -304,12 +316,14 @@ def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest
         queue.put("No remote configured")
 
 
-def _test_pull_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+def _test_pull_tree(user_config_file, project_dir, artifact_dir, tmpdir,
+                    artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
diff --git a/tests/artifactcache/push.py b/tests/artifactcache/push.py
index 116fa78..f97a231 100644
--- a/tests/artifactcache/push.py
+++ b/tests/artifactcache/push.py
@@ -89,7 +89,7 @@ def test_push(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push, queue, user_config_file, project_dir,
-                                                artifact_dir, 'target.bst', element_key))
+                                                artifact_dir, tmpdir, 'target.bst', element_key))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -105,14 +105,18 @@ def test_push(cli, tmpdir, datafiles):
         assert not error
         assert share.has_artifact('test', 'target.bst', element_key)
 
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_push(user_config_file, project_dir, artifact_dir,
+
+def _test_push(user_config_file, project_dir, artifact_dir, tmpdir,
                element_name, element_key, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -196,9 +200,10 @@ def test_push_directory(cli, tmpdir, datafiles):
         queue = multiprocessing.Queue()
         # Use subprocess to avoid creation of gRPC threads in main BuildStream process
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
-        process = multiprocessing.Process(target=_queue_wrapper,
-                                          args=(_test_push_directory, queue, user_config_file,
-                                                project_dir, artifact_dir, artifact_digest))
+        process = multiprocessing.Process(
+            target=_queue_wrapper,
+            args=(_test_push_directory, queue, user_config_file, project_dir,
+                  artifact_dir, tmpdir, artifact_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -215,13 +220,17 @@ def test_push_directory(cli, tmpdir, datafiles):
         assert artifact_digest.hash == directory_hash
         assert share.has_object(artifact_digest)
 
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_push_directory(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+
+def _test_push_directory(user_config_file, project_dir, artifact_dir, tmpdir,
+                         artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -273,7 +282,7 @@ def test_push_message(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push_message, queue, user_config_file,
-                                                project_dir, artifact_dir))
+                                                project_dir, artifact_dir, tmpdir))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -291,13 +300,16 @@ def test_push_message(cli, tmpdir, datafiles):
                                                      size_bytes=message_size)
         assert share.has_object(message_digest)
 
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
 
-def _test_push_message(user_config_file, project_dir, artifact_dir, queue):
+def _test_push_message(user_config_file, project_dir, artifact_dir, tmpdir, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
diff --git a/tests/integration/pullbuildtrees.py b/tests/integration/pullbuildtrees.py
index f6fc712..de13a3d 100644
--- a/tests/integration/pullbuildtrees.py
+++ b/tests/integration/pullbuildtrees.py
@@ -77,6 +77,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         result = cli.run(project=project, args=['--pull-buildtrees', 'pull', element_name])
         assert element_name in result.get_pulled_elements()
         assert os.path.isdir(buildtreedir)
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Pull artifact with pullbuildtrees set in user config, then assert
@@ -89,6 +91,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         assert element_name not in result.get_pulled_elements()
         result = cli.run(project=project, args=['--pull-buildtrees', 'pull', element_name])
         assert element_name not in result.get_pulled_elements()
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Pull artifact with default config and buildtrees cli flag set, then assert
@@ -99,6 +103,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         cli.configure({'cache': {'pull-buildtrees': True}})
         result = cli.run(project=project, args=['pull', element_name])
         assert element_name not in result.get_pulled_elements()
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Assert that a partial build element (not containing a populated buildtree dir)