001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.mqtt.strategy;
018
019import java.io.IOException;
020import java.util.Set;
021import java.util.concurrent.ConcurrentHashMap;
022import java.util.concurrent.ConcurrentMap;
023
024import org.apache.activemq.broker.BrokerService;
025import org.apache.activemq.broker.BrokerServiceAware;
026import org.apache.activemq.broker.ConnectionContext;
027import org.apache.activemq.broker.region.PrefetchSubscription;
028import org.apache.activemq.broker.region.RegionBroker;
029import org.apache.activemq.broker.region.Subscription;
030import org.apache.activemq.broker.region.TopicRegion;
031import org.apache.activemq.broker.region.virtual.VirtualTopicInterceptor;
032import org.apache.activemq.command.ActiveMQDestination;
033import org.apache.activemq.command.ActiveMQTopic;
034import org.apache.activemq.command.ConsumerId;
035import org.apache.activemq.command.ConsumerInfo;
036import org.apache.activemq.command.ExceptionResponse;
037import org.apache.activemq.command.RemoveInfo;
038import org.apache.activemq.command.Response;
039import org.apache.activemq.transport.mqtt.MQTTProtocolConverter;
040import org.apache.activemq.transport.mqtt.MQTTProtocolException;
041import org.apache.activemq.transport.mqtt.MQTTSubscription;
042import org.apache.activemq.transport.mqtt.ResponseHandler;
043import org.apache.activemq.util.LongSequenceGenerator;
044import org.fusesource.mqtt.client.QoS;
045import org.fusesource.mqtt.client.Topic;
046import org.slf4j.Logger;
047import org.slf4j.LoggerFactory;
048
049/**
050 * Abstract implementation of the {@link MQTTSubscriptionStrategy} interface providing
051 * the base functionality that is common to most implementations.
052 */
053public abstract class AbstractMQTTSubscriptionStrategy implements MQTTSubscriptionStrategy, BrokerServiceAware {
054
055    private static final Logger LOG = LoggerFactory.getLogger(AbstractMQTTSubscriptionStrategy.class);
056
057    private static final byte SUBSCRIBE_ERROR = (byte) 0x80;
058
059    protected MQTTProtocolConverter protocol;
060    protected BrokerService brokerService;
061
062    protected final ConcurrentMap<ConsumerId, MQTTSubscription> subscriptionsByConsumerId = new ConcurrentHashMap<ConsumerId, MQTTSubscription>();
063    protected final ConcurrentMap<String, MQTTSubscription> mqttSubscriptionByTopic = new ConcurrentHashMap<String, MQTTSubscription>();
064
065    protected final LongSequenceGenerator consumerIdGenerator = new LongSequenceGenerator();
066
067    @Override
068    public void initialize(MQTTProtocolConverter protocol) throws MQTTProtocolException {
069        setProtocolConverter(protocol);
070    }
071
072    @Override
073    public void setBrokerService(BrokerService brokerService) {
074        this.brokerService = brokerService;
075    }
076
077    @Override
078    public void setProtocolConverter(MQTTProtocolConverter parent) {
079        this.protocol = parent;
080    }
081
082    @Override
083    public MQTTProtocolConverter getProtocolConverter() {
084        return protocol;
085    }
086
087    @Override
088    public byte onSubscribe(final Topic topic) throws MQTTProtocolException {
089
090        final String destinationName = topic.name().toString();
091        final QoS requestedQoS = topic.qos();
092
093        final MQTTSubscription mqttSubscription = mqttSubscriptionByTopic.get(destinationName);
094        if (mqttSubscription != null) {
095            if (requestedQoS != mqttSubscription.getQoS()) {
096                // remove old subscription as the QoS has changed
097                onUnSubscribe(destinationName);
098            } else {
099                try {
100                    onReSubscribe(mqttSubscription);
101                } catch (IOException e) {
102                    throw new MQTTProtocolException("Failed to find subscription strategy", true, e);
103                }
104                return (byte) requestedQoS.ordinal();
105            }
106        }
107
108        try {
109            return onSubscribe(destinationName, requestedQoS);
110        } catch (IOException e) {
111            throw new MQTTProtocolException("Failed while intercepting subscribe", true, e);
112        }
113    }
114
115    @Override
116    public void onReSubscribe(MQTTSubscription mqttSubscription) throws MQTTProtocolException {
117        String topicName = mqttSubscription.getTopicName();
118
119        // get TopicRegion
120        RegionBroker regionBroker;
121        try {
122            regionBroker = (RegionBroker) brokerService.getBroker().getAdaptor(RegionBroker.class);
123        } catch (Exception e) {
124            throw new MQTTProtocolException("Error subscribing to " + topicName + ": " + e.getMessage(), false, e);
125        }
126        final TopicRegion topicRegion = (TopicRegion) regionBroker.getTopicRegion();
127
128        final ConsumerInfo consumerInfo = mqttSubscription.getConsumerInfo();
129        final ConsumerId consumerId = consumerInfo.getConsumerId();
130
131        // use actual client id used to create connection to lookup connection
132        // context
133        String connectionInfoClientId = protocol.getClientId();
134        // for zero-byte client ids we used connection id
135        if (connectionInfoClientId == null || connectionInfoClientId.isEmpty()) {
136            connectionInfoClientId = protocol.getConnectionId().toString();
137        }
138        final ConnectionContext connectionContext = regionBroker.getConnectionContext(connectionInfoClientId);
139
140        // get all matching Topics
141        final Set<org.apache.activemq.broker.region.Destination> matchingDestinations =
142            topicRegion.getDestinations(mqttSubscription.getDestination());
143        for (org.apache.activemq.broker.region.Destination dest : matchingDestinations) {
144
145            // recover retroactive messages for matching subscription
146            for (Subscription subscription : dest.getConsumers()) {
147                if (subscription.getConsumerInfo().getConsumerId().equals(consumerId)) {
148                    try {
149                        if (dest instanceof org.apache.activemq.broker.region.Topic) {
150                            ((org.apache.activemq.broker.region.Topic) dest).recoverRetroactiveMessages(connectionContext, subscription);
151                        } else if (dest instanceof VirtualTopicInterceptor) {
152                            ((VirtualTopicInterceptor) dest).getTopic().recoverRetroactiveMessages(connectionContext, subscription);
153                        }
154                        if (subscription instanceof PrefetchSubscription) {
155                            // request dispatch for prefetch subs
156                            PrefetchSubscription prefetchSubscription = (PrefetchSubscription) subscription;
157                            prefetchSubscription.dispatchPending();
158                        }
159                    } catch (Exception e) {
160                        throw new MQTTProtocolException("Error recovering retained messages for " + dest.getName() + ": " + e.getMessage(), false, e);
161                    }
162                    break;
163                }
164            }
165        }
166    }
167
168    @Override
169    public ActiveMQDestination onSend(String topicName) {
170        return new ActiveMQTopic(topicName);
171    }
172
173    @Override
174    public String onSend(ActiveMQDestination destination) {
175        return destination.getPhysicalName();
176    }
177
178    @Override
179    public boolean isControlTopic(ActiveMQDestination destination) {
180        return destination.getPhysicalName().startsWith("$");
181    }
182
183    @Override
184    public MQTTSubscription getSubscription(ConsumerId consumerId) {
185        return subscriptionsByConsumerId.get(consumerId);
186    }
187
188    protected ConsumerId getNextConsumerId() {
189        return new ConsumerId(protocol.getSessionId(), consumerIdGenerator.getNextSequenceId());
190    }
191
192    protected byte doSubscribe(ConsumerInfo consumerInfo, final String topicName, final QoS qoS) throws MQTTProtocolException {
193
194        MQTTSubscription mqttSubscription = new MQTTSubscription(protocol, topicName, qoS, consumerInfo);
195
196        // optimistic add to local maps first to be able to handle commands in onActiveMQCommand
197        subscriptionsByConsumerId.put(consumerInfo.getConsumerId(), mqttSubscription);
198        mqttSubscriptionByTopic.put(topicName, mqttSubscription);
199
200        final byte[] qos = {-1};
201        protocol.sendToActiveMQ(consumerInfo, new ResponseHandler() {
202            @Override
203            public void onResponse(MQTTProtocolConverter converter, Response response) throws IOException {
204                // validate subscription request
205                if (response.isException()) {
206                    final Throwable throwable = ((ExceptionResponse) response).getException();
207                    LOG.warn("Error subscribing to {}", topicName, throwable);
208                    // version 3.1 don't supports silent fail
209                    // version 3.1.1 send "error" qos
210                    if (protocol.version == MQTTProtocolConverter.V3_1_1) {
211                        qos[0] = SUBSCRIBE_ERROR;
212                    } else {
213                        qos[0] = (byte) qoS.ordinal();
214                    }
215                } else {
216                    qos[0] = (byte) qoS.ordinal();
217                }
218            }
219        });
220
221        if (qos[0] == SUBSCRIBE_ERROR) {
222            // remove from local maps if subscribe failed
223            subscriptionsByConsumerId.remove(consumerInfo.getConsumerId());
224            mqttSubscriptionByTopic.remove(topicName);
225        }
226
227        return qos[0];
228    }
229
230    public void doUnSubscribe(MQTTSubscription subscription) {
231        mqttSubscriptionByTopic.remove(subscription.getTopicName());
232        ConsumerInfo info = subscription.getConsumerInfo();
233        if (info != null) {
234            subscriptionsByConsumerId.remove(info.getConsumerId());
235
236            RemoveInfo removeInfo = info.createRemoveCommand();
237            protocol.sendToActiveMQ(removeInfo, new ResponseHandler() {
238                @Override
239                public void onResponse(MQTTProtocolConverter converter, Response response) throws IOException {
240                    // ignore failures..
241                }
242            });
243        }
244    }
245}