You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by cg...@apache.org on 2020/10/29 12:01:09 UTC

[drill] branch master updated: DRILL-7795: Add support for overriding hostname to C++ client

This is an automated email from the ASF dual-hosted git repository.

cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 264ac06  DRILL-7795: Add support for overriding hostname to C++ client
264ac06 is described below

commit 264ac06831041498e51af9bc296ede9f0834b3f8
Author: James Duong <jd...@dremio.com>
AuthorDate: Fri Oct 9 14:47:48 2020 -0700

    DRILL-7795: Add support for overriding hostname to C++ client
    
    * Add the hostnameOverride property to specify the expected hostname in the SSL certificate.
    * Set the TLS SNI property to the host field even if the hostname wasn't overridden.
    * Update querySubmitter example to include this new property.
    * Change querySubmitter so that the number of options is detected from the options array
    rather than being a constant that requires manual updates.
---
 contrib/native/client/example/querySubmitter.cpp   | 15 +++++---
 contrib/native/client/src/clientlib/channel.cpp    | 22 +++++++----
 contrib/native/client/src/clientlib/channel.hpp    | 44 ++++++++++++++++++++--
 contrib/native/client/src/clientlib/errmsgs.cpp    |  1 +
 contrib/native/client/src/clientlib/errmsgs.hpp    |  7 +++-
 .../native/client/src/clientlib/userProperties.cpp |  1 +
 contrib/native/client/src/include/drill/common.hpp |  1 +
 7 files changed, 72 insertions(+), 19 deletions(-)

diff --git a/contrib/native/client/example/querySubmitter.cpp b/contrib/native/client/example/querySubmitter.cpp
index 1ca7668..7c66613 100644
--- a/contrib/native/client/example/querySubmitter.cpp
+++ b/contrib/native/client/example/querySubmitter.cpp
@@ -23,8 +23,6 @@
 #include <boost/algorithm/string/join.hpp>
 #include "drill/drillc.hpp"
 
-int nOptions=27;
-
 struct Option{
     char name[32];
     char desc[128];
@@ -56,7 +54,8 @@ struct Option{
     {"disableCertVerification", "disable certificate verification", false},
     {"useSystemTrustStore", "[Windows only]. Use the system truststore.", false},
     {"CustomSSLCtxOptions", "The custom SSL CTX Options", false},
-    {"supportComplexTypes", "Toggle for supporting complex types", false}
+    {"supportComplexTypes", "Toggle for supporting complex types", false},
+    {"hostnameOverride", "Override the SSL server hostname", false}
 };
 
 std::map<std::string, std::string> qsOptionValues;
@@ -165,7 +164,7 @@ void print(const Drill::FieldMetadata* pFieldMetadata, void* buf, size_t sz){
 
 void printUsage(){
     std::cerr<<"Usage: querySubmitter ";
-    for(int j=0; j<nOptions ;j++){
+    for(int j=0; j<sizeof(qsOptions)/sizeof(qsOptions[0]) ;j++){
         std::cerr<< " "<< qsOptions[j].name <<"="  << "[" <<qsOptions[j].desc <<"]" ;
     }
     std::cerr<<std::endl;
@@ -179,7 +178,7 @@ int parseArgs(int argc, char* argv[]){
         char*v=strtok(NULL, "");
 
         bool found=false;
-        for(int j=0; j<nOptions ;j++){
+        for(int j=0; j<sizeof(qsOptions)/sizeof(qsOptions[0]) ;j++){
             if(!strcmp(qsOptions[j].name, o)){
                 found=true; break;
             }
@@ -196,7 +195,7 @@ int parseArgs(int argc, char* argv[]){
         qsOptionValues[o]=v;
     }
 
-    for(int j=0; j<nOptions ;j++){
+    for(int j=0; j<sizeof(qsOptions)/sizeof(qsOptions[0]) ;j++){
         if(qsOptions[j].required ){
             if(qsOptionValues.find(qsOptions[j].name) == qsOptionValues.end()){
                 std::cerr<< ""<< qsOptions[j].name << " [" <<qsOptions[j].desc <<"] " << "is required." << std::endl;
@@ -318,6 +317,7 @@ int main(int argc, char* argv[]) {
         std::string useSystemTrustStore = qsOptionValues["useSystemTrustStore"];
         std::string customSSLOptions = qsOptionValues["CustomSSLCtxOptions"];
         std::string supportComplexTypes = qsOptionValues["supportComplexTypes"];
+        std::string hostnameOverride = qsOptionValues["hostnameOverride"];
 
         Drill::QueryType type;
 
@@ -426,6 +426,9 @@ int main(int argc, char* argv[]) {
         if (supportComplexTypes.length() > 0){
             props.setProperty(USERPROP_SUPPORT_COMPLEX_TYPES, supportComplexTypes);
         }
+        if (hostnameOverride.length() > 0) {
+            props.setProperty(USERPROP_HOSTNAME_OVERRIDE, hostnameOverride);
+        }
 
         if(client.connect(connectStr.c_str(), &props)!=Drill::CONN_SUCCESS){
             std::cerr<< "Failed to connect with error: "<< client.getError() << " (Using:"<<connectStr<<")"<<std::endl;
diff --git a/contrib/native/client/src/clientlib/channel.cpp b/contrib/native/client/src/clientlib/channel.cpp
index 5e96388..53c9f0a 100644
--- a/contrib/native/client/src/clientlib/channel.cpp
+++ b/contrib/native/client/src/clientlib/channel.cpp
@@ -210,6 +210,9 @@ ChannelContext* ChannelFactory::getChannelContext(channelType_t t, DrillUserProp
                 verifyMode = boost::asio::ssl::context::verify_none;
             }
 
+            std::string hostnameOverride;
+            props->getProp(USERPROP_HOSTNAME_OVERRIDE, hostnameOverride);
+
             long customSSLCtxOptions = SSLChannelContext::ApplyMinTLSRestriction(protocol);
             std::string sslOptions;
             props->getProp(USERPROP_CUSTOM_SSLCTXOPTIONS, sslOptions);
@@ -222,7 +225,7 @@ ChannelContext* ChannelFactory::getChannelContext(channelType_t t, DrillUserProp
                  }
             }
 
-            pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode, customSSLCtxOptions);
+            pChannelContext = new SSLChannelContext(props, tlsVersion, verifyMode, hostnameOverride, customSSLCtxOptions);
         }
             break;
 #endif
@@ -276,20 +279,23 @@ connectionStatus_t Channel::init(){
 }
 
 connectionStatus_t Channel::connect(){
-    connectionStatus_t ret=CONN_FAILURE;
-    if(this->m_state==CHANNEL_INITIALIZED){
-        ret=m_pEndpoint->getDrillbitEndpoint();
-        if(ret==CONN_SUCCESS){
+    connectionStatus_t ret = CONN_FAILURE;
+    if (this->m_state == CHANNEL_INITIALIZED){
+        ret = m_pEndpoint->getDrillbitEndpoint();
+        if (ret == CONN_SUCCESS){
             DRILL_LOG(LOG_TRACE) << "Connecting to drillbit: " 
                 << m_pEndpoint->getHost() 
                 << ":" << m_pEndpoint->getPort() 
                 << "." << std::endl;
-            ret=this->connectInternal();
-        }else{
+            ret = this->setSocketInformation();
+            if (ret == CONN_SUCCESS) {
+                ret = this->connectInternal();
+            }
+        } else {
             handleError(ret, m_pEndpoint->getError()->msg);
         }
     }
-    this->m_state=(ret==CONN_SUCCESS)?CHANNEL_CONNECTED:this->m_state;
+    this->m_state = (ret == CONN_SUCCESS) ? CHANNEL_CONNECTED : this->m_state;
     return ret;
 }
 
diff --git a/contrib/native/client/src/clientlib/channel.hpp b/contrib/native/client/src/clientlib/channel.hpp
index 7d4ad60..f7dfe3e 100644
--- a/contrib/native/client/src/clientlib/channel.hpp
+++ b/contrib/native/client/src/clientlib/channel.hpp
@@ -94,7 +94,7 @@ class UserProperties;
             /// @brief Applies Minimum TLS protocol restrictions. 
             ///         tlsv11+ means restrict to TLS version 1.1 and higher.
             ///         tlsv12+ means restrict to TLS version 1.2 and higher.
-            ///  Please note that SSL_OP_NO_TLSv tags are depreecated in openSSL 1.1.0.
+            ///  Please note that SSL_OP_NO_TLSv tags are deprecated in openSSL 1.1.0.
             /// 
             /// @param in_ver               The protocol version.
             /// 
@@ -113,9 +113,11 @@ class UserProperties;
         SSLChannelContext(DrillUserProperties *props,
                           boost::asio::ssl::context::method tlsVersion,
                           boost::asio::ssl::verify_mode verifyMode,
+                          const std::string& hostnameOverride,
                           const long customSSLCtxOptions = 0) :
                     ChannelContext(props),
                     m_SSLContext(tlsVersion),
+                    m_hostnameOverride(hostnameOverride),
                     m_certHostnameVerificationStatus(true) 
             {
                 m_SSLContext.set_default_verify_paths();
@@ -142,9 +144,17 @@ class UserProperties;
             /// @param in_result                The host name verification status.
             void SetCertHostnameVerificationStatus(bool in_result) { m_certHostnameVerificationStatus = in_result; }
 
+            /// @brief Returns the overridden hostname used for certificate verification
+            ///
+            /// @return the hostname override, or empty if the hostname should not be overridden.
+            const std::string& GetHostnameOverride() { return m_hostnameOverride; }
+
         private:
             boost::asio::ssl::context m_SSLContext;
 
+            // The hostname to verify. Unused if empty.
+            std::string m_hostnameOverride;
+
             // The flag to indicate the host name verification result.
             bool m_certHostnameVerificationStatus;
     };
@@ -210,6 +220,10 @@ class UserProperties;
                 return handleError(CONN_HANDSHAKE_FAILED, in_err.what());
             }
 
+            virtual connectionStatus_t setSocketInformation() {
+                return CONN_SUCCESS;
+            }
+
             boost::asio::io_service& m_ioService;
             boost::asio::io_service m_ioServiceFallback; // used if m_ioService is not provided
             AsioStreamSocket* m_pSocket;
@@ -291,6 +305,21 @@ class UserProperties;
                         getMessage(ERR_CONN_SSL_GENERAL, in_err.what()));
                 }
             }
+
+            connectionStatus_t setSocketInformation() {
+                const char* sniProperty;
+                SSLChannelContext_t& context = *((SSLChannelContext_t *)m_pContext);
+                if (!context.GetHostnameOverride().empty()){
+                    sniProperty = context.GetHostnameOverride().c_str();
+                }
+                else{
+                    sniProperty = m_pEndpoint->getHost().c_str();
+                }
+                if (!SSL_set_tlsext_host_name(((SslSocket *)m_pSocket)->getSocketStream().native_handle(), sniProperty)) {
+                    return handleError(CONN_SSLERROR, getMessage(ERR_CONN_SSL_SNI, sniProperty, ERR_func_error_string(ERR_get_error())));
+                }
+                return CONN_SUCCESS;
+            }
 #endif
     };
 
@@ -332,9 +361,16 @@ class UserProperties;
                 // Gets the channel context.
                 SSLChannelContext_t* context = (SSLChannelContext_t*)(m_channel->getChannelContext());
 
-                // Retrieve the host before we perform Host name verification.
-                // This is because host with ZK mode is selected after the connect() function is called.
-                boost::asio::ssl::rfc2818_verification verifier(m_channel->getEndpoint()->getHost().c_str());
+                const char* hostname;
+                if (context->GetHostnameOverride().empty()) {
+                    // Retrieve the host before we perform Host name verification.
+                    // This is because host with ZK mode is selected after the connect() function is called.
+                    hostname = m_channel->getEndpoint()->getHost().c_str();
+                } else {
+                    hostname = context->GetHostnameOverride().c_str();
+                }
+
+                boost::asio::ssl::rfc2818_verification verifier(hostname);
 
                 // Perform verification.
                 bool verified = verifier(in_preverified, in_ctx);
diff --git a/contrib/native/client/src/clientlib/errmsgs.cpp b/contrib/native/client/src/clientlib/errmsgs.cpp
index 5ab8d8e..f3daa4d 100644
--- a/contrib/native/client/src/clientlib/errmsgs.cpp
+++ b/contrib/native/client/src/clientlib/errmsgs.cpp
@@ -61,6 +61,7 @@ static Drill::ErrorMessages errorMessages[]={
     {ERR_CONN_SSL_CN, ERR_CATEGORY_CONN, 0, "SSL certificate host name verification failure. [Details: %s]" },
     {ERR_CONN_SSL_CERTVERIFY, ERR_CATEGORY_CONN, 0, "SSL certificate verification failed. [Details: %s]"},
     {ERR_CONN_SSL_PROTOVER, ERR_CATEGORY_CONN, 0, "Unsupported TLS protocol version. [Details: %s]" },
+    {ERR_CONN_SSL_SNI, ERR_CATEGORY_CONN, 0, "Failed to set TLS SNI. Host: %s [Details: %s]"},
     {ERR_QRY_OUTOFMEM, ERR_CATEGORY_QRY, 0, "Out of memory."},
     {ERR_QRY_COMMERR, ERR_CATEGORY_QRY, 0, "Communication error. %s"},
     {ERR_QRY_INVREADLEN, ERR_CATEGORY_QRY, 0, "Internal Error: Received a message with an invalid read length."},
diff --git a/contrib/native/client/src/clientlib/errmsgs.hpp b/contrib/native/client/src/clientlib/errmsgs.hpp
index 7230611..465adb5 100644
--- a/contrib/native/client/src/clientlib/errmsgs.hpp
+++ b/contrib/native/client/src/clientlib/errmsgs.hpp
@@ -59,7 +59,10 @@ namespace Drill{
 #define ERR_CONN_SSL_CN         DRILL_ERR_START+27
 #define ERR_CONN_SSL_CERTVERIFY DRILL_ERR_START+28
 #define ERR_CONN_SSL_PROTOVER   DRILL_ERR_START+29
-#define ERR_CONN_MAX            DRILL_ERR_START+29
+#define ERR_CONN_SSL_SNI        DRILL_ERR_START+30
+
+// This should be the same as the largest ERR_CONN_* code.
+#define ERR_CONN_MAX            DRILL_ERR_START+30
 
 #define ERR_QRY_OUTOFMEM    ERR_CONN_MAX+1
 #define ERR_QRY_COMMERR     ERR_CONN_MAX+2
@@ -81,6 +84,8 @@ namespace Drill{
 #define ERR_QRY_18          ERR_CONN_MAX+18
 #define ERR_QRY_19          ERR_CONN_MAX+19
 #define ERR_QRY_20          ERR_CONN_MAX+20
+
+// This should be the same as the largest ERR_QRY_* code.
 #define ERR_QRY_MAX         ERR_QRY_20
 
     // Use only Plain Old Data types in this struc. We will declare
diff --git a/contrib/native/client/src/clientlib/userProperties.cpp b/contrib/native/client/src/clientlib/userProperties.cpp
index 0ad8af1..d5d96ea 100644
--- a/contrib/native/client/src/clientlib/userProperties.cpp
+++ b/contrib/native/client/src/clientlib/userProperties.cpp
@@ -32,6 +32,7 @@ const std::map<std::string, uint32_t>  DrillUserProperties::USER_PROPERTIES=boos
     ( USERPROP_USESSL,      USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
     ( USERPROP_TLSPROTOCOL,      USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP)
     ( USERPROP_CERTFILEPATH,    USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP|USERPROP_FLAGS_FILEPATH)
+    ( USERPROP_HOSTNAME_OVERRIDE,    USERPROP_FLAGS_STRING|USERPROP_FLAGS_SSLPROP)
     ( USERPROP_DISABLE_HOSTVERIFICATION,    USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
     ( USERPROP_DISABLE_CERTVERIFICATION,    USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
     ( USERPROP_USESYSTEMTRUSTSTORE,    USERPROP_FLAGS_BOOLEAN|USERPROP_FLAGS_SSLPROP)
diff --git a/contrib/native/client/src/include/drill/common.hpp b/contrib/native/client/src/include/drill/common.hpp
index 9f57446..8bd15f4 100644
--- a/contrib/native/client/src/include/drill/common.hpp
+++ b/contrib/native/client/src/include/drill/common.hpp
@@ -180,6 +180,7 @@ typedef enum{
 // #define USERPROP_CERTPASSWORD "certPassword" // Password for certificate file. 
 #define USERPROP_DISABLE_HOSTVERIFICATION "disableHostVerification"
 #define USERPROP_DISABLE_CERTVERIFICATION "disableCertVerification"
+#define USERPROP_HOSTNAME_OVERRIDE "hostnameOverride" //The hostname to verify in the SSL Certificate.
 #define USERPROP_USESYSTEMTRUSTSTORE "useSystemTrustStore" //Windows only, use the system trust store
 #define USERPROP_IMPERSONATION_TARGET "impersonation_target"
 #define USERPROP_AUTH_MECHANISM "auth"