001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.ws.jetty9;
019
020import java.io.IOException;
021import java.net.URI;
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.Comparator;
025import java.util.HashMap;
026import java.util.List;
027import java.util.Map;
028import java.util.concurrent.ConcurrentHashMap;
029
030import javax.servlet.ServletException;
031import javax.servlet.http.HttpServletRequest;
032import javax.servlet.http.HttpServletResponse;
033
034import org.apache.activemq.broker.BrokerService;
035import org.apache.activemq.broker.BrokerServiceAware;
036import org.apache.activemq.transport.Transport;
037import org.apache.activemq.transport.TransportAcceptListener;
038import org.apache.activemq.transport.TransportFactory;
039import org.apache.activemq.transport.util.HttpTransportUtils;
040import org.apache.activemq.transport.ws.WSTransportProxy;
041import org.eclipse.jetty.websocket.api.WebSocketListener;
042import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
043import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
044import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
045import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
046import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
047
048/**
049 * Handle connection upgrade requests and creates web sockets
050 */
051public class WSServlet extends WebSocketServlet implements BrokerServiceAware {
052
053    private static final long serialVersionUID = -4716657876092884139L;
054
055    private TransportAcceptListener listener;
056
057    private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<>();
058    private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<>();
059
060    private Map<String, Object> transportOptions;
061    private BrokerService brokerService;
062
063    private enum Protocol {
064        MQTT, STOMP, UNKNOWN
065    }
066
067    static {
068        stompProtocols.put("v12.stomp", 3);
069        stompProtocols.put("v11.stomp", 2);
070        stompProtocols.put("v10.stomp", 1);
071        stompProtocols.put("stomp", 0);
072
073        mqttProtocols.put("mqttv3.1", 1);
074        mqttProtocols.put("mqtt", 0);
075    }
076
077    @Override
078    public void init() throws ServletException {
079        super.init();
080        listener = (TransportAcceptListener) getServletContext().getAttribute("acceptListener");
081        if (listener == null) {
082            throw new ServletException("No such attribute 'acceptListener' available in the ServletContext");
083        }
084    }
085
086    @Override
087    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
088        //return empty response - AMQ-6491
089    }
090
091    @Override
092    public void configure(WebSocketServletFactory factory) {
093        factory.setCreator(new WebSocketCreator() {
094            @Override
095            public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
096                WebSocketListener socket;
097                Protocol requestedProtocol = Protocol.UNKNOWN;
098
099                // When no sub-protocol is requested we default to STOMP for legacy reasons.
100                if (!req.getSubProtocols().isEmpty()) {
101                    for (String subProtocol : req.getSubProtocols()) {
102                        if (subProtocol.startsWith("mqtt")) {
103                            requestedProtocol = Protocol.MQTT;
104                        } else if (subProtocol.contains("stomp")) {
105                            requestedProtocol = Protocol.STOMP;
106                        }
107                    }
108                } else {
109                    requestedProtocol = Protocol.STOMP;
110                }
111
112                switch (requestedProtocol) {
113                    case MQTT:
114                        socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
115                        ((MQTTSocket) socket).setTransportOptions(new HashMap<>(transportOptions));
116                        ((MQTTSocket) socket).setPeerCertificates(req.getCertificates());
117                        resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt"));
118                        break;
119                    case UNKNOWN:
120                        socket = findWSTransport(req, resp);
121                        if (socket != null) {
122                            break;
123                        }
124                    case STOMP:
125                        socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
126                        ((StompSocket) socket).setPeerCertificates(req.getCertificates());
127                        resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp"));
128                        break;
129                    default:
130                        socket = null;
131                        listener.onAcceptError(new IOException("Unknown protocol requested"));
132                        break;
133                }
134
135                if (socket != null) {
136                    listener.onAccept((Transport) socket);
137                }
138
139                return socket;
140            }
141        });
142    }
143
144    private WebSocketListener findWSTransport(ServletUpgradeRequest request, ServletUpgradeResponse response) {
145        WSTransportProxy proxy = null;
146
147        for (String subProtocol : request.getSubProtocols()) {
148            try {
149                String remoteAddress = HttpTransportUtils.generateWsRemoteAddress(request.getHttpServletRequest(), subProtocol);
150                URI remoteURI = new URI(remoteAddress);
151
152                TransportFactory factory = TransportFactory.findTransportFactory(remoteURI);
153
154                if (factory instanceof BrokerServiceAware) {
155                    ((BrokerServiceAware) factory).setBrokerService(brokerService);
156                }
157
158                Transport transport = factory.doConnect(remoteURI);
159
160                proxy = new WSTransportProxy(remoteAddress, transport);
161                proxy.setPeerCertificates(request.getCertificates());
162                proxy.setTransportOptions(transportOptions);
163
164                response.setAcceptedSubProtocol(proxy.getSubProtocol());
165            } catch (Exception e) {
166                proxy = null;
167
168                // Keep going and try any other sub-protocols present.
169                continue;
170            }
171        }
172
173        return proxy;
174    }
175
176    private String getAcceptedSubProtocol(final Map<String, Integer> protocols, List<String> subProtocols, String defaultProtocol) {
177        List<SubProtocol> matchedProtocols = new ArrayList<>();
178        if (subProtocols != null && subProtocols.size() > 0) {
179            // detect which subprotocols match accepted protocols and add to the
180            // list
181            for (String subProtocol : subProtocols) {
182                Integer priority = protocols.get(subProtocol);
183                if (subProtocol != null && priority != null) {
184                    // only insert if both subProtocol and priority are not null
185                    matchedProtocols.add(new SubProtocol(subProtocol, priority));
186                }
187            }
188            // sort the list by priority
189            if (matchedProtocols.size() > 0) {
190                Collections.sort(matchedProtocols, new Comparator<SubProtocol>() {
191                    @Override
192                    public int compare(SubProtocol s1, SubProtocol s2) {
193                        return s2.priority.compareTo(s1.priority);
194                    }
195                });
196                return matchedProtocols.get(0).protocol;
197            }
198        }
199        return defaultProtocol;
200    }
201
202    private class SubProtocol {
203        private String protocol;
204        private Integer priority;
205
206        public SubProtocol(String protocol, Integer priority) {
207            this.protocol = protocol;
208            this.priority = priority;
209        }
210    }
211
212    public void setTransportOptions(Map<String, Object> transportOptions) {
213        this.transportOptions = transportOptions;
214    }
215
216    @Override
217    public void setBrokerService(BrokerService brokerService) {
218        this.brokerService = brokerService;
219    }
220}