View Javadoc
1   /*
2    * Copyright 2016 The Lannister Project
3    * 
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    * 
8    *     http://www.apache.org/licenses/LICENSE-2.0
9    * 
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package net.anyflow.lannister.client;
18  
19  import java.net.URI;
20  import java.net.URISyntaxException;
21  
22  import javax.net.ssl.TrustManagerFactory;
23  
24  import io.netty.bootstrap.Bootstrap;
25  import io.netty.channel.Channel;
26  import io.netty.channel.ChannelFuture;
27  import io.netty.channel.ChannelFutureListener;
28  import io.netty.channel.ChannelInitializer;
29  import io.netty.channel.EventLoopGroup;
30  import io.netty.channel.epoll.EpollEventLoopGroup;
31  import io.netty.channel.epoll.EpollSocketChannel;
32  import io.netty.channel.nio.NioEventLoopGroup;
33  import io.netty.channel.socket.SocketChannel;
34  import io.netty.channel.socket.nio.NioSocketChannel;
35  import io.netty.handler.codec.mqtt.MqttConnAckMessage;
36  import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
37  import io.netty.handler.codec.mqtt.MqttDecoder;
38  import io.netty.handler.codec.mqtt.MqttEncoder;
39  import io.netty.handler.codec.mqtt.MqttMessage;
40  import io.netty.handler.codec.mqtt.MqttTopicSubscription;
41  import io.netty.handler.ssl.SslContext;
42  import io.netty.handler.ssl.SslContextBuilder;
43  import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
44  import io.netty.util.concurrent.DefaultThreadFactory;
45  import net.anyflow.lannister.Literals;
46  import net.anyflow.lannister.Settings;
47  import net.anyflow.lannister.message.ConnectOptions;
48  import net.anyflow.lannister.message.Message;
49  import net.anyflow.lannister.packetreceiver.MqttMessageFactory;
50  
51  public class MqttClient {
52  	private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MqttClient.class);
53  
54  	private final Bootstrap bootstrap;
55  	private final TrustManagerFactory trustManagerFactory;
56  	private final SharedObject sharedObject;
57  
58  	private Channel channel;
59  	private EventLoopGroup group;
60  	private MessageReceiver receiver;
61  
62  	private Integer currentMessageId;
63  
64  	private URI uri;
65  	private ConnectOptions options;
66  
67  	public MqttClient(String uri) throws URISyntaxException {
68  		this(uri, false);
69  	}
70  
71  	public MqttClient(String uri, boolean useInsecureTrustManagerFactory) throws URISyntaxException {
72  		this.bootstrap = new Bootstrap();
73  		this.uri = new URI(uri);
74  		this.trustManagerFactory = useInsecureTrustManagerFactory ? InsecureTrustManagerFactory.INSTANCE : null;
75  		this.sharedObject = new SharedObject();
76  		this.options = new ConnectOptions();
77  		this.currentMessageId = 0;
78  	}
79  
80  	public MqttConnectReturnCode connect() throws InterruptedException {
81  
82  		Class<? extends SocketChannel> socketChannelClass;
83  
84  		if (Literals.NETTY_EPOLL.equals(Settings.INSTANCE.nettyTransportMode())) {
85  			group = new EpollEventLoopGroup(1, new DefaultThreadFactory("client"));
86  			socketChannelClass = EpollSocketChannel.class;
87  		}
88  		else {
89  			group = new NioEventLoopGroup(1, new DefaultThreadFactory("client"));
90  			socketChannelClass = NioSocketChannel.class;
91  		}
92  
93  		bootstrap.group(group).channel(socketChannelClass).handler(new ChannelInitializer<SocketChannel>() {
94  			@Override
95  			protected void initChannel(SocketChannel ch) throws Exception {
96  				if ("mqtts".equalsIgnoreCase(uri.getScheme())) {
97  					SslContext sslCtx = SslContextBuilder.forClient().trustManager(trustManagerFactory).build();
98  
99  					ch.pipeline().addLast(sslCtx.newHandler(ch.alloc(), uri.getHost(), uri.getPort()));
100 				}
101 
102 				ch.pipeline().addLast(MqttDecoder.class.getName(), new MqttDecoder());
103 				ch.pipeline().addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
104 				ch.pipeline().addLast(MqttPacketReceiver.class.getName(),
105 						new MqttPacketReceiver(MqttClient.this, receiver, sharedObject));
106 			}
107 		});
108 
109 		channel = bootstrap.connect(uri.getHost(), uri.getPort()).sync().channel();
110 
111 		normalizeMessage(options.will());
112 		send(MqttMessageFactory.connect(options));
113 
114 		synchronized (sharedObject.locker()) {
115 			int timeout = Settings.INSTANCE.getInt("mqttclient.responseTimeoutSeconds", 15);
116 
117 			sharedObject.locker().wait(timeout * 1000);
118 		}
119 		if (sharedObject.receivedMessage() == null) { return null; }
120 
121 		return ((MqttConnAckMessage) sharedObject.receivedMessage()).variableHeader().connectReturnCode();
122 	}
123 
124 	public boolean isConnected() {
125 		return channel != null && channel.isActive();
126 	}
127 
128 	public void disconnect(boolean sendDisconnect) {
129 		if (!isConnected()) { return; }
130 
131 		if (sendDisconnect) {
132 			send(MqttMessageFactory.disconnect());
133 		}
134 
135 		channel.disconnect().addListener(ChannelFutureListener.CLOSE);
136 		group.shutdownGracefully();
137 
138 		channel = null;
139 		group = null;
140 	}
141 
142 	protected ChannelFuture send(MqttMessage message) {
143 		if (!isConnected()) {
144 			logger.error("Channel is not active");
145 			return null;
146 		}
147 
148 		return channel.writeAndFlush(message);
149 	}
150 
151 	public MqttClient receiver(MessageReceiver receiver) {
152 		this.receiver = receiver;
153 
154 		return this;
155 	}
156 
157 	public MqttClient connectOptions(ConnectOptions connectOptions) {
158 		this.options = connectOptions;
159 
160 		return this;
161 	}
162 
163 	public void publish(Message message) {
164 		normalizeMessage(message);
165 		send(MqttMessageFactory.publish(message, false));
166 	}
167 
168 	public void subscribe(MqttTopicSubscription... topicSubscriptions) throws InterruptedException {
169 		send(MqttMessageFactory.subscribe(nextMessageId(), topicSubscriptions));
170 
171 		// TODO error handling,store subscription
172 	}
173 
174 	public int nextMessageId() {
175 		currentMessageId = currentMessageId + 1;
176 
177 		if (currentMessageId > Message.MAX_MESSAGE_ID_NUM) {
178 			currentMessageId = Message.MIN_MESSAGE_ID_NUM;
179 		}
180 
181 		return currentMessageId;
182 	}
183 
184 	private void normalizeMessage(Message message) {
185 		if (message == null) { return; }
186 
187 		message.id(nextMessageId());
188 		message.publisherId(this.options.clientId());
189 	}
190 
191 	public String clientId() {
192 		return options.clientId();
193 	}
194 }