MqttClient.java

/*
 * Copyright 2016 The Lannister Project
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *     http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.anyflow.lannister.client;

import java.net.URI;
import java.net.URISyntaxException;

import javax.net.ssl.TrustManagerFactory;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.mqtt.MqttConnAckMessage;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;
import io.netty.handler.codec.mqtt.MqttMessage;
import io.netty.handler.codec.mqtt.MqttTopicSubscription;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.util.concurrent.DefaultThreadFactory;
import net.anyflow.lannister.Literals;
import net.anyflow.lannister.Settings;
import net.anyflow.lannister.message.ConnectOptions;
import net.anyflow.lannister.message.Message;
import net.anyflow.lannister.packetreceiver.MqttMessageFactory;

public class MqttClient {
	private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MqttClient.class);

	private final Bootstrap bootstrap;
	private final TrustManagerFactory trustManagerFactory;
	private final SharedObject sharedObject;

	private Channel channel;
	private EventLoopGroup group;
	private MessageReceiver receiver;

	private Integer currentMessageId;

	private URI uri;
	private ConnectOptions options;

	public MqttClient(String uri) throws URISyntaxException {
		this(uri, false);
	}

	public MqttClient(String uri, boolean useInsecureTrustManagerFactory) throws URISyntaxException {
		this.bootstrap = new Bootstrap();
		this.uri = new URI(uri);
		this.trustManagerFactory = useInsecureTrustManagerFactory ? InsecureTrustManagerFactory.INSTANCE : null;
		this.sharedObject = new SharedObject();
		this.options = new ConnectOptions();
		this.currentMessageId = 0;
	}

	public MqttConnectReturnCode connect() throws InterruptedException {

		Class<? extends SocketChannel> socketChannelClass;

		if (Literals.NETTY_EPOLL.equals(Settings.INSTANCE.nettyTransportMode())) {
			group = new EpollEventLoopGroup(1, new DefaultThreadFactory("client"));
			socketChannelClass = EpollSocketChannel.class;
		}
		else {
			group = new NioEventLoopGroup(1, new DefaultThreadFactory("client"));
			socketChannelClass = NioSocketChannel.class;
		}

		bootstrap.group(group).channel(socketChannelClass).handler(new ChannelInitializer<SocketChannel>() {
			@Override
			protected void initChannel(SocketChannel ch) throws Exception {
				if ("mqtts".equalsIgnoreCase(uri.getScheme())) {
					SslContext sslCtx = SslContextBuilder.forClient().trustManager(trustManagerFactory).build();

					ch.pipeline().addLast(sslCtx.newHandler(ch.alloc(), uri.getHost(), uri.getPort()));
				}

				ch.pipeline().addLast(MqttDecoder.class.getName(), new MqttDecoder());
				ch.pipeline().addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
				ch.pipeline().addLast(MqttPacketReceiver.class.getName(),
						new MqttPacketReceiver(MqttClient.this, receiver, sharedObject));
			}
		});

		channel = bootstrap.connect(uri.getHost(), uri.getPort()).sync().channel();

		normalizeMessage(options.will());
		send(MqttMessageFactory.connect(options));

		synchronized (sharedObject.locker()) {
			int timeout = Settings.INSTANCE.getInt("mqttclient.responseTimeoutSeconds", 15);

			sharedObject.locker().wait(timeout * 1000);
		}
		if (sharedObject.receivedMessage() == null) { return null; }

		return ((MqttConnAckMessage) sharedObject.receivedMessage()).variableHeader().connectReturnCode();
	}

	public boolean isConnected() {
		return channel != null && channel.isActive();
	}

	public void disconnect(boolean sendDisconnect) {
		if (!isConnected()) { return; }

		if (sendDisconnect) {
			send(MqttMessageFactory.disconnect());
		}

		channel.disconnect().addListener(ChannelFutureListener.CLOSE);
		group.shutdownGracefully();

		channel = null;
		group = null;
	}

	protected ChannelFuture send(MqttMessage message) {
		if (!isConnected()) {
			logger.error("Channel is not active");
			return null;
		}

		return channel.writeAndFlush(message);
	}

	public MqttClient receiver(MessageReceiver receiver) {
		this.receiver = receiver;

		return this;
	}

	public MqttClient connectOptions(ConnectOptions connectOptions) {
		this.options = connectOptions;

		return this;
	}

	public void publish(Message message) {
		normalizeMessage(message);
		send(MqttMessageFactory.publish(message, false));
	}

	public void subscribe(MqttTopicSubscription... topicSubscriptions) throws InterruptedException {
		send(MqttMessageFactory.subscribe(nextMessageId(), topicSubscriptions));

		// TODO error handling,store subscription
	}

	public int nextMessageId() {
		currentMessageId = currentMessageId + 1;

		if (currentMessageId > Message.MAX_MESSAGE_ID_NUM) {
			currentMessageId = Message.MIN_MESSAGE_ID_NUM;
		}

		return currentMessageId;
	}

	private void normalizeMessage(Message message) {
		if (message == null) { return; }

		message.id(nextMessageId());
		message.publisherId(this.options.clientId());
	}

	public String clientId() {
		return options.clientId();
	}
}