001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.ws.jetty9; 019 020import java.io.IOException; 021import java.net.URI; 022import java.util.ArrayList; 023import java.util.Collections; 024import java.util.Comparator; 025import java.util.HashMap; 026import java.util.List; 027import java.util.Map; 028import java.util.concurrent.ConcurrentHashMap; 029 030import javax.servlet.ServletException; 031import javax.servlet.http.HttpServletRequest; 032import javax.servlet.http.HttpServletResponse; 033 034import org.apache.activemq.broker.BrokerService; 035import org.apache.activemq.broker.BrokerServiceAware; 036import org.apache.activemq.transport.Transport; 037import org.apache.activemq.transport.TransportAcceptListener; 038import org.apache.activemq.transport.TransportFactory; 039import org.apache.activemq.transport.util.HttpTransportUtils; 040import org.apache.activemq.transport.ws.WSTransportProxy; 041import org.eclipse.jetty.websocket.api.WebSocketListener; 042import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; 043import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; 044import org.eclipse.jetty.websocket.servlet.WebSocketCreator; 045import org.eclipse.jetty.websocket.servlet.WebSocketServlet; 046import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; 047 048/** 049 * Handle connection upgrade requests and creates web sockets 050 */ 051public class WSServlet extends WebSocketServlet implements BrokerServiceAware { 052 053 private static final long serialVersionUID = -4716657876092884139L; 054 055 private TransportAcceptListener listener; 056 057 private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<>(); 058 private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<>(); 059 060 private Map<String, Object> transportOptions; 061 private BrokerService brokerService; 062 063 private enum Protocol { 064 MQTT, STOMP, UNKNOWN 065 } 066 067 static { 068 stompProtocols.put("v12.stomp", 3); 069 stompProtocols.put("v11.stomp", 2); 070 stompProtocols.put("v10.stomp", 1); 071 stompProtocols.put("stomp", 0); 072 073 mqttProtocols.put("mqttv3.1", 1); 074 mqttProtocols.put("mqtt", 0); 075 } 076 077 @Override 078 public void init() throws ServletException { 079 super.init(); 080 listener = (TransportAcceptListener) getServletContext().getAttribute("acceptListener"); 081 if (listener == null) { 082 throw new ServletException("No such attribute 'acceptListener' available in the ServletContext"); 083 } 084 } 085 086 @Override 087 protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { 088 //return empty response - AMQ-6491 089 } 090 091 @Override 092 public void configure(WebSocketServletFactory factory) { 093 factory.setCreator(new WebSocketCreator() { 094 @Override 095 public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) { 096 WebSocketListener socket; 097 Protocol requestedProtocol = Protocol.UNKNOWN; 098 099 // When no sub-protocol is requested we default to STOMP for legacy reasons. 100 if (!req.getSubProtocols().isEmpty()) { 101 for (String subProtocol : req.getSubProtocols()) { 102 if (subProtocol.startsWith("mqtt")) { 103 requestedProtocol = Protocol.MQTT; 104 } else if (subProtocol.contains("stomp")) { 105 requestedProtocol = Protocol.STOMP; 106 } 107 } 108 } else { 109 requestedProtocol = Protocol.STOMP; 110 } 111 112 switch (requestedProtocol) { 113 case MQTT: 114 socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); 115 ((MQTTSocket) socket).setTransportOptions(new HashMap<>(transportOptions)); 116 ((MQTTSocket) socket).setPeerCertificates(req.getCertificates()); 117 resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt")); 118 break; 119 case UNKNOWN: 120 socket = findWSTransport(req, resp); 121 if (socket != null) { 122 break; 123 } 124 case STOMP: 125 socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); 126 ((StompSocket) socket).setPeerCertificates(req.getCertificates()); 127 resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp")); 128 break; 129 default: 130 socket = null; 131 listener.onAcceptError(new IOException("Unknown protocol requested")); 132 break; 133 } 134 135 if (socket != null) { 136 listener.onAccept((Transport) socket); 137 } 138 139 return socket; 140 } 141 }); 142 } 143 144 private WebSocketListener findWSTransport(ServletUpgradeRequest request, ServletUpgradeResponse response) { 145 WSTransportProxy proxy = null; 146 147 for (String subProtocol : request.getSubProtocols()) { 148 try { 149 String remoteAddress = HttpTransportUtils.generateWsRemoteAddress(request.getHttpServletRequest(), subProtocol); 150 URI remoteURI = new URI(remoteAddress); 151 152 TransportFactory factory = TransportFactory.findTransportFactory(remoteURI); 153 154 if (factory instanceof BrokerServiceAware) { 155 ((BrokerServiceAware) factory).setBrokerService(brokerService); 156 } 157 158 Transport transport = factory.doConnect(remoteURI); 159 160 proxy = new WSTransportProxy(remoteAddress, transport); 161 proxy.setPeerCertificates(request.getCertificates()); 162 proxy.setTransportOptions(transportOptions); 163 164 response.setAcceptedSubProtocol(proxy.getSubProtocol()); 165 } catch (Exception e) { 166 proxy = null; 167 168 // Keep going and try any other sub-protocols present. 169 continue; 170 } 171 } 172 173 return proxy; 174 } 175 176 private String getAcceptedSubProtocol(final Map<String, Integer> protocols, List<String> subProtocols, String defaultProtocol) { 177 List<SubProtocol> matchedProtocols = new ArrayList<>(); 178 if (subProtocols != null && subProtocols.size() > 0) { 179 // detect which subprotocols match accepted protocols and add to the 180 // list 181 for (String subProtocol : subProtocols) { 182 Integer priority = protocols.get(subProtocol); 183 if (subProtocol != null && priority != null) { 184 // only insert if both subProtocol and priority are not null 185 matchedProtocols.add(new SubProtocol(subProtocol, priority)); 186 } 187 } 188 // sort the list by priority 189 if (matchedProtocols.size() > 0) { 190 Collections.sort(matchedProtocols, new Comparator<SubProtocol>() { 191 @Override 192 public int compare(SubProtocol s1, SubProtocol s2) { 193 return s2.priority.compareTo(s1.priority); 194 } 195 }); 196 return matchedProtocols.get(0).protocol; 197 } 198 } 199 return defaultProtocol; 200 } 201 202 private class SubProtocol { 203 private String protocol; 204 private Integer priority; 205 206 public SubProtocol(String protocol, Integer priority) { 207 this.protocol = protocol; 208 this.priority = priority; 209 } 210 } 211 212 public void setTransportOptions(Map<String, Object> transportOptions) { 213 this.transportOptions = transportOptions; 214 } 215 216 @Override 217 public void setBrokerService(BrokerService brokerService) { 218 this.brokerService = brokerService; 219 } 220}