You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@buildstream.apache.org by tv...@apache.org on 2021/02/04 08:07:05 UTC

[buildstream] 29/41: _artifactcache/cascache.py: Add remote cache support

This is an automated email from the ASF dual-hosted git repository.

tvb pushed a commit to branch jmac/googlecas_and_virtual_directories_1
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit cf3c1c4a64e9b8a1b0c7c65926595a6258e5af10
Author: Jürg Billeter <j...@bitron.ch>
AuthorDate: Thu Mar 15 10:13:14 2018 +0100

    _artifactcache/cascache.py: Add remote cache support
---
 buildstream/_artifactcache/artifactcache.py |  31 ++-
 buildstream/_artifactcache/cascache.py      | 369 +++++++++++++++++++++++++++-
 buildstream/_project.py                     |   2 +-
 3 files changed, 391 insertions(+), 11 deletions(-)

diff --git a/buildstream/_artifactcache/artifactcache.py b/buildstream/_artifactcache/artifactcache.py
index 7260915..1a0d14f 100644
--- a/buildstream/_artifactcache/artifactcache.py
+++ b/buildstream/_artifactcache/artifactcache.py
@@ -36,22 +36,38 @@ from .. import _yaml
 #     push (bool): Whether we should attempt to push artifacts to this cache,
 #                  in addition to pulling from it.
 #
-class ArtifactCacheSpec(namedtuple('ArtifactCacheSpec', 'url push')):
+class ArtifactCacheSpec(namedtuple('ArtifactCacheSpec', 'url push server_cert client_key client_cert')):
 
     # _new_from_config_node
     #
     # Creates an ArtifactCacheSpec() from a YAML loaded node
     #
     @staticmethod
-    def _new_from_config_node(spec_node):
-        _yaml.node_validate(spec_node, ['url', 'push'])
+    def _new_from_config_node(spec_node, basedir=None):
+        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert'])
         url = _yaml.node_get(spec_node, str, 'url')
         push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
         if not url:
             provenance = _yaml.node_get_provenance(spec_node)
             raise LoadError(LoadErrorReason.INVALID_DATA,
                             "{}: empty artifact cache URL".format(provenance))
-        return ArtifactCacheSpec(url, push)
+
+        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
+        if server_cert and basedir:
+            server_cert = os.path.join(basedir, server_cert)
+
+        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
+        if client_key and basedir:
+            client_key = os.path.join(basedir, client_key)
+
+        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
+        if client_cert and basedir:
+            client_cert = os.path.join(basedir, client_cert)
+
+        return ArtifactCacheSpec(url, push, server_cert, client_key, client_cert)
+
+
+ArtifactCacheSpec.__new__.__defaults__ = (None, None, None)
 
 
 # An ArtifactCache manages artifacts.
@@ -139,6 +155,7 @@ class ArtifactCache():
     #
     # Args:
     #   config_node (dict): The config block, which may contain the 'artifacts' key
+    #   basedir (str): The base directory for relative paths
     #
     # Returns:
     #   A list of ArtifactCacheSpec instances.
@@ -147,15 +164,15 @@ class ArtifactCache():
     #   LoadError, if the config block contains invalid keys.
     #
     @staticmethod
-    def specs_from_config_node(config_node):
+    def specs_from_config_node(config_node, basedir=None):
         cache_specs = []
 
         artifacts = config_node.get('artifacts', [])
         if isinstance(artifacts, Mapping):
-            cache_specs.append(ArtifactCacheSpec._new_from_config_node(artifacts))
+            cache_specs.append(ArtifactCacheSpec._new_from_config_node(artifacts, basedir))
         elif isinstance(artifacts, list):
             for spec_node in artifacts:
-                cache_specs.append(ArtifactCacheSpec._new_from_config_node(spec_node))
+                cache_specs.append(ArtifactCacheSpec._new_from_config_node(spec_node, basedir))
         else:
             provenance = _yaml.node_get_provenance(config_node, key='artifacts')
             raise _yaml.LoadError(_yaml.LoadErrorReason.INVALID_DATA,
diff --git a/buildstream/_artifactcache/cascache.py b/buildstream/_artifactcache/cascache.py
index 5ff0455..880d93b 100644
--- a/buildstream/_artifactcache/cascache.py
+++ b/buildstream/_artifactcache/cascache.py
@@ -19,13 +19,21 @@
 #        Jürg Billeter <ju...@codethink.co.uk>
 
 import hashlib
+import itertools
+import multiprocessing
 import os
+import signal
 import stat
 import tempfile
+from urllib.parse import urlparse
 
-from google.devtools.remoteexecution.v1test import remote_execution_pb2
+import grpc
 
-from .. import utils
+from google.bytestream import bytestream_pb2, bytestream_pb2_grpc
+from google.devtools.remoteexecution.v1test import remote_execution_pb2, remote_execution_pb2_grpc
+from buildstream import buildstream_pb2, buildstream_pb2_grpc
+
+from .. import _signals, utils
 from .._exceptions import ArtifactError
 
 from . import ArtifactCache
@@ -36,15 +44,28 @@ from . import ArtifactCache
 #
 # Args:
 #     context (Context): The BuildStream context
+#     enable_push (bool): Whether pushing is allowed by the platform
+#
+# Pushing is explicitly disabled by the platform in some cases,
+# like when we are falling back to functioning without using
+# user namespaces.
 #
 class CASCache(ArtifactCache):
 
-    def __init__(self, context):
+    def __init__(self, context, *, enable_push=True):
         super().__init__(context)
 
         self.casdir = os.path.join(context.artifactdir, 'cas')
         os.makedirs(os.path.join(self.casdir, 'tmp'), exist_ok=True)
 
+        self._enable_push = enable_push
+
+        # Per-project list of _CASRemote instances.
+        self._remotes = {}
+
+        self._has_fetch_remotes = False
+        self._has_push_remotes = False
+
     ################################################
     #     Implementation of abstract methods       #
     ################################################
@@ -115,6 +136,205 @@ class CASCache(ArtifactCache):
 
         return modified, removed, added
 
+    def initialize_remotes(self, *, on_failure=None):
+        remote_specs = self.global_remote_specs
+
+        for project in self.project_remote_specs:
+            remote_specs += self.project_remote_specs[project]
+
+        remote_specs = list(utils._deduplicate(remote_specs))
+
+        remotes = {}
+        q = multiprocessing.Queue()
+        for remote_spec in remote_specs:
+            # Use subprocess to avoid creation of gRPC threads in main BuildStream process
+            p = multiprocessing.Process(target=self._initialize_remote, args=(remote_spec, q))
+
+            try:
+                # Keep SIGINT blocked in the child process
+                with _signals.blocked([signal.SIGINT], ignore=False):
+                    p.start()
+
+                error = q.get()
+                p.join()
+            except KeyboardInterrupt:
+                utils._kill_process_tree(p.pid)
+                raise
+
+            if error and on_failure:
+                on_failure(remote_spec.url, error)
+            elif error:
+                raise ArtifactError(error)
+            else:
+                self._has_fetch_remotes = True
+                if remote_spec.push:
+                    self._has_push_remotes = True
+
+                remotes[remote_spec.url] = _CASRemote(remote_spec)
+
+        for project in self.context.get_projects():
+            remote_specs = self.global_remote_specs
+            if project in self.project_remote_specs:
+                remote_specs = list(utils._deduplicate(remote_specs + self.project_remote_specs[project]))
+
+            project_remotes = []
+
+            for remote_spec in remote_specs:
+                # Errors are already handled in the loop above,
+                # skip unreachable remotes here.
+                if remote_spec.url not in remotes:
+                    continue
+
+                remote = remotes[remote_spec.url]
+                project_remotes.append(remote)
+
+            self._remotes[project] = project_remotes
+
+    def has_fetch_remotes(self, *, element=None):
+        if not self._has_fetch_remotes:
+            # No project has fetch remotes
+            return False
+        elif element is None:
+            # At least one (sub)project has fetch remotes
+            return True
+        else:
+            # Check whether the specified element's project has fetch remotes
+            remotes_for_project = self._remotes[element._get_project()]
+            return bool(remotes_for_project)
+
+    def has_push_remotes(self, *, element=None):
+        if not self._has_push_remotes or not self._enable_push:
+            # No project has push remotes
+            return False
+        elif element is None:
+            # At least one (sub)project has push remotes
+            return True
+        else:
+            # Check whether the specified element's project has push remotes
+            remotes_for_project = self._remotes[element._get_project()]
+            return any(remote.spec.push for remote in remotes_for_project)
+
+    def pull(self, element, key, *, progress=None):
+        ref = self.get_artifact_fullname(element, key)
+
+        project = element._get_project()
+
+        for remote in self._remotes[project]:
+            try:
+                remote.init()
+
+                request = buildstream_pb2.GetArtifactRequest()
+                request.key = ref
+                response = remote.artifact_cache.GetArtifact(request)
+
+                tree = remote_execution_pb2.Digest()
+                tree.hash = response.artifact.hash
+                tree.size_bytes = response.artifact.size_bytes
+
+                self._fetch_directory(remote, tree)
+
+                self.set_ref(ref, tree)
+
+                # no need to pull from additional remotes
+                return True
+
+            except grpc.RpcError as e:
+                if e.code() != grpc.StatusCode.NOT_FOUND:
+                    raise
+
+        return False
+
+    def link_key(self, element, oldkey, newkey):
+        oldref = self.get_artifact_fullname(element, oldkey)
+        newref = self.get_artifact_fullname(element, newkey)
+
+        tree = self.resolve_ref(oldref)
+
+        self.set_ref(newref, tree)
+
+    def push(self, element, keys):
+        refs = [self.get_artifact_fullname(element, key) for key in keys]
+
+        project = element._get_project()
+
+        push_remotes = [r for r in self._remotes[project] if r.spec.push]
+
+        pushed = False
+
+        for remote in push_remotes:
+            remote.init()
+
+            for ref in refs:
+                tree = self.resolve_ref(ref)
+
+                # Check whether ref is already on the server in which case
+                # there is no need to push the artifact
+                try:
+                    request = buildstream_pb2.GetArtifactRequest()
+                    request.key = ref
+                    response = remote.artifact_cache.GetArtifact(request)
+
+                    if response.artifact.hash == tree.hash and response.artifact.size_bytes == tree.size_bytes:
+                        # ref is already on the server with the same tree
+                        continue
+
+                except grpc.RpcError as e:
+                    if e.code() != grpc.StatusCode.NOT_FOUND:
+                        raise
+
+                missing_blobs = {}
+                required_blobs = self._required_blobs(tree)
+
+                # Limit size of FindMissingBlobs request
+                for required_blobs_group in _grouper(required_blobs, 512):
+                    request = remote_execution_pb2.FindMissingBlobsRequest()
+
+                    for required_digest in required_blobs_group:
+                        d = request.blob_digests.add()
+                        d.hash = required_digest.hash
+                        d.size_bytes = required_digest.size_bytes
+
+                    response = remote.cas.FindMissingBlobs(request)
+                    for digest in response.missing_blob_digests:
+                        d = remote_execution_pb2.Digest()
+                        d.hash = digest.hash
+                        d.size_bytes = digest.size_bytes
+                        missing_blobs[d.hash] = d
+
+                # Upload any blobs missing on the server
+                for digest in missing_blobs.values():
+                    def request_stream():
+                        resource_name = os.path.join(digest.hash, str(digest.size_bytes))
+                        with open(self.objpath(digest), 'rb') as f:
+                            assert os.fstat(f.fileno()).st_size == digest.size_bytes
+                            offset = 0
+                            finished = False
+                            remaining = digest.size_bytes
+                            while not finished:
+                                chunk_size = min(remaining, 64 * 1024)
+                                remaining -= chunk_size
+
+                                request = bytestream_pb2.WriteRequest()
+                                request.write_offset = offset
+                                # max. 64 kB chunks
+                                request.data = f.read(chunk_size)
+                                request.resource_name = resource_name
+                                request.finish_write = remaining <= 0
+                                yield request
+                                offset += chunk_size
+                                finished = request.finish_write
+                    response = remote.bytestream.Write(request_stream())
+
+                request = buildstream_pb2.UpdateArtifactRequest()
+                request.keys.append(ref)
+                request.artifact.hash = tree.hash
+                request.artifact.size_bytes = tree.size_bytes
+                remote.artifact_cache.UpdateArtifact(request)
+
+                pushed = True
+
+        return pushed
+
     ################################################
     #                API Private Methods           #
     ################################################
@@ -344,3 +564,146 @@ class CASCache(ArtifactCache):
                                      path=os.path.join(path, dir_a.directories[a].name))
                 a += 1
                 b += 1
+
+    def _initialize_remote(self, remote_spec, q):
+        try:
+            remote = _CASRemote(remote_spec)
+            remote.init()
+
+            request = buildstream_pb2.StatusRequest()
+            response = remote.artifact_cache.Status(request)
+
+            if remote_spec.push and not response.allow_updates:
+                q.put('Artifact server does not allow push')
+            else:
+                # No error
+                q.put(None)
+
+        except Exception as e:               # pylint: disable=broad-except
+            # Whatever happens, we need to return it to the calling process
+            #
+            q.put(str(e))
+
+    def _required_blobs(self, tree):
+        # parse directory, and recursively add blobs
+        d = remote_execution_pb2.Digest()
+        d.hash = tree.hash
+        d.size_bytes = tree.size_bytes
+        yield d
+
+        directory = remote_execution_pb2.Directory()
+
+        with open(self.objpath(tree), 'rb') as f:
+            directory.ParseFromString(f.read())
+
+        for filenode in directory.files:
+            d = remote_execution_pb2.Digest()
+            d.hash = filenode.digest.hash
+            d.size_bytes = filenode.digest.size_bytes
+            yield d
+
+        for dirnode in directory.directories:
+            yield from self._required_blobs(dirnode.digest)
+
+    def _fetch_blob(self, remote, digest, out):
+        resource_name = os.path.join(digest.hash, str(digest.size_bytes))
+        request = bytestream_pb2.ReadRequest()
+        request.resource_name = resource_name
+        request.read_offset = 0
+        for response in remote.bytestream.Read(request):
+            out.write(response.data)
+
+        out.flush()
+        assert digest.size_bytes == os.fstat(out.fileno()).st_size
+
+    def _fetch_directory(self, remote, tree):
+        objpath = self.objpath(tree)
+        if os.path.exists(objpath):
+            # already in local cache
+            return
+
+        with tempfile.NamedTemporaryFile(dir=os.path.join(self.casdir, 'tmp')) as out:
+            self._fetch_blob(remote, tree, out)
+
+            directory = remote_execution_pb2.Directory()
+
+            with open(out.name, 'rb') as f:
+                directory.ParseFromString(f.read())
+
+            for filenode in directory.files:
+                fileobjpath = self.objpath(tree)
+                if os.path.exists(fileobjpath):
+                    # already in local cache
+                    continue
+
+                with tempfile.NamedTemporaryFile(dir=os.path.join(self.casdir, 'tmp')) as f:
+                    self._fetch_blob(remote, filenode.digest, f)
+
+                    digest = self.add_object(path=f.name)
+                    assert digest.hash == filenode.digest.hash
+
+            for dirnode in directory.directories:
+                self._fetch_directory(remote, dirnode.digest)
+
+            # place directory blob only in final location when we've downloaded
+            # all referenced blobs to avoid dangling references in the repository
+            digest = self.add_object(path=out.name)
+            assert digest.hash == tree.hash
+
+
+# Represents a single remote CAS cache.
+#
+class _CASRemote():
+    def __init__(self, spec):
+        self.spec = spec
+        self._initialized = False
+        self.channel = None
+        self.bytestream = None
+        self.cas = None
+        self.artifact_cache = None
+
+    def init(self):
+        if not self._initialized:
+            url = urlparse(self.spec.url)
+            if url.scheme == 'http':
+                port = url.port or 80
+                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
+            elif url.scheme == 'https':
+                port = url.port or 443
+
+                if self.spec.server_cert:
+                    with open(self.spec.server_cert, 'rb') as f:
+                        server_cert_bytes = f.read()
+                else:
+                    server_cert_bytes = None
+
+                if self.spec.client_key:
+                    with open(self.spec.client_key, 'rb') as f:
+                        client_key_bytes = f.read()
+                else:
+                    client_key_bytes = None
+
+                if self.spec.client_cert:
+                    with open(self.spec.client_cert, 'rb') as f:
+                        client_cert_bytes = f.read()
+                else:
+                    client_cert_bytes = None
+
+                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
+                                                           private_key=client_key_bytes,
+                                                           certificate_chain=client_cert_bytes)
+                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
+            else:
+                raise ArtifactError("Unsupported URL: {}".format(self.spec.url))
+
+            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
+            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
+            self.artifact_cache = buildstream_pb2_grpc.ArtifactCacheStub(self.channel)
+
+            self._initialized = True
+
+
+def _grouper(iterable, n):
+    # pylint: disable=stop-iteration-return
+    while True:
+        yield itertools.chain([next(iterable)], itertools.islice(iterable, n - 1))
diff --git a/buildstream/_project.py b/buildstream/_project.py
index 5344e95..87f14ee 100644
--- a/buildstream/_project.py
+++ b/buildstream/_project.py
@@ -296,7 +296,7 @@ class Project():
         #
 
         # Load artifacts pull/push configuration for this project
-        self.artifact_cache_specs = ArtifactCache.specs_from_config_node(config)
+        self.artifact_cache_specs = ArtifactCache.specs_from_config_node(config, self.directory)
 
         # Workspace configurations
         self.workspaces = Workspaces(self)