You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@buildstream.apache.org by gi...@apache.org on 2020/12/29 13:09:16 UTC

[buildstream] branch raoul/802-refactor-artifactcache created (now 7329ef5)

This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a change to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git.


      at 7329ef5  artifactcache: implement new push methods

This branch includes the following new commits:

     new c69d12f  _cas: Rename artifactcache folder and move that to a root module
     new fe01405  casremote.py: Move remote CAS classes into its own file
     new cdc5b6f  cas: move remote only functions to CASRemote
     new 9944ddd  tmpdir: add tmpdir to context for CASRemote
     new 15db919  artifactcache: Move pull logic into CASRemote
     new 7329ef5  artifactcache: implement new push methods

The 6 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[buildstream] 03/06: cas: move remote only functions to CASRemote

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit cdc5b6f5bcf604841f6a291349231df60f407d18
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Tue Dec 11 11:41:44 2018 +0000

    cas: move remote only functions to CASRemote
    
    List of methods moved
    * Initialization check: made it a class method that is run in a subprocess, for
      when checking in the main buildstream process.
    * fetch_blobs
    * send_blobs
    * verify_digest_on_remote
    * push_method
    
    Part of #802
---
 buildstream/_artifactcache.py         |  18 +----
 buildstream/_cas/cascache.py          | 127 +-----------------------------
 buildstream/_cas/casremote.py         | 141 +++++++++++++++++++++++++++++++++-
 buildstream/sandbox/_sandboxremote.py |   6 +-
 4 files changed, 148 insertions(+), 144 deletions(-)

diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py
index 1b2b55d..cdbf2d9 100644
--- a/buildstream/_artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -19,14 +19,12 @@
 
 import multiprocessing
 import os
-import signal
 import string
 from collections.abc import Mapping
 
 from .types import _KeyStrength
 from ._exceptions import ArtifactError, CASError, LoadError, LoadErrorReason
 from ._message import Message, MessageType
-from . import _signals
 from . import utils
 from . import _yaml
 
@@ -375,20 +373,8 @@ class ArtifactCache():
         remotes = {}
         q = multiprocessing.Queue()
         for remote_spec in remote_specs:
-            # Use subprocess to avoid creation of gRPC threads in main BuildStream process
-            # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
-            p = multiprocessing.Process(target=self.cas.initialize_remote, args=(remote_spec, q))
 
-            try:
-                # Keep SIGINT blocked in the child process
-                with _signals.blocked([signal.SIGINT], ignore=False):
-                    p.start()
-
-                error = q.get()
-                p.join()
-            except KeyboardInterrupt:
-                utils._kill_process_tree(p.pid)
-                raise
+            error = CASRemote.check_remote(remote_spec, q)
 
             if error and on_failure:
                 on_failure(remote_spec.url, error)
@@ -747,7 +733,7 @@ class ArtifactCache():
                                 "servers are configured as push remotes.")
 
         for remote in push_remotes:
-            message_digest = self.cas.push_message(remote, message)
+            message_digest = remote.push_message(message)
 
         return message_digest
 
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index 5f62e61..adbd34c 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -19,7 +19,6 @@
 
 import hashlib
 import itertools
-import io
 import os
 import stat
 import tempfile
@@ -28,14 +27,13 @@ import contextlib
 
 import grpc
 
-from .._protos.google.bytestream import bytestream_pb2
 from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 from .._protos.buildstream.v2 import buildstream_pb2
 
 from .. import utils
 from .._exceptions import CASCacheError
 
-from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES
+from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate
 
 
 # A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -185,29 +183,6 @@ class CASCache():
 
         return modified, removed, added
 
-    def initialize_remote(self, remote_spec, q):
-        try:
-            remote = CASRemote(remote_spec)
-            remote.init()
-
-            request = buildstream_pb2.StatusRequest(instance_name=remote_spec.instance_name)
-            response = remote.ref_storage.Status(request)
-
-            if remote_spec.push and not response.allow_updates:
-                q.put('CAS server does not allow push')
-            else:
-                # No error
-                q.put(None)
-
-        except grpc.RpcError as e:
-            # str(e) is too verbose for errors reported to the user
-            q.put(e.details())
-
-        except Exception as e:               # pylint: disable=broad-except
-            # Whatever happens, we need to return it to the calling process
-            #
-            q.put(str(e))
-
     # pull():
     #
     # Pull a ref from a remote repository.
@@ -355,50 +330,6 @@ class CASCache():
 
         self._send_directory(remote, directory.ref)
 
-    # push_message():
-    #
-    # Push the given protobuf message to a remote.
-    #
-    # Args:
-    #     remote (CASRemote): The remote to push to
-    #     message (Message): A protobuf message to push.
-    #
-    # Raises:
-    #     (CASCacheError): if there was an error
-    #
-    def push_message(self, remote, message):
-
-        message_buffer = message.SerializeToString()
-        message_digest = utils._message_digest(message_buffer)
-
-        remote.init()
-
-        with io.BytesIO(message_buffer) as b:
-            self._send_blob(remote, message_digest, b)
-
-        return message_digest
-
-    # verify_digest_on_remote():
-    #
-    # Check whether the object is already on the server in which case
-    # there is no need to upload it.
-    #
-    # Args:
-    #     remote (CASRemote): The remote to check
-    #     digest (Digest): The object digest.
-    #
-    def verify_digest_on_remote(self, remote, digest):
-        remote.init()
-
-        request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=remote.spec.instance_name)
-        request.blob_digests.extend([digest])
-
-        response = remote.cas.FindMissingBlobs(request)
-        if digest in response.missing_blob_digests:
-            return False
-
-        return True
-
     # objpath():
     #
     # Return the path of an object based on its digest.
@@ -849,23 +780,6 @@ class CASCache():
         for dirnode in directory.directories:
             yield from self._required_blobs(dirnode.digest)
 
-    def _fetch_blob(self, remote, digest, stream):
-        resource_name_components = ['blobs', digest.hash, str(digest.size_bytes)]
-
-        if remote.spec.instance_name:
-            resource_name_components.insert(0, remote.spec.instance_name)
-
-        resource_name = '/'.join(resource_name_components)
-
-        request = bytestream_pb2.ReadRequest()
-        request.resource_name = resource_name
-        request.read_offset = 0
-        for response in remote.bytestream.Read(request):
-            stream.write(response.data)
-        stream.flush()
-
-        assert digest.size_bytes == os.fstat(stream.fileno()).st_size
-
     # _ensure_blob():
     #
     # Fetch and add blob if it's not already local.
@@ -884,7 +798,7 @@ class CASCache():
             return objpath
 
         with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
-            self._fetch_blob(remote, digest, f)
+            remote._fetch_blob(digest, f)
 
             added_digest = self.add_object(path=f.name, link_directly=True)
             assert added_digest.hash == digest.hash
@@ -991,7 +905,7 @@ class CASCache():
     def _fetch_tree(self, remote, digest):
         # download but do not store the Tree object
         with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out:
-            self._fetch_blob(remote, digest, out)
+            remote._fetch_blob(digest, out)
 
             tree = remote_execution_pb2.Tree()
 
@@ -1011,39 +925,6 @@ class CASCache():
 
         return dirdigest
 
-    def _send_blob(self, remote, digest, stream, u_uid=uuid.uuid4()):
-        resource_name_components = ['uploads', str(u_uid), 'blobs',
-                                    digest.hash, str(digest.size_bytes)]
-
-        if remote.spec.instance_name:
-            resource_name_components.insert(0, remote.spec.instance_name)
-
-        resource_name = '/'.join(resource_name_components)
-
-        def request_stream(resname, instream):
-            offset = 0
-            finished = False
-            remaining = digest.size_bytes
-            while not finished:
-                chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
-                remaining -= chunk_size
-
-                request = bytestream_pb2.WriteRequest()
-                request.write_offset = offset
-                # max. _MAX_PAYLOAD_BYTES chunks
-                request.data = instream.read(chunk_size)
-                request.resource_name = resname
-                request.finish_write = remaining <= 0
-
-                yield request
-
-                offset += chunk_size
-                finished = request.finish_write
-
-        response = remote.bytestream.Write(request_stream(resource_name, stream))
-
-        assert response.committed_size == digest.size_bytes
-
     def _send_directory(self, remote, digest, u_uid=uuid.uuid4()):
         required_blobs = self._required_blobs(digest)
 
@@ -1077,7 +958,7 @@ class CASCache():
                 if (digest.size_bytes >= remote.max_batch_total_size_bytes or
                         not remote.batch_update_supported):
                     # Too large for batch request, upload in independent request.
-                    self._send_blob(remote, digest, f, u_uid=u_uid)
+                    remote._send_blob(digest, f, u_uid=u_uid)
                 else:
                     if not batch.add(digest, f):
                         # Not enough space left in batch request.
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
index 59eb7e3..56ba4c5 100644
--- a/buildstream/_cas/casremote.py
+++ b/buildstream/_cas/casremote.py
@@ -1,16 +1,22 @@
 from collections import namedtuple
+import io
 import os
+import multiprocessing
+import signal
 from urllib.parse import urlparse
+import uuid
 
 import grpc
 
 from .. import _yaml
 from .._protos.google.rpc import code_pb2
-from .._protos.google.bytestream import bytestream_pb2_grpc
+from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
 from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
-from .._protos.buildstream.v2 import buildstream_pb2_grpc
+from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
 
 from .._exceptions import CASRemoteError, LoadError, LoadErrorReason
+from .. import _signals
+from .. import utils
 
 # The default limit for gRPC messages is 4 MiB.
 # Limit payload to 1 MiB to leave sufficient headroom for metadata.
@@ -159,6 +165,137 @@ class CASRemote():
 
             self._initialized = True
 
+    # check_remote
+    #
+    # Used when checking whether remote_specs work in the buildstream main
+    # thread, runs this in a seperate process to avoid creation of gRPC threads
+    # in the main BuildStream process
+    # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
+    @classmethod
+    def check_remote(cls, remote_spec, q):
+
+        def __check_remote():
+            try:
+                remote = cls(remote_spec)
+                remote.init()
+
+                request = buildstream_pb2.StatusRequest()
+                response = remote.ref_storage.Status(request)
+
+                if remote_spec.push and not response.allow_updates:
+                    q.put('CAS server does not allow push')
+                else:
+                    # No error
+                    q.put(None)
+
+            except grpc.RpcError as e:
+                # str(e) is too verbose for errors reported to the user
+                q.put(e.details())
+
+            except Exception as e:               # pylint: disable=broad-except
+                # Whatever happens, we need to return it to the calling process
+                #
+                q.put(str(e))
+
+        p = multiprocessing.Process(target=__check_remote)
+
+        try:
+            # Keep SIGINT blocked in the child process
+            with _signals.blocked([signal.SIGINT], ignore=False):
+                p.start()
+
+            error = q.get()
+            p.join()
+        except KeyboardInterrupt:
+            utils._kill_process_tree(p.pid)
+            raise
+
+        return error
+
+    # verify_digest_on_remote():
+    #
+    # Check whether the object is already on the server in which case
+    # there is no need to upload it.
+    #
+    # Args:
+    #     digest (Digest): The object digest.
+    #
+    def verify_digest_on_remote(self, digest):
+        self.init()
+
+        request = remote_execution_pb2.FindMissingBlobsRequest()
+        request.blob_digests.extend([digest])
+
+        response = self.cas.FindMissingBlobs(request)
+        if digest in response.missing_blob_digests:
+            return False
+
+        return True
+
+    # push_message():
+    #
+    # Push the given protobuf message to a remote.
+    #
+    # Args:
+    #     message (Message): A protobuf message to push.
+    #
+    # Raises:
+    #     (CASRemoteError): if there was an error
+    #
+    def push_message(self, message):
+
+        message_buffer = message.SerializeToString()
+        message_digest = utils._message_digest(message_buffer)
+
+        self.init()
+
+        with io.BytesIO(message_buffer) as b:
+            self._send_blob(message_digest, b)
+
+        return message_digest
+
+    ################################################
+    #             Local Private Methods            #
+    ################################################
+    def _fetch_blob(self, digest, stream):
+        resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
+        request = bytestream_pb2.ReadRequest()
+        request.resource_name = resource_name
+        request.read_offset = 0
+        for response in self.bytestream.Read(request):
+            stream.write(response.data)
+        stream.flush()
+
+        assert digest.size_bytes == os.fstat(stream.fileno()).st_size
+
+    def _send_blob(self, digest, stream, u_uid=uuid.uuid4()):
+        resource_name = '/'.join(['uploads', str(u_uid), 'blobs',
+                                  digest.hash, str(digest.size_bytes)])
+
+        def request_stream(resname, instream):
+            offset = 0
+            finished = False
+            remaining = digest.size_bytes
+            while not finished:
+                chunk_size = min(remaining, _MAX_PAYLOAD_BYTES)
+                remaining -= chunk_size
+
+                request = bytestream_pb2.WriteRequest()
+                request.write_offset = offset
+                # max. _MAX_PAYLOAD_BYTES chunks
+                request.data = instream.read(chunk_size)
+                request.resource_name = resname
+                request.finish_write = remaining <= 0
+
+                yield request
+
+                offset += chunk_size
+                finished = request.finish_write
+
+        response = self.bytestream.Write(request_stream(resource_name, stream))
+
+        assert response.committed_size == digest.size_bytes
+
 
 # Represents a batch of blobs queued for fetching.
 #
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index 8b4c87c..6a1f6f2 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -348,17 +348,17 @@ class SandboxRemote(Sandbox):
             except grpc.RpcError as e:
                 raise SandboxError("Failed to push source directory to remote: {}".format(e)) from e
 
-            if not cascache.verify_digest_on_remote(casremote, upload_vdir.ref):
+            if not casremote.verify_digest_on_remote(upload_vdir.ref):
                 raise SandboxError("Failed to verify that source has been pushed to the remote artifact cache.")
 
             # Push command and action
             try:
-                cascache.push_message(casremote, command_proto)
+                casremote.push_message(command_proto)
             except grpc.RpcError as e:
                 raise SandboxError("Failed to push command to remote: {}".format(e))
 
             try:
-                cascache.push_message(casremote, action)
+                casremote.push_message(action)
             except grpc.RpcError as e:
                 raise SandboxError("Failed to push action to remote: {}".format(e))
 


[buildstream] 04/06: tmpdir: add tmpdir to context for CASRemote

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 9944dddbb7b2f93ec9af4442bf825c6723ddb761
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Fri Dec 14 11:20:40 2018 +0000

    tmpdir: add tmpdir to context for CASRemote
    
    As CASRemote is used in other places such as SandboxRemote it makes sense for
    there to be a tmpdir it uses. This dir is passed to CASRemote when initialized.
    
    This is currently in the artifactdir, but should be moved when artifactdir and
    builddir options are deprecated in favour of one top level directory containing
    the two.
    
    Part of #802
---
 buildstream/_artifactcache.py         |  4 ++--
 buildstream/_cas/casremote.py         | 11 ++++++++---
 buildstream/_context.py               |  7 +++++--
 buildstream/sandbox/_sandboxremote.py |  7 ++++---
 tests/testutils/runcli.py             |  2 +-
 5 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py
index cdbf2d9..dd5b4b5 100644
--- a/buildstream/_artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -374,7 +374,7 @@ class ArtifactCache():
         q = multiprocessing.Queue()
         for remote_spec in remote_specs:
 
-            error = CASRemote.check_remote(remote_spec, q)
+            error = CASRemote.check_remote(remote_spec, self.context.tmpdir, q)
 
             if error and on_failure:
                 on_failure(remote_spec.url, error)
@@ -385,7 +385,7 @@ class ArtifactCache():
                 if remote_spec.push:
                     self._has_push_remotes = True
 
-                remotes[remote_spec.url] = CASRemote(remote_spec)
+                remotes[remote_spec.url] = CASRemote(remote_spec, self.context.tmpdir)
 
         for project in self.context.get_projects():
             remote_specs = self.global_remote_specs
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
index 56ba4c5..f7af253 100644
--- a/buildstream/_cas/casremote.py
+++ b/buildstream/_cas/casremote.py
@@ -79,7 +79,7 @@ class BlobNotFound(CASRemoteError):
 # Represents a single remote CAS cache.
 #
 class CASRemote():
-    def __init__(self, spec):
+    def __init__(self, spec, tmpdir):
         self.spec = spec
         self._initialized = False
         self.channel = None
@@ -91,6 +91,11 @@ class CASRemote():
         self.capabilities = None
         self.max_batch_total_size_bytes = None
 
+        # Need str because python 3.5 and lower doesn't deal with path like
+        # objects here.
+        self.tmpdir = str(tmpdir)
+        os.makedirs(self.tmpdir, exist_ok=True)
+
     def init(self):
         if not self._initialized:
             url = urlparse(self.spec.url)
@@ -172,11 +177,11 @@ class CASRemote():
     # in the main BuildStream process
     # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
     @classmethod
-    def check_remote(cls, remote_spec, q):
+    def check_remote(cls, remote_spec, tmpdir, q):
 
         def __check_remote():
             try:
-                remote = cls(remote_spec)
+                remote = cls(remote_spec, tmpdir)
                 remote.init()
 
                 request = buildstream_pb2.StatusRequest()
diff --git a/buildstream/_context.py b/buildstream/_context.py
index 324b455..f81a6b0 100644
--- a/buildstream/_context.py
+++ b/buildstream/_context.py
@@ -191,10 +191,13 @@ class Context():
         _yaml.node_validate(defaults, [
             'sourcedir', 'builddir', 'artifactdir', 'logdir',
             'scheduler', 'artifacts', 'logging', 'projects',
-            'cache', 'prompt', 'workspacedir', 'remote-execution'
+            'cache', 'prompt', 'workspacedir', 'remote-execution',
         ])
 
-        for directory in ['sourcedir', 'builddir', 'artifactdir', 'logdir', 'workspacedir']:
+        defaults['tmpdir'] = os.path.join(defaults['artifactdir'], 'tmp')
+
+        for directory in ['sourcedir', 'builddir', 'artifactdir', 'logdir',
+                          'tmpdir', 'workspacedir']:
             # Allow the ~ tilde expansion and any environment variables in
             # path specification in the config files.
             #
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index 6a1f6f2..8c21041 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -279,7 +279,7 @@ class SandboxRemote(Sandbox):
 
         context = self._get_context()
         cascache = context.get_cascache()
-        casremote = CASRemote(self.storage_remote_spec)
+        casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
 
         # Now do a pull to ensure we have the necessary parts.
         dir_digest = cascache.pull_tree(casremote, tree_digest)
@@ -306,8 +306,9 @@ class SandboxRemote(Sandbox):
 
     def _run(self, command, flags, *, cwd, env):
         # set up virtual dircetory
+        context = self._get_context()
         upload_vdir = self.get_virtual_directory()
-        cascache = self._get_context().get_cascache()
+        cascache = context.get_cascache()
         if isinstance(upload_vdir, FileBasedDirectory):
             # Make a new temporary directory to put source in
             upload_vdir = CasBasedDirectory(cascache, ref=None)
@@ -340,7 +341,7 @@ class SandboxRemote(Sandbox):
         action_result = self._check_action_cache(action_digest)
 
         if not action_result:
-            casremote = CASRemote(self.storage_remote_spec)
+            casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
 
             # Now, push that key (without necessarily needing a ref) to the remote.
             try:
diff --git a/tests/testutils/runcli.py b/tests/testutils/runcli.py
index 0c8e962..67a874e 100644
--- a/tests/testutils/runcli.py
+++ b/tests/testutils/runcli.py
@@ -526,7 +526,7 @@ def cli_integration(tmpdir, integration_cache):
     # to avoid downloading the huge base-sdk repeatedly
     fixture.configure({
         'sourcedir': os.path.join(integration_cache, 'sources'),
-        'artifactdir': os.path.join(integration_cache, 'artifacts')
+        'artifactdir': os.path.join(integration_cache, 'artifacts'),
     })
 
     return fixture


[buildstream] 06/06: artifactcache: implement new push methods

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 7329ef562a3df0d81cfeea20fbfb5cbba366b91a
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Thu Jan 3 12:21:58 2019 +0000

    artifactcache: implement new push methods
    
    Similar to the pull methods, this implements a yield_directory_digests methods
    that iterates over blobs in the local CAS, with the upload_blob sending blobs
    to a remote and batching them where appropriate.
    
    Part of #802
---
 buildstream/_artifactcache.py         |  41 ++++++++--
 buildstream/_cas/cascache.py          | 149 +++++-----------------------------
 buildstream/_cas/casremote.py         | 102 ++++++++++++++++++++++-
 buildstream/_cas/transfer.py          |  11 +++
 buildstream/sandbox/_sandboxremote.py |   4 +-
 5 files changed, 166 insertions(+), 141 deletions(-)

diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py
index 21db707..4628041 100644
--- a/buildstream/_artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -29,7 +29,7 @@ from . import utils
 from . import _yaml
 
 from ._cas import BlobNotFound, CASRemote, CASRemoteSpec
-from ._cas.transfer import cas_directory_download, cas_tree_download
+from ._cas.transfer import cas_directory_upload, cas_directory_download, cas_tree_download
 
 
 CACHE_SIZE_FILE = "cache_size"
@@ -608,16 +608,41 @@ class ArtifactCache():
 
         for remote in push_remotes:
             remote.init()
+            skipped_remote = True
             display_key = element._get_brief_display_key()
             element.status("Pushing artifact {} -> {}".format(display_key, remote.spec.url))
 
-            if self.cas.push(refs, remote):
-                element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url))
+            try:
+                for ref in refs:
+                    # Check whether ref is already on the server in which case
+                    # there is no need to push the ref
+                    root_digest = self.cas.resolve_ref(ref)
+                    response = remote.get_reference(ref)
+                    if (response is not None and
+                            response.hash == root_digest.hash and
+                            response.size_bytes == root_digest.size_bytes):
+                        element.info("Remote ({}) already has {} cached".format(
+                            remote.spec.url, element._get_brief_display_key()))
+                        continue
+
+                    # upload blobs
+                    cas_directory_upload(self.cas, remote, root_digest)
+                    remote.update_reference(ref, root_digest)
+
+                    skipped_remote = False
+
+            except CASError as e:
+                if str(e.reason) == "StatusCode.RESOURCE_EXHAUSTED":
+                    element.warn("Failed to push element to {}: Resource exhuasted"
+                                 .format(remote.spec.url))
+                    continue
+                else:
+                    raise ArtifactError("Failed to push refs {}: {}".format(refs, e),
+                                        temporary=True) from e
+
+            if skipped_remote is False:
                 pushed = True
-            else:
-                element.info("Remote ({}) already has {} cached".format(
-                    remote.spec.url, element._get_brief_display_key()
-                ))
+                element.info("Pushed artifact {} -> {}".format(display_key, remote.spec.url))
 
         return pushed
 
@@ -722,7 +747,7 @@ class ArtifactCache():
             return
 
         for remote in push_remotes:
-            self.cas.push_directory(remote, directory)
+            cas_directory_upload(self.cas, remote, directory.ref)
 
     # push_message():
     #
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index e3b0332..7ea46a0 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -18,23 +18,16 @@
 #        Jürg Billeter <ju...@codethink.co.uk>
 
 import hashlib
-import itertools
 import os
 import stat
 import tempfile
-import uuid
 import contextlib
 
-import grpc
-
 from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
-from .._protos.buildstream.v2 import buildstream_pb2
 
 from .. import utils
 from .._exceptions import CASCacheError
 
-from .casremote import _CASBatchUpdate
-
 
 # A CASCache manages a CAS repository as specified in the Remote Execution API.
 #
@@ -196,73 +189,6 @@ class CASCache():
 
         self.set_ref(newref, tree)
 
-    # push():
-    #
-    # Push committed refs to remote repository.
-    #
-    # Args:
-    #     refs (list): The refs to push
-    #     remote (CASRemote): The remote to push to
-    #
-    # Returns:
-    #   (bool): True if any remote was updated, False if no pushes were required
-    #
-    # Raises:
-    #   (CASCacheError): if there was an error
-    #
-    def push(self, refs, remote):
-        skipped_remote = True
-        try:
-            for ref in refs:
-                tree = self.resolve_ref(ref)
-
-                # Check whether ref is already on the server in which case
-                # there is no need to push the ref
-                try:
-                    request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name)
-                    request.key = ref
-                    response = remote.ref_storage.GetReference(request)
-
-                    if response.digest.hash == tree.hash and response.digest.size_bytes == tree.size_bytes:
-                        # ref is already on the server with the same tree
-                        continue
-
-                except grpc.RpcError as e:
-                    if e.code() != grpc.StatusCode.NOT_FOUND:
-                        # Intentionally re-raise RpcError for outer except block.
-                        raise
-
-                self._send_directory(remote, tree)
-
-                request = buildstream_pb2.UpdateReferenceRequest(instance_name=remote.spec.instance_name)
-                request.keys.append(ref)
-                request.digest.hash = tree.hash
-                request.digest.size_bytes = tree.size_bytes
-                remote.ref_storage.UpdateReference(request)
-
-                skipped_remote = False
-        except grpc.RpcError as e:
-            if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED:
-                raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
-
-        return not skipped_remote
-
-    # push_directory():
-    #
-    # Push the given virtual directory to a remote.
-    #
-    # Args:
-    #     remote (CASRemote): The remote to push to
-    #     directory (Directory): A virtual directory object to push.
-    #
-    # Raises:
-    #     (CASCacheError): if there was an error
-    #
-    def push_directory(self, remote, directory):
-        remote.init()
-
-        self._send_directory(remote, directory.ref)
-
     # objpath():
     #
     # Return the path of an object based on its digest.
@@ -534,6 +460,27 @@ class CASCache():
         else:
             return None
 
+    def yield_directory_digests(self, directory_digest):
+        # parse directory, and recursively add blobs
+        d = remote_execution_pb2.Digest()
+        d.hash = directory_digest.hash
+        d.size_bytes = directory_digest.size_bytes
+        yield d
+
+        directory = remote_execution_pb2.Directory()
+
+        with open(self.objpath(directory_digest), 'rb') as f:
+            directory.ParseFromString(f.read())
+
+        for filenode in directory.files:
+            d = remote_execution_pb2.Digest()
+            d.hash = filenode.digest.hash
+            d.size_bytes = filenode.digest.size_bytes
+            yield d
+
+        for dirnode in directory.directories:
+            yield from self.yield_directory_digests(dirnode.digest)
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -722,57 +669,3 @@ class CASCache():
 
         for dirnode in directory.directories:
             yield from self._required_blobs(dirnode.digest)
-
-    def _send_directory(self, remote, digest, u_uid=uuid.uuid4()):
-        required_blobs = self._required_blobs(digest)
-
-        missing_blobs = dict()
-        # Limit size of FindMissingBlobs request
-        for required_blobs_group in _grouper(required_blobs, 512):
-            request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=remote.spec.instance_name)
-
-            for required_digest in required_blobs_group:
-                d = request.blob_digests.add()
-                d.hash = required_digest.hash
-                d.size_bytes = required_digest.size_bytes
-
-            response = remote.cas.FindMissingBlobs(request)
-            for missing_digest in response.missing_blob_digests:
-                d = remote_execution_pb2.Digest()
-                d.hash = missing_digest.hash
-                d.size_bytes = missing_digest.size_bytes
-                missing_blobs[d.hash] = d
-
-        # Upload any blobs missing on the server
-        self._send_blobs(remote, missing_blobs.values(), u_uid)
-
-    def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()):
-        batch = _CASBatchUpdate(remote)
-
-        for digest in digests:
-            with open(self.objpath(digest), 'rb') as f:
-                assert os.fstat(f.fileno()).st_size == digest.size_bytes
-
-                if (digest.size_bytes >= remote.max_batch_total_size_bytes or
-                        not remote.batch_update_supported):
-                    # Too large for batch request, upload in independent request.
-                    remote._send_blob(digest, f, u_uid=u_uid)
-                else:
-                    if not batch.add(digest, f):
-                        # Not enough space left in batch request.
-                        # Complete pending batch first.
-                        batch.send()
-                        batch = _CASBatchUpdate(remote)
-                        batch.add(digest, f)
-
-        # Send final batch
-        batch.send()
-
-
-def _grouper(iterable, n):
-    while True:
-        try:
-            current = next(iterable)
-        except StopIteration:
-            return
-        yield itertools.chain([current], itertools.islice(iterable, n - 1))
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
index 0e75b09..8435230 100644
--- a/buildstream/_cas/casremote.py
+++ b/buildstream/_cas/casremote.py
@@ -1,5 +1,6 @@
 from collections import namedtuple
 import io
+import itertools
 import os
 import multiprocessing
 import signal
@@ -288,6 +289,18 @@ class CASRemote():
             else:
                 return None
 
+    # update_reference():
+    #
+    # Args:
+    #    ref (str): Reference to update
+    #    digest (Digest): New digest to update ref with
+    def update_reference(self, ref, digest):
+        request = buildstream_pb2.UpdateReferenceRequest()
+        request.keys.append(ref)
+        request.digest.hash = digest.hash
+        request.digest.size_bytes = digest.size_bytes
+        self.ref_storage.UpdateReference(request)
+
     def get_tree_blob(self, tree_digest):
         self.init()
         f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
@@ -397,6 +410,68 @@ class CASRemote():
         while self.__tmp_downloads:
             yield self.__tmp_downloads.pop()
 
+    # upload_blob():
+    #
+    # Push blobs given an iterator over blob files
+    #
+    # Args:
+    #    digest (Digest): digest we want to upload
+    #    blob_file (str): Name of file location
+    #    u_uid (str): Used to identify to the bytestream service
+    #
+    def upload_blob(self, digest, blob_file, u_uid=uuid.uuid4()):
+        with open(blob_file, 'rb') as f:
+            assert os.fstat(f.fileno()).st_size == digest.size_bytes
+
+            if (digest.size_bytes >= self.max_batch_total_size_bytes or
+                    not self.batch_update_supported):
+                # Too large for batch request, upload in independent request.
+                self._send_blob(digest, f, u_uid=u_uid)
+            else:
+                if self.__batch_update.add(digest, f) is False:
+                    self.__batch_update.send()
+                    self.__batch_update = _CASBatchUpdate(self)
+                    self.__batch_update.add(digest, f)
+
+    # send_update_batch():
+    #
+    # Sends anything left in the update batch
+    #
+    def send_update_batch(self):
+        # make sure everything is sent
+        self.__batch_update.send()
+        self.__batch_update = _CASBatchUpdate(self)
+
+    # find_missing_blobs()
+    #
+    # Does FindMissingBlobs request to remote
+    #
+    # Args:
+    #    required_blobs ([Digest]): list of blobs required
+    #
+    # Returns:
+    #    (Dict(Digest)): missing blobs
+    def find_missing_blobs(self, required_blobs):
+        self.init()
+        missing_blobs = dict()
+        # Limit size of FindMissingBlobs request
+        for required_blobs_group in _grouper(required_blobs, 512):
+            request = remote_execution_pb2.FindMissingBlobsRequest()
+
+            for required_digest in required_blobs_group:
+                d = request.blob_digests.add()
+                d.hash = required_digest.hash
+                d.size_bytes = required_digest.size_bytes
+
+            response = self.cas.FindMissingBlobs(request)
+            for missing_digest in response.missing_blob_digests:
+                d = remote_execution_pb2.Digest()
+                d.hash = missing_digest.hash
+                d.size_bytes = missing_digest.size_bytes
+                missing_blobs[d.hash] = d
+
+        return missing_blobs
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -435,7 +510,10 @@ class CASRemote():
                 offset += chunk_size
                 finished = request.finish_write
 
-        response = self.bytestream.Write(request_stream(resource_name, stream))
+        try:
+            response = self.bytestream.Write(request_stream(resource_name, stream))
+        except grpc.RpcError as e:
+            raise CASRemoteError("Failed to upload blob: {}".format(e), reason=e.code())
 
         assert response.committed_size == digest.size_bytes
 
@@ -449,6 +527,15 @@ class CASRemote():
         self.__batch_read = _CASBatchRead(self)
 
 
+def _grouper(iterable, n):
+    while True:
+        try:
+            current = next(iterable)
+        except StopIteration:
+            return
+        yield itertools.chain([current], itertools.islice(iterable, n - 1))
+
+
 # Represents a batch of blobs queued for fetching.
 #
 class _CASBatchRead():
@@ -480,7 +567,11 @@ class _CASBatchRead():
         if not self._request.digests:
             return
 
-        batch_response = self._remote.cas.BatchReadBlobs(self._request)
+        try:
+            batch_response = self._remote.cas.BatchReadBlobs(self._request)
+        except grpc.RpcError as e:
+            raise CASRemoteError("Failed to read blob batch: {}".format(e),
+                                 reason=e.code()) from e
 
         for response in batch_response.responses:
             if response.status.code == code_pb2.NOT_FOUND:
@@ -528,7 +619,12 @@ class _CASBatchUpdate():
         if not self._request.requests:
             return
 
-        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
+        # Want to raise a CASRemoteError if
+        try:
+            batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
+        except grpc.RpcError as e:
+            raise CASRemoteError("Failed to upload blob batch: {}".format(e),
+                                 reason=e.code()) from e
 
         for response in batch_response.responses:
             if response.status.code != code_pb2.OK:
diff --git a/buildstream/_cas/transfer.py b/buildstream/_cas/transfer.py
index 5eaaf09..c1293c0 100644
--- a/buildstream/_cas/transfer.py
+++ b/buildstream/_cas/transfer.py
@@ -49,3 +49,14 @@ def cas_tree_download(caslocal, casremote, tree_digest):
 
     # get root digest from tree and return that
     return _message_digest(tree.root.SerializeToString())
+
+
+def cas_directory_upload(caslocal, casremote, root_digest):
+    required_blobs = caslocal.yield_directory_digests(root_digest)
+    missing_blobs = casremote.find_missing_blobs(required_blobs)
+    for blob in missing_blobs.values():
+        blob_file = caslocal.objpath(blob)
+        casremote.upload_blob(blob, blob_file)
+
+    # send remaining blobs
+    casremote.send_update_batch()
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index bea1754..baa5aaa 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -39,7 +39,7 @@ from .._exceptions import SandboxError
 from .. import _yaml
 from .._protos.google.longrunning import operations_pb2, operations_pb2_grpc
 from .._cas import CASRemote, CASRemoteSpec
-from .._cas.transfer import cas_tree_download
+from .._cas.transfer import cas_tree_download, cas_directory_upload
 
 
 class RemoteExecutionSpec(namedtuple('RemoteExecutionSpec', 'exec_service storage_service action_service')):
@@ -345,7 +345,7 @@ class SandboxRemote(Sandbox):
 
             # Now, push that key (without necessarily needing a ref) to the remote.
             try:
-                cascache.push_directory(casremote, upload_vdir)
+                cas_directory_upload(cascache, casremote, upload_vdir.ref)
             except grpc.RpcError as e:
                 raise SandboxError("Failed to push source directory to remote: {}".format(e)) from e
 


[buildstream] 02/06: casremote.py: Move remote CAS classes into its own file

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit fe01405b4d7160a3a1b4ba20a410e5fe0808fa54
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Fri Dec 7 17:51:20 2018 +0000

    casremote.py: Move remote CAS classes into its own file
    
    Part of #802
---
 buildstream/_cas/__init__.py  |   3 +-
 buildstream/_cas/cascache.py  | 275 +++---------------------------------------
 buildstream/_cas/casremote.py | 247 +++++++++++++++++++++++++++++++++++++
 buildstream/_exceptions.py    |  15 +++
 4 files changed, 283 insertions(+), 257 deletions(-)

diff --git a/buildstream/_cas/__init__.py b/buildstream/_cas/__init__.py
index 7386109..a88e413 100644
--- a/buildstream/_cas/__init__.py
+++ b/buildstream/_cas/__init__.py
@@ -17,4 +17,5 @@
 #  Authors:
 #        Tristan Van Berkom <tr...@codethink.co.uk>
 
-from .cascache import CASCache, CASRemote, CASRemoteSpec
+from .cascache import CASCache
+from .casremote import CASRemote, CASRemoteSpec
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index 482d400..5f62e61 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -17,7 +17,6 @@
 #  Authors:
 #        Jürg Billeter <ju...@codethink.co.uk>
 
-from collections import namedtuple
 import hashlib
 import itertools
 import io
@@ -26,76 +25,17 @@ import stat
 import tempfile
 import uuid
 import contextlib
-from urllib.parse import urlparse
 
 import grpc
 
-from .._protos.google.rpc import code_pb2
-from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
-from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
-from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
+from .._protos.google.bytestream import bytestream_pb2
+from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
+from .._protos.buildstream.v2 import buildstream_pb2
 
 from .. import utils
-from .._exceptions import CASError, LoadError, LoadErrorReason
-from .. import _yaml
+from .._exceptions import CASCacheError
 
-
-# The default limit for gRPC messages is 4 MiB.
-# Limit payload to 1 MiB to leave sufficient headroom for metadata.
-_MAX_PAYLOAD_BYTES = 1024 * 1024
-
-
-class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):
-
-    # _new_from_config_node
-    #
-    # Creates an CASRemoteSpec() from a YAML loaded node
-    #
-    @staticmethod
-    def _new_from_config_node(spec_node, basedir=None):
-        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name'])
-        url = _yaml.node_get(spec_node, str, 'url')
-        push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
-        if not url:
-            provenance = _yaml.node_get_provenance(spec_node, 'url')
-            raise LoadError(LoadErrorReason.INVALID_DATA,
-                            "{}: empty artifact cache URL".format(provenance))
-
-        instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None)
-
-        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
-        if server_cert and basedir:
-            server_cert = os.path.join(basedir, server_cert)
-
-        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
-        if client_key and basedir:
-            client_key = os.path.join(basedir, client_key)
-
-        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
-        if client_cert and basedir:
-            client_cert = os.path.join(basedir, client_cert)
-
-        if client_key and not client_cert:
-            provenance = _yaml.node_get_provenance(spec_node, 'client-key')
-            raise LoadError(LoadErrorReason.INVALID_DATA,
-                            "{}: 'client-key' was specified without 'client-cert'".format(provenance))
-
-        if client_cert and not client_key:
-            provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
-            raise LoadError(LoadErrorReason.INVALID_DATA,
-                            "{}: 'client-cert' was specified without 'client-key'".format(provenance))
-
-        return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)
-
-
-CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)
-
-
-class BlobNotFound(CASError):
-
-    def __init__(self, blob, msg):
-        self.blob = blob
-        super().__init__(msg)
+from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES
 
 
 # A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -120,7 +60,7 @@ class CASCache():
         headdir = os.path.join(self.casdir, 'refs', 'heads')
         objdir = os.path.join(self.casdir, 'objects')
         if not (os.path.isdir(headdir) and os.path.isdir(objdir)):
-            raise CASError("CAS repository check failed for '{}'".format(self.casdir))
+            raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir))
 
     # contains():
     #
@@ -169,7 +109,7 @@ class CASCache():
     #     subdir (str): Optional specific dir to extract
     #
     # Raises:
-    #     CASError: In cases there was an OSError, or if the ref did not exist.
+    #     CASCacheError: In cases there was an OSError, or if the ref did not exist.
     #
     # Returns: path to extracted directory
     #
@@ -201,7 +141,7 @@ class CASCache():
                 # Another process beat us to rename
                 pass
             except OSError as e:
-                raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
+                raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
 
         return originaldest
 
@@ -306,7 +246,7 @@ class CASCache():
             return True
         except grpc.RpcError as e:
             if e.code() != grpc.StatusCode.NOT_FOUND:
-                raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e
+                raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e
             else:
                 return False
         except BlobNotFound as e:
@@ -360,7 +300,7 @@ class CASCache():
     #   (bool): True if any remote was updated, False if no pushes were required
     #
     # Raises:
-    #   (CASError): if there was an error
+    #   (CASCacheError): if there was an error
     #
     def push(self, refs, remote):
         skipped_remote = True
@@ -395,7 +335,7 @@ class CASCache():
                 skipped_remote = False
         except grpc.RpcError as e:
             if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED:
-                raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
+                raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
 
         return not skipped_remote
 
@@ -408,7 +348,7 @@ class CASCache():
     #     directory (Directory): A virtual directory object to push.
     #
     # Raises:
-    #     (CASError): if there was an error
+    #     (CASCacheError): if there was an error
     #
     def push_directory(self, remote, directory):
         remote.init()
@@ -424,7 +364,7 @@ class CASCache():
     #     message (Message): A protobuf message to push.
     #
     # Raises:
-    #     (CASError): if there was an error
+    #     (CASCacheError): if there was an error
     #
     def push_message(self, remote, message):
 
@@ -531,7 +471,7 @@ class CASCache():
             pass
 
         except OSError as e:
-            raise CASError("Failed to hash object: {}".format(e)) from e
+            raise CASCacheError("Failed to hash object: {}".format(e)) from e
 
         return digest
 
@@ -572,7 +512,7 @@ class CASCache():
                 return digest
 
         except FileNotFoundError as e:
-            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
+            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
 
     # update_mtime()
     #
@@ -585,7 +525,7 @@ class CASCache():
         try:
             os.utime(self._refpath(ref))
         except FileNotFoundError as e:
-            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
+            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e
 
     # calculate_cache_size()
     #
@@ -676,7 +616,7 @@ class CASCache():
         # Remove cache ref
         refpath = self._refpath(ref)
         if not os.path.exists(refpath):
-            raise CASError("Could not find ref '{}'".format(ref))
+            raise CASCacheError("Could not find ref '{}'".format(ref))
 
         os.unlink(refpath)
 
@@ -792,7 +732,7 @@ class CASCache():
                 # The process serving the socket can't be cached anyway
                 pass
             else:
-                raise CASError("Unsupported file type for {}".format(full_path))
+                raise CASCacheError("Unsupported file type for {}".format(full_path))
 
         return self.add_object(digest=dir_digest,
                                buffer=directory.SerializeToString())
@@ -811,7 +751,7 @@ class CASCache():
             if dirnode.name == name:
                 return dirnode.digest
 
-        raise CASError("Subdirectory {} not found".format(name))
+        raise CASCacheError("Subdirectory {} not found".format(name))
 
     def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""):
         dir_a = remote_execution_pb2.Directory()
@@ -1150,183 +1090,6 @@ class CASCache():
         batch.send()
 
 
-# Represents a single remote CAS cache.
-#
-class CASRemote():
-    def __init__(self, spec):
-        self.spec = spec
-        self._initialized = False
-        self.channel = None
-        self.bytestream = None
-        self.cas = None
-        self.ref_storage = None
-        self.batch_update_supported = None
-        self.batch_read_supported = None
-        self.capabilities = None
-        self.max_batch_total_size_bytes = None
-
-    def init(self):
-        if not self._initialized:
-            url = urlparse(self.spec.url)
-            if url.scheme == 'http':
-                port = url.port or 80
-                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
-            elif url.scheme == 'https':
-                port = url.port or 443
-
-                if self.spec.server_cert:
-                    with open(self.spec.server_cert, 'rb') as f:
-                        server_cert_bytes = f.read()
-                else:
-                    server_cert_bytes = None
-
-                if self.spec.client_key:
-                    with open(self.spec.client_key, 'rb') as f:
-                        client_key_bytes = f.read()
-                else:
-                    client_key_bytes = None
-
-                if self.spec.client_cert:
-                    with open(self.spec.client_cert, 'rb') as f:
-                        client_cert_bytes = f.read()
-                else:
-                    client_cert_bytes = None
-
-                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
-                                                           private_key=client_key_bytes,
-                                                           certificate_chain=client_cert_bytes)
-                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
-            else:
-                raise CASError("Unsupported URL: {}".format(self.spec.url))
-
-            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
-            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
-            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
-            self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
-
-            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
-            try:
-                request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name)
-                response = self.capabilities.GetCapabilities(request)
-                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
-                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
-                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
-            except grpc.RpcError as e:
-                # Simply use the defaults for servers that don't implement GetCapabilities()
-                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
-                    raise
-
-            # Check whether the server supports BatchReadBlobs()
-            self.batch_read_supported = False
-            try:
-                request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name)
-                response = self.cas.BatchReadBlobs(request)
-                self.batch_read_supported = True
-            except grpc.RpcError as e:
-                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
-                    raise
-
-            # Check whether the server supports BatchUpdateBlobs()
-            self.batch_update_supported = False
-            try:
-                request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name)
-                response = self.cas.BatchUpdateBlobs(request)
-                self.batch_update_supported = True
-            except grpc.RpcError as e:
-                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
-                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
-                    raise
-
-            self._initialized = True
-
-
-# Represents a batch of blobs queued for fetching.
-#
-class _CASBatchRead():
-    def __init__(self, remote):
-        self._remote = remote
-        self._max_total_size_bytes = remote.max_batch_total_size_bytes
-        self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name)
-        self._size = 0
-        self._sent = False
-
-    def add(self, digest):
-        assert not self._sent
-
-        new_batch_size = self._size + digest.size_bytes
-        if new_batch_size > self._max_total_size_bytes:
-            # Not enough space left in current batch
-            return False
-
-        request_digest = self._request.digests.add()
-        request_digest.hash = digest.hash
-        request_digest.size_bytes = digest.size_bytes
-        self._size = new_batch_size
-        return True
-
-    def send(self):
-        assert not self._sent
-        self._sent = True
-
-        if not self._request.digests:
-            return
-
-        batch_response = self._remote.cas.BatchReadBlobs(self._request)
-
-        for response in batch_response.responses:
-            if response.status.code == code_pb2.NOT_FOUND:
-                raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
-                    response.digest.hash, response.status.code))
-            if response.status.code != code_pb2.OK:
-                raise CASError("Failed to download blob {}: {}".format(
-                    response.digest.hash, response.status.code))
-            if response.digest.size_bytes != len(response.data):
-                raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
-                    response.digest.hash, response.digest.size_bytes, len(response.data)))
-
-            yield (response.digest, response.data)
-
-
-# Represents a batch of blobs queued for upload.
-#
-class _CASBatchUpdate():
-    def __init__(self, remote):
-        self._remote = remote
-        self._max_total_size_bytes = remote.max_batch_total_size_bytes
-        self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name)
-        self._size = 0
-        self._sent = False
-
-    def add(self, digest, stream):
-        assert not self._sent
-
-        new_batch_size = self._size + digest.size_bytes
-        if new_batch_size > self._max_total_size_bytes:
-            # Not enough space left in current batch
-            return False
-
-        blob_request = self._request.requests.add()
-        blob_request.digest.hash = digest.hash
-        blob_request.digest.size_bytes = digest.size_bytes
-        blob_request.data = stream.read(digest.size_bytes)
-        self._size = new_batch_size
-        return True
-
-    def send(self):
-        assert not self._sent
-        self._sent = True
-
-        if not self._request.requests:
-            return
-
-        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
-
-        for response in batch_response.responses:
-            if response.status.code != code_pb2.OK:
-                raise CASError("Failed to upload blob {}: {}".format(
-                    response.digest.hash, response.status.code))
-
-
 def _grouper(iterable, n):
     while True:
         try:
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
new file mode 100644
index 0000000..59eb7e3
--- /dev/null
+++ b/buildstream/_cas/casremote.py
@@ -0,0 +1,247 @@
+from collections import namedtuple
+import os
+from urllib.parse import urlparse
+
+import grpc
+
+from .. import _yaml
+from .._protos.google.rpc import code_pb2
+from .._protos.google.bytestream import bytestream_pb2_grpc
+from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
+from .._protos.buildstream.v2 import buildstream_pb2_grpc
+
+from .._exceptions import CASRemoteError, LoadError, LoadErrorReason
+
+# The default limit for gRPC messages is 4 MiB.
+# Limit payload to 1 MiB to leave sufficient headroom for metadata.
+_MAX_PAYLOAD_BYTES = 1024 * 1024
+
+
+class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):
+
+    # _new_from_config_node
+    #
+    # Creates an CASRemoteSpec() from a YAML loaded node
+    #
+    @staticmethod
+    def _new_from_config_node(spec_node, basedir=None):
+        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance_name'])
+        url = _yaml.node_get(spec_node, str, 'url')
+        push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
+        if not url:
+            provenance = _yaml.node_get_provenance(spec_node, 'url')
+            raise LoadError(LoadErrorReason.INVALID_DATA,
+                            "{}: empty artifact cache URL".format(provenance))
+
+        instance_name = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
+
+        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
+        if server_cert and basedir:
+            server_cert = os.path.join(basedir, server_cert)
+
+        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
+        if client_key and basedir:
+            client_key = os.path.join(basedir, client_key)
+
+        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
+        if client_cert and basedir:
+            client_cert = os.path.join(basedir, client_cert)
+
+        if client_key and not client_cert:
+            provenance = _yaml.node_get_provenance(spec_node, 'client-key')
+            raise LoadError(LoadErrorReason.INVALID_DATA,
+                            "{}: 'client-key' was specified without 'client-cert'".format(provenance))
+
+        if client_cert and not client_key:
+            provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
+            raise LoadError(LoadErrorReason.INVALID_DATA,
+                            "{}: 'client-cert' was specified without 'client-key'".format(provenance))
+
+        return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)
+
+
+CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)
+
+
+class BlobNotFound(CASRemoteError):
+
+    def __init__(self, blob, msg):
+        self.blob = blob
+        super().__init__(msg)
+
+
+# Represents a single remote CAS cache.
+#
+class CASRemote():
+    def __init__(self, spec):
+        self.spec = spec
+        self._initialized = False
+        self.channel = None
+        self.bytestream = None
+        self.cas = None
+        self.ref_storage = None
+        self.batch_update_supported = None
+        self.batch_read_supported = None
+        self.capabilities = None
+        self.max_batch_total_size_bytes = None
+
+    def init(self):
+        if not self._initialized:
+            url = urlparse(self.spec.url)
+            if url.scheme == 'http':
+                port = url.port or 80
+                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
+            elif url.scheme == 'https':
+                port = url.port or 443
+
+                if self.spec.server_cert:
+                    with open(self.spec.server_cert, 'rb') as f:
+                        server_cert_bytes = f.read()
+                else:
+                    server_cert_bytes = None
+
+                if self.spec.client_key:
+                    with open(self.spec.client_key, 'rb') as f:
+                        client_key_bytes = f.read()
+                else:
+                    client_key_bytes = None
+
+                if self.spec.client_cert:
+                    with open(self.spec.client_cert, 'rb') as f:
+                        client_cert_bytes = f.read()
+                else:
+                    client_cert_bytes = None
+
+                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
+                                                           private_key=client_key_bytes,
+                                                           certificate_chain=client_cert_bytes)
+                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
+            else:
+                raise CASRemoteError("Unsupported URL: {}".format(self.spec.url))
+
+            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
+            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
+            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
+            self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)
+
+            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
+            try:
+                request = remote_execution_pb2.GetCapabilitiesRequest()
+                response = self.capabilities.GetCapabilities(request)
+                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
+                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
+                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
+            except grpc.RpcError as e:
+                # Simply use the defaults for servers that don't implement GetCapabilities()
+                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
+                    raise
+
+            # Check whether the server supports BatchReadBlobs()
+            self.batch_read_supported = False
+            try:
+                request = remote_execution_pb2.BatchReadBlobsRequest()
+                response = self.cas.BatchReadBlobs(request)
+                self.batch_read_supported = True
+            except grpc.RpcError as e:
+                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
+                    raise
+
+            # Check whether the server supports BatchUpdateBlobs()
+            self.batch_update_supported = False
+            try:
+                request = remote_execution_pb2.BatchUpdateBlobsRequest()
+                response = self.cas.BatchUpdateBlobs(request)
+                self.batch_update_supported = True
+            except grpc.RpcError as e:
+                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
+                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
+                    raise
+
+            self._initialized = True
+
+
+# Represents a batch of blobs queued for fetching.
+#
+class _CASBatchRead():
+    def __init__(self, remote):
+        self._remote = remote
+        self._max_total_size_bytes = remote.max_batch_total_size_bytes
+        self._request = remote_execution_pb2.BatchReadBlobsRequest()
+        self._size = 0
+        self._sent = False
+
+    def add(self, digest):
+        assert not self._sent
+
+        new_batch_size = self._size + digest.size_bytes
+        if new_batch_size > self._max_total_size_bytes:
+            # Not enough space left in current batch
+            return False
+
+        request_digest = self._request.digests.add()
+        request_digest.hash = digest.hash
+        request_digest.size_bytes = digest.size_bytes
+        self._size = new_batch_size
+        return True
+
+    def send(self):
+        assert not self._sent
+        self._sent = True
+
+        if not self._request.digests:
+            return
+
+        batch_response = self._remote.cas.BatchReadBlobs(self._request)
+
+        for response in batch_response.responses:
+            if response.status.code == code_pb2.NOT_FOUND:
+                raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
+                    response.digest.hash, response.status.code))
+            if response.status.code != code_pb2.OK:
+                raise CASRemoteError("Failed to download blob {}: {}".format(
+                    response.digest.hash, response.status.code))
+            if response.digest.size_bytes != len(response.data):
+                raise CASRemoteError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
+                    response.digest.hash, response.digest.size_bytes, len(response.data)))
+
+            yield (response.digest, response.data)
+
+
+# Represents a batch of blobs queued for upload.
+#
+class _CASBatchUpdate():
+    def __init__(self, remote):
+        self._remote = remote
+        self._max_total_size_bytes = remote.max_batch_total_size_bytes
+        self._request = remote_execution_pb2.BatchUpdateBlobsRequest()
+        self._size = 0
+        self._sent = False
+
+    def add(self, digest, stream):
+        assert not self._sent
+
+        new_batch_size = self._size + digest.size_bytes
+        if new_batch_size > self._max_total_size_bytes:
+            # Not enough space left in current batch
+            return False
+
+        blob_request = self._request.requests.add()
+        blob_request.digest.hash = digest.hash
+        blob_request.digest.size_bytes = digest.size_bytes
+        blob_request.data = stream.read(digest.size_bytes)
+        self._size = new_batch_size
+        return True
+
+    def send(self):
+        assert not self._sent
+        self._sent = True
+
+        if not self._request.requests:
+            return
+
+        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
+
+        for response in batch_response.responses:
+            if response.status.code != code_pb2.OK:
+                raise CASRemoteError("Failed to upload blob {}: {}".format(
+                    response.digest.hash, response.status.code))
diff --git a/buildstream/_exceptions.py b/buildstream/_exceptions.py
index ba0b9fa..ea5ea62 100644
--- a/buildstream/_exceptions.py
+++ b/buildstream/_exceptions.py
@@ -284,6 +284,21 @@ class CASError(BstError):
         super().__init__(message, detail=detail, domain=ErrorDomain.CAS, reason=reason, temporary=True)
 
 
+# CASRemoteError
+#
+# Raised when errors are encountered in the remote CAS
+class CASRemoteError(CASError):
+    pass
+
+
+# CASCacheError
+#
+# Raised when errors are encountered in the local CASCacheError
+#
+class CASCacheError(CASError):
+    pass
+
+
 # PipelineError
 #
 # Raised from pipeline operations


[buildstream] 05/06: artifactcache: Move pull logic into CASRemote

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 15db919f6fafd0c2b0abbddde78a084d5b5379a5
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Thu Dec 13 12:04:26 2018 +0000

    artifactcache: Move pull logic into CASRemote
    
    Seperates the pull logic into a remote/local API, so that artifact cache
    iterates over blob digests checks whether it has them, and then requests them
    if not. The request command allows batching of blobs where appropriate.
    
    Tests have been updated to ensure the correct tmpdir is set up in process
    wrappers, else invalid cross link errors happen in the CI. Additional asserts
    have been added to check that the temporary directories are cleared by the end
    of a pull.
    
    Part of #802
---
 buildstream/_artifactcache.py         |  54 +++++---
 buildstream/_cas/__init__.py          |   2 +-
 buildstream/_cas/cascache.py          | 224 ++--------------------------------
 buildstream/_cas/casremote.py         | 147 ++++++++++++++++++++++
 buildstream/_cas/transfer.py          |  51 ++++++++
 buildstream/sandbox/_sandboxremote.py |   4 +-
 conftest.py                           |   7 ++
 tests/artifactcache/pull.py           |  26 +++-
 tests/artifactcache/push.py           |  28 +++--
 tests/integration/pullbuildtrees.py   |   6 +
 10 files changed, 299 insertions(+), 250 deletions(-)

diff --git a/buildstream/_artifactcache.py b/buildstream/_artifactcache.py
index dd5b4b5..21db707 100644
--- a/buildstream/_artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -28,7 +28,8 @@ from ._message import Message, MessageType
 from . import utils
 from . import _yaml
 
-from ._cas import CASRemote, CASRemoteSpec
+from ._cas import BlobNotFound, CASRemote, CASRemoteSpec
+from ._cas.transfer import cas_directory_download, cas_tree_download
 
 
 CACHE_SIZE_FILE = "cache_size"
@@ -644,19 +645,31 @@ class ArtifactCache():
                 display_key = element._get_brief_display_key()
                 element.status("Pulling artifact {} <- {}".format(display_key, remote.spec.url))
 
-                if self.cas.pull(ref, remote, progress=progress, subdir=subdir, excluded_subdirs=excluded_subdirs):
-                    element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
-                    if subdir:
-                        # Attempt to extract subdir into artifact extract dir if it already exists
-                        # without containing the subdir. If the respective artifact extract dir does not
-                        # exist a complete extraction will complete.
-                        self.extract(element, key, subdir)
-                    # no need to pull from additional remotes
-                    return True
-                else:
+                root_digest = remote.get_reference(ref)
+
+                if not root_digest:
                     element.info("Remote ({}) does not have {} cached".format(
-                        remote.spec.url, element._get_brief_display_key()
-                    ))
+                        remote.spec.url, element._get_brief_display_key()))
+                    continue
+
+                try:
+                    cas_directory_download(self.cas, remote, root_digest, excluded_subdirs)
+                except BlobNotFound:
+                    element.info("Remote ({}) is missing blobs for {}".format(
+                        remote.spec.url, element._get_brief_display_key()))
+                    continue
+
+                self.cas.set_ref(ref, root_digest)
+
+                if subdir:
+                    # Attempt to extract subdir into artifact extract dir if it already exists
+                    # without containing the subdir. If the respective artifact extract dir does not
+                    # exist a complete extraction will complete.
+                    self.extract(element, key, subdir)
+
+                element.info("Pulled artifact {} <- {}".format(display_key, remote.spec.url))
+                # no need to pull from additional remotes
+                return True
 
             except CASError as e:
                 raise ArtifactError("Failed to pull artifact {}: {}".format(
@@ -671,15 +684,16 @@ class ArtifactCache():
     #
     # Args:
     #     project (Project): The current project
-    #     digest (Digest): The digest of the tree
+    #     tree_digest (Digest): The digest of the tree
     #
-    def pull_tree(self, project, digest):
+    def pull_tree(self, project, tree_digest):
         for remote in self._remotes[project]:
-            digest = self.cas.pull_tree(remote, digest)
-
-            if digest:
-                # no need to pull from additional remotes
-                return digest
+            try:
+                root_digest = cas_tree_download(self.cas, remote, tree_digest)
+            except BlobNotFound:
+                continue
+            else:
+                return root_digest
 
         return None
 
diff --git a/buildstream/_cas/__init__.py b/buildstream/_cas/__init__.py
index a88e413..20c0279 100644
--- a/buildstream/_cas/__init__.py
+++ b/buildstream/_cas/__init__.py
@@ -18,4 +18,4 @@
 #        Tristan Van Berkom <tr...@codethink.co.uk>
 
 from .cascache import CASCache
-from .casremote import CASRemote, CASRemoteSpec
+from .casremote import CASRemote, CASRemoteSpec, BlobNotFound
diff --git a/buildstream/_cas/cascache.py b/buildstream/_cas/cascache.py
index adbd34c..e3b0332 100644
--- a/buildstream/_cas/cascache.py
+++ b/buildstream/_cas/cascache.py
@@ -33,7 +33,7 @@ from .._protos.buildstream.v2 import buildstream_pb2
 from .. import utils
 from .._exceptions import CASCacheError
 
-from .casremote import BlobNotFound, _CASBatchRead, _CASBatchUpdate
+from .casremote import _CASBatchUpdate
 
 
 # A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -183,73 +183,6 @@ class CASCache():
 
         return modified, removed, added
 
-    # pull():
-    #
-    # Pull a ref from a remote repository.
-    #
-    # Args:
-    #     ref (str): The ref to pull
-    #     remote (CASRemote): The remote repository to pull from
-    #     progress (callable): The progress callback, if any
-    #     subdir (str): The optional specific subdir to pull
-    #     excluded_subdirs (list): The optional list of subdirs to not pull
-    #
-    # Returns:
-    #   (bool): True if pull was successful, False if ref was not available
-    #
-    def pull(self, ref, remote, *, progress=None, subdir=None, excluded_subdirs=None):
-        try:
-            remote.init()
-
-            request = buildstream_pb2.GetReferenceRequest(instance_name=remote.spec.instance_name)
-            request.key = ref
-            response = remote.ref_storage.GetReference(request)
-
-            tree = remote_execution_pb2.Digest()
-            tree.hash = response.digest.hash
-            tree.size_bytes = response.digest.size_bytes
-
-            # Check if the element artifact is present, if so just fetch the subdir.
-            if subdir and os.path.exists(self.objpath(tree)):
-                self._fetch_subdir(remote, tree, subdir)
-            else:
-                # Fetch artifact, excluded_subdirs determined in pullqueue
-                self._fetch_directory(remote, tree, excluded_subdirs=excluded_subdirs)
-
-            self.set_ref(ref, tree)
-
-            return True
-        except grpc.RpcError as e:
-            if e.code() != grpc.StatusCode.NOT_FOUND:
-                raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e
-            else:
-                return False
-        except BlobNotFound as e:
-            return False
-
-    # pull_tree():
-    #
-    # Pull a single Tree rather than a ref.
-    # Does not update local refs.
-    #
-    # Args:
-    #     remote (CASRemote): The remote to pull from
-    #     digest (Digest): The digest of the tree
-    #
-    def pull_tree(self, remote, digest):
-        try:
-            remote.init()
-
-            digest = self._fetch_tree(remote, digest)
-
-            return digest
-
-        except grpc.RpcError as e:
-            if e.code() != grpc.StatusCode.NOT_FOUND:
-                raise
-
-        return None
-
     # link_ref():
     #
     # Add an alias for an existing ref.
@@ -591,6 +524,16 @@ class CASCache():
         reachable = set()
         self._reachable_refs_dir(reachable, tree, update_mtime=True)
 
+    # Check to see if a blob is in the local CAS
+    # return None if not
+    def check_blob(self, digest):
+        objpath = self.objpath(digest)
+        if os.path.exists(objpath):
+            # already in local repository
+            return objpath
+        else:
+            return None
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -780,151 +723,6 @@ class CASCache():
         for dirnode in directory.directories:
             yield from self._required_blobs(dirnode.digest)
 
-    # _ensure_blob():
-    #
-    # Fetch and add blob if it's not already local.
-    #
-    # Args:
-    #     remote (Remote): The remote to use.
-    #     digest (Digest): Digest object for the blob to fetch.
-    #
-    # Returns:
-    #     (str): The path of the object
-    #
-    def _ensure_blob(self, remote, digest):
-        objpath = self.objpath(digest)
-        if os.path.exists(objpath):
-            # already in local repository
-            return objpath
-
-        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
-            remote._fetch_blob(digest, f)
-
-            added_digest = self.add_object(path=f.name, link_directly=True)
-            assert added_digest.hash == digest.hash
-
-        return objpath
-
-    def _batch_download_complete(self, batch):
-        for digest, data in batch.send():
-            with tempfile.NamedTemporaryFile(dir=self.tmpdir) as f:
-                f.write(data)
-                f.flush()
-
-                added_digest = self.add_object(path=f.name, link_directly=True)
-                assert added_digest.hash == digest.hash
-
-    # Helper function for _fetch_directory().
-    def _fetch_directory_batch(self, remote, batch, fetch_queue, fetch_next_queue):
-        self._batch_download_complete(batch)
-
-        # All previously scheduled directories are now locally available,
-        # move them to the processing queue.
-        fetch_queue.extend(fetch_next_queue)
-        fetch_next_queue.clear()
-        return _CASBatchRead(remote)
-
-    # Helper function for _fetch_directory().
-    def _fetch_directory_node(self, remote, digest, batch, fetch_queue, fetch_next_queue, *, recursive=False):
-        in_local_cache = os.path.exists(self.objpath(digest))
-
-        if in_local_cache:
-            # Skip download, already in local cache.
-            pass
-        elif (digest.size_bytes >= remote.max_batch_total_size_bytes or
-              not remote.batch_read_supported):
-            # Too large for batch request, download in independent request.
-            self._ensure_blob(remote, digest)
-            in_local_cache = True
-        else:
-            if not batch.add(digest):
-                # Not enough space left in batch request.
-                # Complete pending batch first.
-                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-                batch.add(digest)
-
-        if recursive:
-            if in_local_cache:
-                # Add directory to processing queue.
-                fetch_queue.append(digest)
-            else:
-                # Directory will be available after completing pending batch.
-                # Add directory to deferred processing queue.
-                fetch_next_queue.append(digest)
-
-        return batch
-
-    # _fetch_directory():
-    #
-    # Fetches remote directory and adds it to content addressable store.
-    #
-    # Fetches files, symbolic links and recursively other directories in
-    # the remote directory and adds them to the content addressable
-    # store.
-    #
-    # Args:
-    #     remote (Remote): The remote to use.
-    #     dir_digest (Digest): Digest object for the directory to fetch.
-    #     excluded_subdirs (list): The optional list of subdirs to not fetch
-    #
-    def _fetch_directory(self, remote, dir_digest, *, excluded_subdirs=None):
-        fetch_queue = [dir_digest]
-        fetch_next_queue = []
-        batch = _CASBatchRead(remote)
-        if not excluded_subdirs:
-            excluded_subdirs = []
-
-        while len(fetch_queue) + len(fetch_next_queue) > 0:
-            if not fetch_queue:
-                batch = self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-
-            dir_digest = fetch_queue.pop(0)
-
-            objpath = self._ensure_blob(remote, dir_digest)
-
-            directory = remote_execution_pb2.Directory()
-            with open(objpath, 'rb') as f:
-                directory.ParseFromString(f.read())
-
-            for dirnode in directory.directories:
-                if dirnode.name not in excluded_subdirs:
-                    batch = self._fetch_directory_node(remote, dirnode.digest, batch,
-                                                       fetch_queue, fetch_next_queue, recursive=True)
-
-            for filenode in directory.files:
-                batch = self._fetch_directory_node(remote, filenode.digest, batch,
-                                                   fetch_queue, fetch_next_queue)
-
-        # Fetch final batch
-        self._fetch_directory_batch(remote, batch, fetch_queue, fetch_next_queue)
-
-    def _fetch_subdir(self, remote, tree, subdir):
-        subdirdigest = self._get_subdir(tree, subdir)
-        self._fetch_directory(remote, subdirdigest)
-
-    def _fetch_tree(self, remote, digest):
-        # download but do not store the Tree object
-        with tempfile.NamedTemporaryFile(dir=self.tmpdir) as out:
-            remote._fetch_blob(digest, out)
-
-            tree = remote_execution_pb2.Tree()
-
-            with open(out.name, 'rb') as f:
-                tree.ParseFromString(f.read())
-
-            tree.children.extend([tree.root])
-            for directory in tree.children:
-                for filenode in directory.files:
-                    self._ensure_blob(remote, filenode.digest)
-
-                # place directory blob only in final location when we've downloaded
-                # all referenced blobs to avoid dangling references in the repository
-                dirbuffer = directory.SerializeToString()
-                dirdigest = self.add_object(buffer=dirbuffer)
-                assert dirdigest.size_bytes == len(dirbuffer)
-
-        return dirdigest
-
     def _send_directory(self, remote, digest, u_uid=uuid.uuid4()):
         required_blobs = self._required_blobs(digest)
 
diff --git a/buildstream/_cas/casremote.py b/buildstream/_cas/casremote.py
index f7af253..0e75b09 100644
--- a/buildstream/_cas/casremote.py
+++ b/buildstream/_cas/casremote.py
@@ -3,6 +3,7 @@ import io
 import os
 import multiprocessing
 import signal
+import tempfile
 from urllib.parse import urlparse
 import uuid
 
@@ -96,6 +97,11 @@ class CASRemote():
         self.tmpdir = str(tmpdir)
         os.makedirs(self.tmpdir, exist_ok=True)
 
+        self.__tmp_downloads = []  # files in the tmpdir waiting to be added to local caches
+
+        self.__batch_read = None
+        self.__batch_update = None
+
     def init(self):
         if not self._initialized:
             url = urlparse(self.spec.url)
@@ -153,6 +159,7 @@ class CASRemote():
                 request = remote_execution_pb2.BatchReadBlobsRequest()
                 response = self.cas.BatchReadBlobs(request)
                 self.batch_read_supported = True
+                self.__batch_read = _CASBatchRead(self)
             except grpc.RpcError as e:
                 if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                     raise
@@ -163,6 +170,7 @@ class CASRemote():
                 request = remote_execution_pb2.BatchUpdateBlobsRequest()
                 response = self.cas.BatchUpdateBlobs(request)
                 self.batch_update_supported = True
+                self.__batch_update = _CASBatchUpdate(self)
             except grpc.RpcError as e:
                 if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
                         e.code() != grpc.StatusCode.PERMISSION_DENIED):
@@ -259,6 +267,136 @@ class CASRemote():
 
         return message_digest
 
+    # get_reference():
+    #
+    # Args:
+    #    ref (str): The ref to request
+    #
+    # Returns:
+    #    (digest): digest of ref, None if not found
+    #
+    def get_reference(self, ref):
+        try:
+            self.init()
+
+            request = buildstream_pb2.GetReferenceRequest()
+            request.key = ref
+            return self.ref_storage.GetReference(request).digest
+        except grpc.RpcError as e:
+            if e.code() != grpc.StatusCode.NOT_FOUND:
+                raise CASRemoteError("Failed to find ref {}: {}".format(ref, e)) from e
+            else:
+                return None
+
+    def get_tree_blob(self, tree_digest):
+        self.init()
+        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+        self._fetch_blob(tree_digest, f)
+
+        tree = remote_execution_pb2.Tree()
+        with open(f.name, 'rb') as tmp:
+            tree.ParseFromString(tmp.read())
+
+        return tree
+
+    # yield_directory_digests():
+    #
+    # Recursively iterates over digests for files, symbolic links and other
+    # directories starting from a root digest
+    #
+    # Args:
+    #     root_digest (digest): The root_digest to get a tree of
+    #     progress (callable): The progress callback, if any
+    #     subdir (str): The optional specific subdir to pull
+    #     excluded_subdirs (list): The optional list of subdirs to not pull
+    #
+    # Returns:
+    #     (iter digests): recursively iterates over digests contained in root directory
+    #
+    def yield_directory_digests(self, root_digest, *, progress=None,
+                                subdir=None, excluded_subdirs=None):
+        self.init()
+
+        # Fetch artifact, excluded_subdirs determined in pullqueue
+        if excluded_subdirs is None:
+            excluded_subdirs = []
+
+        # get directory blob
+        f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+        self._fetch_blob(root_digest, f)
+
+        directory = remote_execution_pb2.Directory()
+        with open(f.name, 'rb') as tmp:
+            directory.ParseFromString(tmp.read())
+
+        yield root_digest
+        for filenode in directory.files:
+            yield filenode.digest
+
+        for dirnode in directory.directories:
+            if dirnode.name not in excluded_subdirs:
+                yield from self.yield_directory_digests(dirnode.digest)
+
+    # yield_tree_digests():
+    #
+    # Fetches a tree file from digests and then iterates over child digests
+    #
+    # Args:
+    #     tree_digest (digest): tree digest
+    #
+    # Returns:
+    #     (iter digests): iterates over digests in tree message
+    def yield_tree_digests(self, tree):
+        self.init()
+
+        tree.children.extend([tree.root])
+        for directory in tree.children:
+            for filenode in directory.files:
+                yield filenode.digest
+
+            # add the directory to downloaded tmp files to be added
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            f.write(directory.SerializeToString())
+            f.flush()
+            self.__tmp_downloads.append(f)
+
+    # request_blob():
+    #
+    # Request blob, triggering download depending via bytestream or cas
+    # BatchReadBlobs depending on size.
+    #
+    # Args:
+    #    digest (Digest): digest of the requested blob
+    #
+    def request_blob(self, digest):
+        if (not self.batch_read_supported or
+                digest.size_bytes > self.max_batch_total_size_bytes):
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            self._fetch_blob(digest, f)
+            self.__tmp_downloads.append(f)
+        elif self.__batch_read.add(digest) is False:
+            self._download_batch()
+            self.__batch_read.add(digest)
+
+    # get_blobs():
+    #
+    # Yield over downloaded blobs in the tmp file locations, causing the files
+    # to be deleted once they go out of scope.
+    #
+    # Args:
+    #    complete_batch (bool): download any outstanding batch read request
+    #
+    # Returns:
+    #    iterator over NamedTemporaryFile
+    def get_blobs(self, complete_batch=False):
+        # Send read batch request and download
+        if (complete_batch is True and
+                self.batch_read_supported is True):
+            self._download_batch()
+
+        while self.__tmp_downloads:
+            yield self.__tmp_downloads.pop()
+
     ################################################
     #             Local Private Methods            #
     ################################################
@@ -301,6 +439,15 @@ class CASRemote():
 
         assert response.committed_size == digest.size_bytes
 
+    def _download_batch(self):
+        for _, data in self.__batch_read.send():
+            f = tempfile.NamedTemporaryFile(dir=self.tmpdir)
+            f.write(data)
+            f.flush()
+            self.__tmp_downloads.append(f)
+
+        self.__batch_read = _CASBatchRead(self)
+
 
 # Represents a batch of blobs queued for fetching.
 #
diff --git a/buildstream/_cas/transfer.py b/buildstream/_cas/transfer.py
new file mode 100644
index 0000000..5eaaf09
--- /dev/null
+++ b/buildstream/_cas/transfer.py
@@ -0,0 +1,51 @@
+#
+#  Copyright (C) 2017-2018 Codethink Limited
+#
+#  This program is free software; you can redistribute it and/or
+#  modify it under the terms of the GNU Lesser General Public
+#  License as published by the Free Software Foundation; either
+#  version 2 of the License, or (at your option) any later version.
+#
+#  This library is distributed in the hope that it will be useful,
+#  but WITHOUT ANY WARRANTY; without even the implied warranty of
+#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.	 See the GNU
+#  Lesser General Public License for more details.
+#
+#  You should have received a copy of the GNU Lesser General Public
+#  License along with this library. If not, see <http://www.gnu.org/licenses/>.
+#
+#  Authors:
+#        Raoul Hidalgo Charman <ra...@codethink.co.uk>
+
+from ..utils import _message_digest
+
+
+def cas_directory_download(caslocal, casremote, root_digest, excluded_subdirs):
+    for blob_digest in casremote.yield_directory_digests(
+            root_digest, excluded_subdirs=excluded_subdirs):
+        if caslocal.check_blob(blob_digest):
+            continue
+        casremote.request_blob(blob_digest)
+        for blob_file in casremote.get_blobs():
+            caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # Request final CAS batch
+    for blob_file in casremote.get_blobs(complete_batch=True):
+        caslocal.add_object(path=blob_file.name, link_directly=True)
+
+
+def cas_tree_download(caslocal, casremote, tree_digest):
+    tree = casremote.get_tree_blob(tree_digest)
+    for blob_digest in casremote.yield_tree_digests(tree):
+        if caslocal.check_blob(blob_digest):
+            continue
+        casremote.request_blob(blob_digest)
+        for blob_file in casremote.get_blobs():
+            caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # Get the last batch
+    for blob_file in casremote.get_blobs(complete_batch=True):
+        caslocal.add_object(path=blob_file.name, link_directly=True)
+
+    # get root digest from tree and return that
+    return _message_digest(tree.root.SerializeToString())
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index 8c21041..bea1754 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -39,6 +39,7 @@ from .._exceptions import SandboxError
 from .. import _yaml
 from .._protos.google.longrunning import operations_pb2, operations_pb2_grpc
 from .._cas import CASRemote, CASRemoteSpec
+from .._cas.transfer import cas_tree_download
 
 
 class RemoteExecutionSpec(namedtuple('RemoteExecutionSpec', 'exec_service storage_service action_service')):
@@ -281,8 +282,7 @@ class SandboxRemote(Sandbox):
         cascache = context.get_cascache()
         casremote = CASRemote(self.storage_remote_spec, context.tmpdir)
 
-        # Now do a pull to ensure we have the necessary parts.
-        dir_digest = cascache.pull_tree(casremote, tree_digest)
+        dir_digest = cas_tree_download(cascache, casremote, tree_digest)
         if dir_digest is None or not dir_digest.hash or not dir_digest.size_bytes:
             raise SandboxError("Output directory structure pulling from remote failed.")
 
diff --git a/conftest.py b/conftest.py
index f3c09a5..9fb7f12 100755
--- a/conftest.py
+++ b/conftest.py
@@ -46,6 +46,13 @@ def integration_cache(request):
     else:
         cache_dir = os.path.abspath('./integration-cache')
 
+    # Clean up the tmp dir, should be empty but something in CI tests is
+    # leaving files here
+    try:
+        shutil.rmtree(os.path.join(cache_dir, 'tmp'))
+    except FileNotFoundError:
+        pass
+
     yield cache_dir
 
     # Clean up the artifacts after each test run - we only want to
diff --git a/tests/artifactcache/pull.py b/tests/artifactcache/pull.py
index 4c332bf..15d5c67 100644
--- a/tests/artifactcache/pull.py
+++ b/tests/artifactcache/pull.py
@@ -110,7 +110,7 @@ def test_pull(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_pull, queue, user_config_file, project_dir,
-                                                artifact_dir, 'target.bst', element_key))
+                                                artifact_dir, tmpdir, 'target.bst', element_key))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -126,14 +126,18 @@ def test_pull(cli, tmpdir, datafiles):
         assert not error
         assert cas.contains(element, element_key)
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_pull(user_config_file, project_dir, artifact_dir,
+
+def _test_pull(user_config_file, project_dir, artifact_dir, tmpdir,
                element_name, element_key, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -218,7 +222,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push_tree, queue, user_config_file, project_dir,
-                                                artifact_dir, artifact_digest))
+                                                artifact_dir, tmpdir, artifact_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -239,6 +243,9 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Assert that we are not cached locally anymore
         assert cli.get_element_state(project_dir, 'target.bst') != 'cached'
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
         tree_digest = remote_execution_pb2.Digest(hash=tree_hash,
                                                   size_bytes=tree_size)
 
@@ -246,7 +253,7 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Use subprocess to avoid creation of gRPC threads in main BuildStream process
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_pull_tree, queue, user_config_file, project_dir,
-                                                artifact_dir, tree_digest))
+                                                artifact_dir, tmpdir, tree_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -267,13 +274,18 @@ def test_pull_tree(cli, tmpdir, datafiles):
         # Ensure the entire Tree stucture has been pulled
         assert os.path.exists(cas.objpath(directory_digest))
 
+        # Check that the tmp dir is cleared out
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
 
-def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+def _test_push_tree(user_config_file, project_dir, artifact_dir, tmpdir,
+                    artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -304,12 +316,14 @@ def _test_push_tree(user_config_file, project_dir, artifact_dir, artifact_digest
         queue.put("No remote configured")
 
 
-def _test_pull_tree(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+def _test_pull_tree(user_config_file, project_dir, artifact_dir, tmpdir,
+                    artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
diff --git a/tests/artifactcache/push.py b/tests/artifactcache/push.py
index 116fa78..f97a231 100644
--- a/tests/artifactcache/push.py
+++ b/tests/artifactcache/push.py
@@ -89,7 +89,7 @@ def test_push(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push, queue, user_config_file, project_dir,
-                                                artifact_dir, 'target.bst', element_key))
+                                                artifact_dir, tmpdir, 'target.bst', element_key))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -105,14 +105,18 @@ def test_push(cli, tmpdir, datafiles):
         assert not error
         assert share.has_artifact('test', 'target.bst', element_key)
 
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_push(user_config_file, project_dir, artifact_dir,
+
+def _test_push(user_config_file, project_dir, artifact_dir, tmpdir,
                element_name, element_key, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -196,9 +200,10 @@ def test_push_directory(cli, tmpdir, datafiles):
         queue = multiprocessing.Queue()
         # Use subprocess to avoid creation of gRPC threads in main BuildStream process
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
-        process = multiprocessing.Process(target=_queue_wrapper,
-                                          args=(_test_push_directory, queue, user_config_file,
-                                                project_dir, artifact_dir, artifact_digest))
+        process = multiprocessing.Process(
+            target=_queue_wrapper,
+            args=(_test_push_directory, queue, user_config_file, project_dir,
+                  artifact_dir, tmpdir, artifact_digest))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -215,13 +220,17 @@ def test_push_directory(cli, tmpdir, datafiles):
         assert artifact_digest.hash == directory_hash
         assert share.has_object(artifact_digest)
 
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
 
-def _test_push_directory(user_config_file, project_dir, artifact_dir, artifact_digest, queue):
+
+def _test_push_directory(user_config_file, project_dir, artifact_dir, tmpdir,
+                         artifact_digest, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
@@ -273,7 +282,7 @@ def test_push_message(cli, tmpdir, datafiles):
         # See https://github.com/grpc/grpc/blob/master/doc/fork_support.md for details
         process = multiprocessing.Process(target=_queue_wrapper,
                                           args=(_test_push_message, queue, user_config_file,
-                                                project_dir, artifact_dir))
+                                                project_dir, artifact_dir, tmpdir))
 
         try:
             # Keep SIGINT blocked in the child process
@@ -291,13 +300,16 @@ def test_push_message(cli, tmpdir, datafiles):
                                                      size_bytes=message_size)
         assert share.has_object(message_digest)
 
+        assert os.listdir(os.path.join(str(tmpdir), 'cache', 'tmp')) == []
+
 
-def _test_push_message(user_config_file, project_dir, artifact_dir, queue):
+def _test_push_message(user_config_file, project_dir, artifact_dir, tmpdir, queue):
     # Fake minimal context
     context = Context()
     context.load(config=user_config_file)
     context.artifactdir = artifact_dir
     context.set_message_handler(message_handler)
+    context.tmpdir = os.path.join(str(tmpdir), 'cache', 'tmp')
 
     # Load the project manually
     project = Project(project_dir, context)
diff --git a/tests/integration/pullbuildtrees.py b/tests/integration/pullbuildtrees.py
index f6fc712..de13a3d 100644
--- a/tests/integration/pullbuildtrees.py
+++ b/tests/integration/pullbuildtrees.py
@@ -77,6 +77,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         result = cli.run(project=project, args=['--pull-buildtrees', 'pull', element_name])
         assert element_name in result.get_pulled_elements()
         assert os.path.isdir(buildtreedir)
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Pull artifact with pullbuildtrees set in user config, then assert
@@ -89,6 +91,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         assert element_name not in result.get_pulled_elements()
         result = cli.run(project=project, args=['--pull-buildtrees', 'pull', element_name])
         assert element_name not in result.get_pulled_elements()
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Pull artifact with default config and buildtrees cli flag set, then assert
@@ -99,6 +103,8 @@ def test_pullbuildtrees(cli, tmpdir, datafiles, integration_cache):
         cli.configure({'cache': {'pull-buildtrees': True}})
         result = cli.run(project=project, args=['pull', element_name])
         assert element_name not in result.get_pulled_elements()
+        # Check tmpdir for downloads is cleared
+        assert os.listdir(os.path.join(str(tmpdir), 'artifacts', 'tmp')) == []
         default_state(cli, tmpdir, share1)
 
         # Assert that a partial build element (not containing a populated buildtree dir)


[buildstream] 01/06: _cas: Rename artifactcache folder and move that to a root module

Posted by gi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch raoul/802-refactor-artifactcache
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit c69d12f9aa89b15e2ea489419907cc66c4b70450
Author: Raoul Hidalgo Charman <ra...@codethink.co.uk>
AuthorDate: Fri Dec 7 17:29:21 2018 +0000

    _cas: Rename artifactcache folder and move that to a root module
    
    Other components will start to reply on cas modules, and not the artifact cache
    modules so it should be organized to reflect this.
    
    All relevant imports have been changed.
    
    Part #802
---
 .../artifactcache.py => _artifactcache.py}               | 16 ++++++++--------
 buildstream/{_artifactcache => _cas}/__init__.py         |  2 +-
 buildstream/{_artifactcache => _cas}/cascache.py         |  0
 buildstream/{_artifactcache => _cas}/casserver.py        |  0
 buildstream/_context.py                                  |  2 +-
 buildstream/sandbox/_sandboxremote.py                    |  2 +-
 doc/source/using_configuring_artifact_server.rst         |  2 +-
 tests/artifactcache/config.py                            |  3 +--
 tests/artifactcache/expiry.py                            |  4 ++--
 tests/sandboxes/storage-tests.py                         |  2 +-
 tests/storage/virtual_directory_import.py                |  2 +-
 tests/testutils/artifactshare.py                         |  4 ++--
 tests/utils/misc.py                                      |  2 +-
 13 files changed, 20 insertions(+), 21 deletions(-)

diff --git a/buildstream/_artifactcache/artifactcache.py b/buildstream/_artifactcache.py
similarity index 99%
rename from buildstream/_artifactcache/artifactcache.py
rename to buildstream/_artifactcache.py
index b4b8df3..1b2b55d 100644
--- a/buildstream/_artifactcache/artifactcache.py
+++ b/buildstream/_artifactcache.py
@@ -23,14 +23,14 @@ import signal
 import string
 from collections.abc import Mapping
 
-from ..types import _KeyStrength
-from .._exceptions import ArtifactError, CASError, LoadError, LoadErrorReason
-from .._message import Message, MessageType
-from .. import _signals
-from .. import utils
-from .. import _yaml
-
-from .cascache import CASRemote, CASRemoteSpec
+from .types import _KeyStrength
+from ._exceptions import ArtifactError, CASError, LoadError, LoadErrorReason
+from ._message import Message, MessageType
+from . import _signals
+from . import utils
+from . import _yaml
+
+from ._cas import CASRemote, CASRemoteSpec
 
 
 CACHE_SIZE_FILE = "cache_size"
diff --git a/buildstream/_artifactcache/__init__.py b/buildstream/_cas/__init__.py
similarity index 91%
rename from buildstream/_artifactcache/__init__.py
rename to buildstream/_cas/__init__.py
index fad483a..7386109 100644
--- a/buildstream/_artifactcache/__init__.py
+++ b/buildstream/_cas/__init__.py
@@ -17,4 +17,4 @@
 #  Authors:
 #        Tristan Van Berkom <tr...@codethink.co.uk>
 
-from .artifactcache import ArtifactCache, ArtifactCacheSpec, CACHE_SIZE_FILE
+from .cascache import CASCache, CASRemote, CASRemoteSpec
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_cas/cascache.py
similarity index 100%
rename from buildstream/_artifactcache/cascache.py
rename to buildstream/_cas/cascache.py
diff --git a/buildstream/_artifactcache/casserver.py b/buildstream/_cas/casserver.py
similarity index 100%
rename from buildstream/_artifactcache/casserver.py
rename to buildstream/_cas/casserver.py
diff --git a/buildstream/_context.py b/buildstream/_context.py
index c62755c..324b455 100644
--- a/buildstream/_context.py
+++ b/buildstream/_context.py
@@ -31,7 +31,7 @@ from ._exceptions import LoadError, LoadErrorReason, BstError
 from ._message import Message, MessageType
 from ._profile import Topics, profile_start, profile_end
 from ._artifactcache import ArtifactCache
-from ._artifactcache.cascache import CASCache
+from ._cas import CASCache
 from ._workspaces import Workspaces, WorkspaceProjectCache, WORKSPACE_PROJECT_FILE
 from .plugin import _plugin_lookup
 from .sandbox import SandboxRemote
diff --git a/buildstream/sandbox/_sandboxremote.py b/buildstream/sandbox/_sandboxremote.py
index a842f08..8b4c87c 100644
--- a/buildstream/sandbox/_sandboxremote.py
+++ b/buildstream/sandbox/_sandboxremote.py
@@ -38,7 +38,7 @@ from .._protos.google.rpc import code_pb2
 from .._exceptions import SandboxError
 from .. import _yaml
 from .._protos.google.longrunning import operations_pb2, operations_pb2_grpc
-from .._artifactcache.cascache import CASRemote, CASRemoteSpec
+from .._cas import CASRemote, CASRemoteSpec
 
 
 class RemoteExecutionSpec(namedtuple('RemoteExecutionSpec', 'exec_service storage_service action_service')):
diff --git a/doc/source/using_configuring_artifact_server.rst b/doc/source/using_configuring_artifact_server.rst
index cc4880e..bcf7d0e 100644
--- a/doc/source/using_configuring_artifact_server.rst
+++ b/doc/source/using_configuring_artifact_server.rst
@@ -94,7 +94,7 @@ requiring BuildStream's more exigent dependencies by setting the
 Command reference
 ~~~~~~~~~~~~~~~~~
 
-.. click:: buildstream._artifactcache.casserver:server_main
+.. click:: buildstream._cas.casserver:server_main
    :prog: bst-artifact-server
 
 
diff --git a/tests/artifactcache/config.py b/tests/artifactcache/config.py
index df40d10..8c8c4b4 100644
--- a/tests/artifactcache/config.py
+++ b/tests/artifactcache/config.py
@@ -3,8 +3,7 @@ import pytest
 import itertools
 import os
 
-from buildstream._artifactcache import ArtifactCacheSpec
-from buildstream._artifactcache.artifactcache import _configured_remote_artifact_cache_specs
+from buildstream._artifactcache import ArtifactCacheSpec, _configured_remote_artifact_cache_specs
 from buildstream._context import Context
 from buildstream._project import Project
 from buildstream.utils import _deduplicate
diff --git a/tests/artifactcache/expiry.py b/tests/artifactcache/expiry.py
index 05cbe32..e739283 100644
--- a/tests/artifactcache/expiry.py
+++ b/tests/artifactcache/expiry.py
@@ -342,13 +342,13 @@ def test_invalid_cache_quota(cli, datafiles, tmpdir, quota, success):
         total_space = 10000
 
     volume_space_patch = mock.patch(
-        "buildstream._artifactcache.artifactcache.ArtifactCache._get_volume_space_info_for",
+        "buildstream._artifactcache.ArtifactCache._get_volume_space_info_for",
         autospec=True,
         return_value=(free_space, total_space),
     )
 
     cache_size_patch = mock.patch(
-        "buildstream._artifactcache.artifactcache.ArtifactCache.get_cache_size",
+        "buildstream._artifactcache.ArtifactCache.get_cache_size",
         autospec=True,
         return_value=0,
     )
diff --git a/tests/sandboxes/storage-tests.py b/tests/sandboxes/storage-tests.py
index e646a62..3871677 100644
--- a/tests/sandboxes/storage-tests.py
+++ b/tests/sandboxes/storage-tests.py
@@ -3,7 +3,7 @@ import pytest
 
 from buildstream._exceptions import ErrorDomain
 
-from buildstream._artifactcache.cascache import CASCache
+from buildstream._cas import CASCache
 from buildstream.storage._casbaseddirectory import CasBasedDirectory
 from buildstream.storage._filebaseddirectory import FileBasedDirectory
 
diff --git a/tests/storage/virtual_directory_import.py b/tests/storage/virtual_directory_import.py
index fa40b17..0bb47e3 100644
--- a/tests/storage/virtual_directory_import.py
+++ b/tests/storage/virtual_directory_import.py
@@ -8,7 +8,7 @@ from tests.testutils import cli
 from buildstream.storage._casbaseddirectory import CasBasedDirectory
 from buildstream.storage._filebaseddirectory import FileBasedDirectory
 from buildstream._artifactcache import ArtifactCache
-from buildstream._artifactcache.cascache import CASCache
+from buildstream._cas import CASCache
 from buildstream import utils
 
 
diff --git a/tests/testutils/artifactshare.py b/tests/testutils/artifactshare.py
index 38c54a9..6b03d8d 100644
--- a/tests/testutils/artifactshare.py
+++ b/tests/testutils/artifactshare.py
@@ -11,8 +11,8 @@ from multiprocessing import Process, Queue
 import pytest_cov
 
 from buildstream import _yaml
-from buildstream._artifactcache.cascache import CASCache
-from buildstream._artifactcache.casserver import create_server
+from buildstream._cas import CASCache
+from buildstream._cas.casserver import create_server
 from buildstream._exceptions import CASError
 from buildstream._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
 
diff --git a/tests/utils/misc.py b/tests/utils/misc.py
index 4ab29ad..a34d3cd 100644
--- a/tests/utils/misc.py
+++ b/tests/utils/misc.py
@@ -23,7 +23,7 @@ def test_parse_size_over_1024T(cli, tmpdir):
     _yaml.dump({'name': 'main'}, str(project.join("project.conf")))
 
     volume_space_patch = mock.patch(
-        "buildstream._artifactcache.artifactcache.ArtifactCache._get_volume_space_info_for",
+        "buildstream._artifactcache.ArtifactCache._get_volume_space_info_for",
         autospec=True,
         return_value=(1025 * TiB, 1025 * TiB)
     )