You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@buildstream.apache.org by gi...@apache.org on 2020/12/29 13:19:00 UTC

[buildstream] 04/06: Extract casd_channel logic to CASDConnection

This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch aevri/casdprocessmanager2
in repository https://gitbox.apache.org/repos/asf/buildstream.git

commit 84aab60bd8066439ed4971a23b21288d30729a87
Author: Angelos Evripiotis <je...@bloomberg.net>
AuthorDate: Mon Oct 14 13:53:00 2019 +0100

    Extract casd_channel logic to CASDConnection
    
    Encapsulate the management of a connection to CASD, so we can hide the
    details of how it happens. This will make it easier to port to Windows,
    as we will have to take a different approach there.
    
    Also make get_local_cas() public, since it is already used outside of
    the CASCache class.
---
 src/buildstream/_cas/cascache.py           | 98 ++++++++++++++----------------
 src/buildstream/_cas/casdprocessmanager.py | 83 ++++++++++++++++++++++++-
 src/buildstream/_cas/casremote.py          |  6 +-
 3 files changed, 130 insertions(+), 57 deletions(-)

diff --git a/src/buildstream/_cas/cascache.py b/src/buildstream/_cas/cascache.py
index aefc1b9..65359ff 100644
--- a/src/buildstream/_cas/cascache.py
+++ b/src/buildstream/_cas/cascache.py
@@ -31,14 +31,14 @@ import time
 import grpc
 
 from .._protos.google.rpc import code_pb2
-from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
-from .._protos.build.buildgrid import local_cas_pb2, local_cas_pb2_grpc
+from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
+from .._protos.build.buildgrid import local_cas_pb2
 
 from .. import _signals, utils
 from ..types import FastEnum
 from .._exceptions import CASCacheError
 
-from .casdprocessmanager import CASDProcessManager
+from .casdprocessmanager import CASDConnection, CASDProcessManager
 from .casremote import _CASBatchRead, _CASBatchUpdate
 
 _BUFFER_SIZE = 65536
@@ -74,9 +74,6 @@ class CASCache():
         os.makedirs(os.path.join(self.casdir, 'objects'), exist_ok=True)
         os.makedirs(self.tmpdir, exist_ok=True)
 
-        self._casd_channel = None
-        self._casd_cas = None
-        self._local_cas = None
         self._cache_usage_monitor = None
         self._cache_usage_monitor_forbidden = False
 
@@ -107,43 +104,12 @@ class CASCache():
 
         return state
 
-    def _init_casd(self):
-        assert self._casd_process_manager, "CASCache was instantiated without buildbox-casd"
-
-        if not self._casd_channel:
-            while not os.path.exists(self._casd_process_manager.socket_path):
-                # casd is not ready yet, try again after a 10ms delay,
-                # but don't wait for more than 15s
-                if time.time() > self._casd_process_manager.start_time + 15:
-                    raise CASCacheError("Timed out waiting for buildbox-casd to become ready")
-
-                time.sleep(0.01)
-
-            self._casd_channel = grpc.insecure_channel('unix:' + self._casd_process_manager.socket_path)
-            self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel)
-            self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel)
-
-            # Call GetCapabilities() to establish connection to casd
-            capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self._casd_channel)
-            capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest())
-
-    # _get_cas():
-    #
-    # Return ContentAddressableStorage stub for buildbox-casd channel.
-    #
-    def _get_cas(self):
-        if not self._casd_cas:
-            self._init_casd()
-        return self._casd_cas
-
-    # _get_local_cas():
+    # get_local_cas():
     #
     # Return LocalCAS stub for buildbox-casd channel.
     #
-    def _get_local_cas(self):
-        if not self._local_cas:
-            self._init_casd()
-        return self._local_cas
+    def get_local_cas(self):
+        return self._casd_process_manager.get_connection().get_local_cas()
 
     # preflight():
     #
@@ -161,18 +127,17 @@ class CASCache():
     # against fork() with open gRPC channels.
     #
     def has_open_grpc_channels(self):
-        return bool(self._casd_channel)
+        if self._casd_process_manager:
+            return self._casd_process_manager.has_open_grpc_channels()
+        return False
 
     # close_grpc_channels():
     #
     # Close the casd channel if it exists
     #
     def close_grpc_channels(self):
-        if self._casd_channel:
-            self._local_cas = None
-            self._casd_cas = None
-            self._casd_channel.close()
-            self._casd_channel = None
+        if self._casd_process_manager:
+            self._casd_process_manager.close_grpc_channels()
 
     # release_resources():
     #
@@ -390,8 +355,7 @@ class CASCache():
 
             request.path.append(path)
 
-            local_cas = self._get_local_cas()
-
+            local_cas = self.get_local_cas()
             response = local_cas.CaptureFiles(request)
 
             if len(response.responses) != 1:
@@ -417,7 +381,7 @@ class CASCache():
     #     (Digest): The digest of the imported directory
     #
     def import_directory(self, path):
-        local_cas = self._get_local_cas()
+        local_cas = self.get_local_cas()
 
         request = local_cas_pb2.CaptureTreeRequest()
         request.path.append(path)
@@ -537,7 +501,7 @@ class CASCache():
     # Returns: List of missing Digest objects
     #
     def remote_missing_blobs(self, remote, blobs):
-        cas = self._get_cas()
+        cas = self._casd_process_manager.get_connection().get_cas()
         instance_name = remote.local_cas_instance_name
 
         missing_blobs = dict()
@@ -1032,7 +996,7 @@ class _CASCacheUsageMonitor:
 
         disk_usage = self._disk_usage
         disk_quota = self._disk_quota
-        local_cas = self.cas._get_local_cas()
+        local_cas = self.cas.get_local_cas()
 
         while True:
             try:
@@ -1071,5 +1035,33 @@ def _grouper(iterable, n):
 #
 class _LimitedCASDProcessManagerProxy:
     def __init__(self, casd_process_manager):
-        self.socket_path = casd_process_manager.socket_path
-        self.start_time = casd_process_manager.start_time
+        self._casd_connection = None
+        self._connection_string = casd_process_manager.connection_string
+        self._start_time = casd_process_manager.start_time
+        self._socket_path = casd_process_manager.socket_path
+
+    # get_connection():
+    #
+    # Return ContentAddressableStorage stub for buildbox-casd channel.
+    #
+    def get_connection(self):
+        if not self._casd_connection:
+            self._casd_connection = CASDConnection(
+                self._socket_path, self._connection_string, self._start_time)
+        return self._casd_connection
+
+    # has_open_grpc_channels():
+    #
+    # Return whether there are gRPC channel instances. This is used to safeguard
+    # against fork() with open gRPC channels.
+    #
+    def has_open_grpc_channels(self):
+        return bool(self._casd_connection)
+
+    # close_grpc_channels():
+    #
+    # Close the casd channel if it exists
+    #
+    def close_grpc_channels(self):
+        if self._casd_connection:
+            self._casd_connection.close()
diff --git a/src/buildstream/_cas/casdprocessmanager.py b/src/buildstream/_cas/casdprocessmanager.py
index 3a434ad..c096db1 100644
--- a/src/buildstream/_cas/casdprocessmanager.py
+++ b/src/buildstream/_cas/casdprocessmanager.py
@@ -25,7 +25,13 @@ import subprocess
 import tempfile
 import time
 
+import grpc
+
+from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
+from .._protos.build.buildgrid import local_cas_pb2_grpc
+
 from .. import _signals, utils
+from .._exceptions import CASCacheError
 from .._message import Message, MessageType
 
 _CASD_MAX_LOGFILES = 10
@@ -47,13 +53,16 @@ class CASDProcessManager:
     def __init__(self, path, log_dir, log_level, cache_quota, protect_session_blobs):
         self._log_dir = log_dir
 
+        self._casd_connection = None
+
         # Place socket in global/user temporary directory to avoid hitting
         # the socket path length limit.
         self._socket_tempdir = tempfile.mkdtemp(prefix='buildstream')
         self.socket_path = os.path.join(self._socket_tempdir, 'casd.sock')
+        self.connection_string = "unix:" + self.socket_path
 
         casd_args = [utils.get_host_tool('buildbox-casd')]
-        casd_args.append('--bind=unix:' + self.socket_path)
+        casd_args.append('--bind=' + self.connection_string)
         casd_args.append('--log-level=' + log_level.value)
 
         if cache_quota is not None:
@@ -215,3 +224,75 @@ class CASDProcessManager:
         assert self._failure_callback is not None
         self._process.returncode = returncode
         self._failure_callback()
+
+    # get_connection():
+    #
+    # Return ContentAddressableStorage stub for buildbox-casd channel.
+    #
+    def get_connection(self):
+        if not self._casd_connection:
+            self._casd_connection = CASDConnection(
+                self.socket_path, self.connection_string, self.start_time)
+        return self._casd_connection
+
+    # has_open_grpc_channels():
+    #
+    # Return whether there are gRPC channel instances. This is used to safeguard
+    # against fork() with open gRPC channels.
+    #
+    def has_open_grpc_channels(self):
+        return bool(self._casd_connection)
+
+    # close_grpc_channels():
+    #
+    # Close the casd channel if it exists
+    #
+    def close_grpc_channels(self):
+        if self._casd_connection:
+            self._casd_connection.close()
+
+
+class CASDConnection:
+    def __init__(self, socket_path, connection_string, start_time):
+        while not os.path.exists(socket_path):
+            # casd is not ready yet, try again after a 10ms delay,
+            # but don't wait for more than 15s
+            if time.time() > start_time + 15:
+                raise CASCacheError("Timed out waiting for buildbox-casd to become ready")
+
+            time.sleep(0.01)
+
+        self._casd_channel = grpc.insecure_channel(connection_string)
+        self._casd_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self._casd_channel)
+        self._local_cas = local_cas_pb2_grpc.LocalContentAddressableStorageStub(self._casd_channel)
+
+        # Call GetCapabilities() to establish connection to casd
+        capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self._casd_channel)
+        capabilities.GetCapabilities(remote_execution_pb2.GetCapabilitiesRequest())
+
+    # get_cas():
+    #
+    # Return ContentAddressableStorage stub for buildbox-casd channel.
+    #
+    def get_cas(self):
+        assert self._casd_channel is not None
+        return self._casd_cas
+
+    # get_local_cas():
+    #
+    # Return LocalCAS stub for buildbox-casd channel.
+    #
+    def get_local_cas(self):
+        assert self._casd_channel is not None
+        return self._local_cas
+
+    # close():
+    #
+    # Close the casd channel.
+    #
+    def close(self):
+        assert self._casd_channel is not None
+        self._local_cas = None
+        self._casd_cas = None
+        self._casd_channel.close()
+        self._casd_channel = None
diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py
index a054b28..c89ea9f 100644
--- a/src/buildstream/_cas/casremote.py
+++ b/src/buildstream/_cas/casremote.py
@@ -55,7 +55,7 @@ class CASRemote(BaseRemote):
     # be called outside of init().
     #
     def _configure_protocols(self):
-        local_cas = self.cascache._get_local_cas()
+        local_cas = self.cascache.get_local_cas()
         request = local_cas_pb2.GetInstanceNameForRemoteRequest()
         request.url = self.spec.url
         if self.spec.instance_name:
@@ -115,7 +115,7 @@ class _CASBatchRead():
         if not self._requests:
             return
 
-        local_cas = self._remote.cascache._get_local_cas()
+        local_cas = self._remote.cascache.get_local_cas()
 
         for request in self._requests:
             batch_response = local_cas.FetchMissingBlobs(request)
@@ -163,7 +163,7 @@ class _CASBatchUpdate():
         if not self._requests:
             return
 
-        local_cas = self._remote.cascache._get_local_cas()
+        local_cas = self._remote.cascache.get_local_cas()
 
         for request in self._requests:
             batch_response = local_cas.UploadMissingBlobs(request)