You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/21 14:06:31 UTC

[arrow-adbc] branch main updated: refactor(csharp): cleanup load of imported drivers (#818)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f7fbdc6 refactor(csharp): cleanup load of imported drivers (#818)
2f7fbdc6 is described below

commit 2f7fbdc67a1c062484f5ccc42664d8e1d37a1864
Author: Curt Hagenlocher <cu...@hagenlocher.org>
AuthorDate: Wed Jun 21 07:06:23 2023 -0700

    refactor(csharp): cleanup load of imported drivers (#818)
    
    Resolves #753
---
 .../src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj |   6 +-
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 144 ++++++++-------------
 csharp/src/Apache.Arrow.Adbc/C/NativeLibrary.cs    | 102 +++++++++++++++
 3 files changed, 158 insertions(+), 94 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj
index 9276cc1b..c49113e7 100644
--- a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj
+++ b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
 
   <PropertyGroup>
     <TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
@@ -9,4 +9,8 @@
     <ProjectReference Include="..\arrow\csharp\src\Apache.Arrow\Apache.Arrow.csproj" />
   </ItemGroup>
 
+  <ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible($(TargetFramework), 'net5.0'))">
+    <Compile Remove="C\NativeLibrary.cs" />
+  </ItemGroup>
+
 </Project>
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index c13121f3..792873e0 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -17,9 +17,9 @@
 
 using System;
 using System.Collections.Generic;
+using System.IO;
 using System.Runtime.InteropServices;
 using Apache.Arrow.C;
-using Microsoft.Win32.SafeHandles;
 
 #if NETSTANDARD
 using Apache.Arrow.Adbc.Extensions;
@@ -27,117 +27,62 @@ using Apache.Arrow.Adbc.Extensions;
 
 namespace Apache.Arrow.Adbc.C
 {
-    internal delegate byte AdbcDriverInit(int version, ref CAdbcDriver driver, ref CAdbcError error);
+    internal delegate AdbcStatusCode AdbcDriverInit(int version, ref CAdbcDriver driver, ref CAdbcError error);
 
     /// <summary>
-    /// Class for working with loading drivers from files
+    /// Class for working with imported drivers from files
     /// </summary>
     public static class CAdbcDriverImporter
     {
         private const string driverInit = "AdbcDriverInit";
         private const int ADBC_VERSION_1_0_0 = 1000000;
 
-        class NativeDriver
-        {
-            public SafeHandle driverHandle;
-            public CAdbcDriver driver;
-        }
-
         /// <summary>
-        /// Class used for Mac interoperability
+        /// Loads an <see cref="AdbcDriver"/> from the file system.
         /// </summary>
-        static class MacInterop
+        /// <param name="file">The path to the driver to load</param>
+        /// <param name="entryPoint">The name of the entry point. If not provided, the name AdbcDriverInit will be used.</param>
+        public static AdbcDriver Load(string file, string entryPoint = null)
         {
-            private const string libdl = "libdl.dylib";
-            private const int RTLD_NOW = 2;
-
-            [DllImport(libdl)]
-            static extern SafeLibraryHandle dlopen(string fileName, int flags);
-
-            [DllImport(libdl)]
-            static extern IntPtr dlsym(SafeHandle libraryHandle, string symbol);
-
-            [DllImport(libdl)]
-            static extern int dlclose(IntPtr handle);
-
-            sealed class SafeLibraryHandle : SafeHandleZeroOrMinusOneIsInvalid
+            if (file == null)
             {
-                SafeLibraryHandle() : base(true) { }
-
-                protected override bool ReleaseHandle()
-                {
-                    return dlclose(handle) == 0;
-                }
+                throw new ArgumentNullException(nameof(file));
             }
 
-            public static NativeDriver GetDriver(string file)
+            if (!File.Exists(file))
             {
-                SafeHandle library = dlopen(file, RTLD_NOW);
-                IntPtr symbol = dlsym(library, "AdbcDriverInit");
-                AdbcDriverInit init = Marshal.GetDelegateForFunctionPointer<AdbcDriverInit>(symbol);
-                CAdbcDriver driver = new CAdbcDriver();
-                CAdbcError error = new CAdbcError();
-                byte result = init(ADBC_VERSION_1_0_0, ref driver, ref error);
-                return new NativeDriver { driverHandle = library, driver = driver };
+                throw new ArgumentException("file does not exist", nameof(file));
             }
-        }
-
-        /// <summary>
-        /// Class used for Windows interoperability
-        /// </summary>
-        static class WindowsInterop
-        {
-            private const string kernel32 = "kernel32.dll";
-            private const int LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x1000;
-            private const int LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x0100;
-
-            [DllImport(kernel32)]
-            [return: MarshalAs(UnmanagedType.Bool)]
-            static extern bool FreeLibrary(IntPtr libraryHandle);
-
-            [DllImport(kernel32, CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
-            static extern IntPtr GetProcAddress(SafeHandle libraryHandle, string functionName);
 
-            [DllImport(kernel32, CharSet = CharSet.Unicode, SetLastError = true)]
-            static extern SafeLibraryHandle LoadLibraryEx(string fileName, IntPtr hFile, uint flags);
-
-            sealed class SafeLibraryHandle : SafeHandleZeroOrMinusOneIsInvalid
+            IntPtr library = NativeLibrary.Load(file);
+            if (library == IntPtr.Zero)
             {
-                SafeLibraryHandle() : base(true) { }
+                throw new ArgumentException("unable to load library", nameof(file));
+            }
 
-                protected override bool ReleaseHandle()
+            try
+            {
+                entryPoint = entryPoint ?? driverInit;
+                IntPtr export = NativeLibrary.GetExport(library, entryPoint);
+                if (export == IntPtr.Zero)
                 {
-                    return FreeLibrary(handle);
+                    NativeLibrary.Free(library);
+                    throw new ArgumentException($"Unable to find {entryPoint} export", nameof(file));
                 }
-            }
 
-            public static NativeDriver GetDriver(string file)
-            {
-                SafeHandle library = LoadLibraryEx(file, IntPtr.Zero, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR);
-                IntPtr symbol = GetProcAddress(library, "AdbcDriverInit");
-                AdbcDriverInit init = Marshal.GetDelegateForFunctionPointer<AdbcDriverInit>(symbol);
+                AdbcDriverInit init = Marshal.GetDelegateForFunctionPointer<AdbcDriverInit>(export);
                 CAdbcDriver driver = new CAdbcDriver();
-                CAdbcError error = new CAdbcError();
-                byte result = init(ADBC_VERSION_1_0_0, ref driver, ref error);
-                return new NativeDriver { /* driverHandle = library, */ driver = driver };
-            }
-        }
-
-        /// <summary>
-        /// Loads an <see cref="AdbcDriver"/> from the file system.
-        /// </summary>
-        /// <param name="file">
-        /// The path to the file.
-        /// </param>
-        public static AdbcDriver Load(string file)
-        {
-            if (file[0] == '/')
-            {
-                return new ImportedAdbcDriver(MacInterop.GetDriver(file).driver);
+                using (CallHelper caller = new CallHelper())
+                {
+                    caller.Call(init, ADBC_VERSION_1_0_0, ref driver);
+                    ImportedAdbcDriver result = new ImportedAdbcDriver(library, driver);
+                    library = IntPtr.Zero;
+                    return result;
+                }
             }
-            else
+            finally
             {
-                return new ImportedAdbcDriver(WindowsInterop.GetDriver(file).driver);
+                if (library != IntPtr.Zero) { NativeLibrary.Free(library); }
             }
         }
 
@@ -146,10 +91,12 @@ namespace Apache.Arrow.Adbc.C
         /// </summary>
         sealed class ImportedAdbcDriver : AdbcDriver
         {
+            private IntPtr _library;
             private CAdbcDriver _nativeDriver;
 
-            public ImportedAdbcDriver(CAdbcDriver nativeDriver)
+            public ImportedAdbcDriver(IntPtr library, CAdbcDriver nativeDriver)
             {
+                _library = library;
                 _nativeDriver = nativeDriver;
             }
 
@@ -196,6 +143,10 @@ namespace Apache.Arrow.Adbc.C
                             _nativeDriver.release = null;
                         }
                     }
+
+                    NativeLibrary.Free(_library);
+                    _library = IntPtr.Zero;
+
                     base.Dispose();
                 }
             }
@@ -221,6 +172,8 @@ namespace Apache.Arrow.Adbc.C
 
                 using (CallHelper caller = new CallHelper())
                 {
+                    caller.Call(_nativeDriver.ConnectionNew, ref nativeConnection);
+
                     if (options != null)
                     {
                         foreach (KeyValuePair<string, string> pair in options)
@@ -313,14 +266,14 @@ namespace Apache.Arrow.Adbc.C
         /// <summary>
         /// Assists with UTF8/string marshalling
         /// </summary>
-        struct Utf8Helper : IDisposable
+        private struct Utf8Helper : IDisposable
         {
             private IntPtr _s;
 
             public Utf8Helper(string s)
             {
 #if NETSTANDARD
-                    _s = MarshalExtensions.StringToCoTaskMemUTF8(s);
+                _s = MarshalExtensions.StringToCoTaskMemUTF8(s);
 #else
                 _s = Marshal.StringToCoTaskMemUTF8(s);
 #endif
@@ -333,10 +286,15 @@ namespace Apache.Arrow.Adbc.C
         /// <summary>
         /// Assists with delegate calls and handling error codes
         /// </summary>
-        struct CallHelper : IDisposable
+        private struct CallHelper : IDisposable
         {
             private CAdbcError _error;
 
+            public unsafe void Call(AdbcDriverInit init, int version, ref CAdbcDriver driver)
+            {
+                TranslateCode(init(version, ref driver, ref this._error));
+            }
+
             public unsafe void Call(delegate* unmanaged[Stdcall]<CAdbcDriver*, CAdbcError*, AdbcStatusCode> fn, ref CAdbcDriver nativeDriver)
             {
                 fixed (CAdbcDriver* driver = &nativeDriver)
@@ -467,7 +425,7 @@ namespace Apache.Arrow.Adbc.C
                 }
             }
 
-            internal unsafe void TranslateCode(AdbcStatusCode statusCode)
+            private unsafe void TranslateCode(AdbcStatusCode statusCode)
             {
                 if (statusCode != AdbcStatusCode.Success)
                 {
@@ -475,7 +433,7 @@ namespace Apache.Arrow.Adbc.C
                     if ((IntPtr)_error.message != IntPtr.Zero)
                     {
 #if NETSTANDARD
-                            message = MarshalExtensions.PtrToStringUTF8((IntPtr)_error.message);
+                        message = MarshalExtensions.PtrToStringUTF8((IntPtr)_error.message);
 #else
                         message = Marshal.PtrToStringUTF8((IntPtr)_error.message);
 #endif
diff --git a/csharp/src/Apache.Arrow.Adbc/C/NativeLibrary.cs b/csharp/src/Apache.Arrow.Adbc/C/NativeLibrary.cs
new file mode 100644
index 00000000..d15e665b
--- /dev/null
+++ b/csharp/src/Apache.Arrow.Adbc/C/NativeLibrary.cs
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+using System;
+using System.Runtime.InteropServices;
+
+namespace Apache.Arrow.Adbc.C
+{
+    internal static class NativeLibrary
+    {
+        static readonly Loader _loader;
+
+        static NativeLibrary()
+        {
+            if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+            {
+                _loader = new Windows();
+            }
+            else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
+            {
+                _loader = new OSX();
+            }
+            else
+            {
+                _loader = new Unsupported();
+            }
+        }
+
+        private interface Loader
+        {
+            IntPtr Load(string libraryPath);
+            IntPtr GetExport(IntPtr handle, string name);
+            void Free(IntPtr handle);
+        }
+
+        sealed private class OSX : Loader
+        {
+            private const string libdl = "libdl.dylib";
+            private const int RTLD_NOW = 2;
+
+            [DllImport(libdl)]
+            static extern IntPtr dlopen(string fileName, int flags);
+
+            [DllImport(libdl)]
+            static extern IntPtr dlsym(IntPtr libraryHandle, string symbol);
+
+            [DllImport(libdl)]
+            static extern int dlclose(IntPtr handle);
+
+            public IntPtr Load(string libraryPath) => dlopen(libraryPath, RTLD_NOW);
+            public IntPtr GetExport(IntPtr handle, string name) => dlsym(handle, name);
+            public void Free(IntPtr handle) => dlclose(handle);
+        }
+
+        sealed private class Windows : Loader
+        {
+            private const string kernel32 = "kernel32.dll";
+            private const int LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x1000;
+            private const int LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x0100;
+
+            [DllImport(kernel32)]
+            [return: MarshalAs(UnmanagedType.Bool)]
+            static extern bool FreeLibrary(IntPtr libraryHandle);
+
+            [DllImport(kernel32, CharSet = CharSet.Ansi, BestFitMapping = false, ThrowOnUnmappableChar = true)]
+            static extern IntPtr GetProcAddress(IntPtr libraryHandle, string functionName);
+
+            [DllImport(kernel32, CharSet = CharSet.Unicode, SetLastError = true)]
+            static extern IntPtr LoadLibraryEx(string fileName, IntPtr hFile, uint flags);
+
+            public IntPtr Load(string fileName) =>
+                LoadLibraryEx(fileName, IntPtr.Zero, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR);
+            public IntPtr GetExport(IntPtr handle, string name) => GetProcAddress(handle, name);
+            public void Free(IntPtr handle) => FreeLibrary(handle);
+        }
+
+        sealed private class Unsupported : Loader
+        {
+            public void Free(IntPtr handle) => throw new NotSupportedException();
+            public IntPtr GetExport(IntPtr handle, string name) => throw new NotSupportedException();
+            public IntPtr Load(string libraryPath) => throw new NotSupportedException();
+        }
+
+        public static IntPtr Load(string fileName) => _loader.Load(fileName);
+        public static IntPtr GetExport(IntPtr handle, string name) => _loader.GetExport(handle, name);
+        public static void Free(IntPtr handle) => _loader.Free(handle);
+    }
+}