You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@rocketmq.apache.org by ti...@apache.org on 2022/07/01 10:15:07 UTC

[rocketmq-mqtt] branch main updated: support tls tcp server

This is an automated email from the ASF dual-hosted git repository.

tianliuliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/rocketmq-mqtt.git


The following commit(s) were added to refs/heads/main by this push:
     new c8922c8  support tls tcp server
     new 0d8ecf3  Merge pull request #115 from ChangingFond/tls-server
c8922c8 is described below

commit c8922c8a90a5b21ab34c33ba9a8e1d8d01fbbc61
Author: ChangingFond <78...@qq.com>
AuthorDate: Sat Jun 11 19:37:14 2022 +0800

    support tls tcp server
---
 .../rocketmq/mqtt/cs/config/ConnectConf.java       | 15 ++++
 .../rocketmq/mqtt/cs/protocol/ssl/SslFactory.java  | 83 ++++++++++++++++++++++
 .../rocketmq/mqtt/cs/starter/MqttServer.java       | 36 +++++++++-
 3 files changed, 132 insertions(+), 2 deletions(-)

diff --git a/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/config/ConnectConf.java b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/config/ConnectConf.java
index 72330e8..c15b2a3 100644
--- a/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/config/ConnectConf.java
+++ b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/config/ConnectConf.java
@@ -33,7 +33,10 @@ public class ConnectConf {
     private int nettySelectThreadNum = 1;
     private int nettyWorkerThreadNum = Runtime.getRuntime().availableProcessors() * 2;
     private int mqttPort = 1883;
+    private int mqttTlsPort = 8883;
     private int mqttWsPort = 8888;
+    private boolean enableTlsSever = false;
+    private boolean needClientAuth = false;
     private int maxPacketSizeInByte = 64 * 1024;
     private int highWater = 256 * 1024;
     private int lowWater = 16 * 1024;
@@ -83,6 +86,10 @@ public class ConnectConf {
         return mqttPort;
     }
 
+    public int getMqttTlsPort() {
+        return mqttTlsPort;
+    }
+
     public void setMqttPort(int mqttPort) {
         this.mqttPort = mqttPort;
     }
@@ -91,6 +98,14 @@ public class ConnectConf {
         return mqttWsPort;
     }
 
+    public boolean isEnableTlsSever() {
+        return enableTlsSever;
+    }
+
+    public boolean isNeedClientAuth() {
+        return needClientAuth;
+    }
+
     public void setMqttWsPort(int mqttWsPort) {
         this.mqttWsPort = mqttWsPort;
     }
diff --git a/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/protocol/ssl/SslFactory.java b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/protocol/ssl/SslFactory.java
new file mode 100644
index 0000000..c869acf
--- /dev/null
+++ b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/protocol/ssl/SslFactory.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.rocketmq.mqtt.cs.protocol.ssl;
+
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.ClientAuth;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.SslProvider;
+import org.apache.rocketmq.mqtt.cs.config.ConnectConf;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.stereotype.Component;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+import javax.annotation.PostConstruct;
+import javax.annotation.Resource;
+import javax.net.ssl.SSLEngine;
+
+@Component
+public class SslFactory {
+
+    private static final Logger LOG = LoggerFactory.getLogger(SslFactory.class);
+
+    private static final String CERT_FILE_NAME = "mqtt.crt";
+    private static final String KEY_FILE_NAME = "mqtt.key";
+
+    @Resource
+    private ConnectConf connectConf;
+
+    private SslContext sslContext;
+
+    @PostConstruct
+    private void initSslContext() {
+        if (!connectConf.isEnableTlsSever()) {
+            return;
+        }
+
+        try {
+            InputStream certStream = new ClassPathResource(CERT_FILE_NAME).getInputStream();
+            InputStream keyStream = new ClassPathResource(KEY_FILE_NAME).getInputStream();
+            SslContextBuilder contextBuilder = SslContextBuilder.forServer(certStream, keyStream);
+            contextBuilder.clientAuth(ClientAuth.OPTIONAL);
+            contextBuilder.sslProvider(SslProvider.JDK);
+            if (connectConf.isNeedClientAuth()) {
+                LOG.info("client tls authentication is required.");
+                contextBuilder.clientAuth(ClientAuth.REQUIRE);
+                contextBuilder.trustManager(certStream);
+            }
+            sslContext = contextBuilder.build();
+        } catch (IOException e) {
+            throw new RuntimeException("failed to initialize ssl context.", e);
+        }
+    }
+
+    public SSLEngine buildSslEngine(SocketChannel ch) {
+        SSLEngine sslEngine = sslContext.newEngine(ch.alloc());
+        sslEngine.setEnabledCipherSuites(sslEngine.getSupportedCipherSuites());
+        sslEngine.setUseClientMode(false);
+        sslEngine.setNeedClientAuth(connectConf.isNeedClientAuth());
+        return sslEngine;
+    }
+
+}
+
diff --git a/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/starter/MqttServer.java b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/starter/MqttServer.java
index 9022940..911828e 100644
--- a/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/starter/MqttServer.java
+++ b/mqtt-cs/src/main/java/org/apache/rocketmq/mqtt/cs/starter/MqttServer.java
@@ -30,11 +30,12 @@ import io.netty.handler.codec.http.HttpObjectAggregator;
 import io.netty.handler.codec.http.HttpServerCodec;
 import io.netty.handler.codec.mqtt.MqttDecoder;
 import io.netty.handler.codec.mqtt.MqttEncoder;
+import io.netty.handler.ssl.SslHandler;
 import io.netty.handler.stream.ChunkedWriteHandler;
-import org.apache.rocketmq.mqtt.cs.channel.ChannelManager;
 import org.apache.rocketmq.mqtt.cs.channel.ConnectHandler;
 import org.apache.rocketmq.mqtt.cs.config.ConnectConf;
 import org.apache.rocketmq.mqtt.cs.protocol.mqtt.MqttPacketDispatcher;
+import org.apache.rocketmq.mqtt.cs.protocol.ssl.SslFactory;
 import org.apache.rocketmq.mqtt.cs.protocol.ws.WebSocketServerHandler;
 import org.apache.rocketmq.mqtt.cs.protocol.ws.WebSocketEncoder;
 import org.slf4j.Logger;
@@ -51,6 +52,7 @@ public class MqttServer {
 
     private ServerBootstrap serverBootstrap = new ServerBootstrap();
     private ServerBootstrap wsServerBootstrap = new ServerBootstrap();
+    private ServerBootstrap tlsServerBootstrap = new ServerBootstrap();
 
     @Resource
     private ConnectHandler connectHandler;
@@ -65,12 +67,13 @@ public class MqttServer {
     private WebSocketServerHandler webSocketServerHandler;
 
     @Resource
-    private ChannelManager channelManager;
+    private SslFactory sslFactory;
 
     @PostConstruct
     public void init() throws Exception {
         start();
         startWs();
+        startTls();
     }
 
     private void start() {
@@ -97,6 +100,35 @@ public class MqttServer {
         logger.warn("start mqtt server , port:{}", port);
     }
 
+    private void startTls() {
+        if (!connectConf.isEnableTlsSever()) {
+            return;
+        }
+
+        int tlsPort = connectConf.getMqttTlsPort();
+        tlsServerBootstrap
+            .group(new NioEventLoopGroup(connectConf.getNettySelectThreadNum()), new NioEventLoopGroup(connectConf.getNettyWorkerThreadNum()))
+            .channel(NioServerSocketChannel.class)
+            .option(ChannelOption.SO_BACKLOG, 8 * 1024)
+            .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+            .childOption(ChannelOption.WRITE_BUFFER_WATER_MARK,new WriteBufferWaterMark(connectConf.getLowWater(), connectConf.getHighWater()))
+            .childOption(ChannelOption.TCP_NODELAY, true)
+            .localAddress(new InetSocketAddress(tlsPort))
+            .childHandler(new ChannelInitializer<SocketChannel>() {
+                @Override
+                public void initChannel(SocketChannel ch) throws Exception {
+                    ChannelPipeline pipeline = ch.pipeline();
+                    pipeline.addLast("sslHandler", new SslHandler(sslFactory.buildSslEngine(ch)));
+                    pipeline.addLast("connectHandler", connectHandler);
+                    pipeline.addLast("decoder", new MqttDecoder(connectConf.getMaxPacketSizeInByte()));
+                    pipeline.addLast("encoder", MqttEncoder.INSTANCE);
+                    pipeline.addLast("dispatcher", mqttPacketDispatcher);
+                }
+            });
+        tlsServerBootstrap.bind();
+        logger.warn("start mqtt tls server , port:{}", tlsPort);
+    }
+
     private void startWs() {
         int port = connectConf.getMqttWsPort();
         wsServerBootstrap