001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLParameters;
040import javax.net.ssl.SSLPeerUnverifiedException;
041import javax.net.ssl.SSLSession;
042
043import org.apache.activemq.MaxFrameSizeExceededException;
044import org.apache.activemq.command.ConnectionInfo;
045import org.apache.activemq.openwire.OpenWireFormat;
046import org.apache.activemq.thread.TaskRunnerFactory;
047import org.apache.activemq.util.IOExceptionSupport;
048import org.apache.activemq.util.ServiceStopper;
049import org.apache.activemq.wireformat.WireFormat;
050import org.slf4j.Logger;
051import org.slf4j.LoggerFactory;
052
053public class NIOSSLTransport extends NIOTransport {
054
055    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
056
057    protected boolean needClientAuth;
058    protected boolean wantClientAuth;
059    protected String[] enabledCipherSuites;
060    protected String[] enabledProtocols;
061    protected boolean verifyHostName = false;
062
063    protected SSLContext sslContext;
064    protected SSLEngine sslEngine;
065    protected SSLSession sslSession;
066
067    protected volatile boolean handshakeInProgress = false;
068    protected SSLEngineResult.Status status = null;
069    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
070    protected TaskRunnerFactory taskRunnerFactory;
071
072    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
073        super(wireFormat, socketFactory, remoteLocation, localLocation);
074    }
075
076    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
077            ByteBuffer inputBuffer) throws IOException {
078        super(wireFormat, socket, initBuffer);
079        this.sslEngine = engine;
080        if (engine != null) {
081            this.sslSession = engine.getSession();
082        }
083        this.inputBuffer = inputBuffer;
084    }
085
086    public void setSslContext(SSLContext sslContext) {
087        this.sslContext = sslContext;
088    }
089
090    volatile boolean hasSslEngine = false;
091
092    @Override
093    protected void initializeStreams() throws IOException {
094        if (sslEngine != null) {
095            hasSslEngine = true;
096        }
097        NIOOutputStream outputStream = null;
098        try {
099            channel = socket.getChannel();
100            channel.configureBlocking(false);
101
102            if (sslContext == null) {
103                sslContext = SSLContext.getDefault();
104            }
105
106            String remoteHost = null;
107            int remotePort = -1;
108
109            try {
110                URI remoteAddress = new URI(this.getRemoteAddress());
111                remoteHost = remoteAddress.getHost();
112                remotePort = remoteAddress.getPort();
113            } catch (Exception e) {
114            }
115
116            // initialize engine, the initial sslSession we get will need to be
117            // updated once the ssl handshake process is completed.
118            if (!hasSslEngine) {
119                if (remoteHost != null && remotePort != -1) {
120                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
121                } else {
122                    sslEngine = sslContext.createSSLEngine();
123                }
124
125                if (verifyHostName) {
126                    SSLParameters sslParams = new SSLParameters();
127                    sslParams.setEndpointIdentificationAlgorithm("HTTPS");
128                    sslEngine.setSSLParameters(sslParams);
129                }
130
131                sslEngine.setUseClientMode(false);
132                if (enabledCipherSuites != null) {
133                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
134                }
135
136                if (enabledProtocols != null) {
137                    sslEngine.setEnabledProtocols(enabledProtocols);
138                }
139
140                if (wantClientAuth) {
141                    sslEngine.setWantClientAuth(wantClientAuth);
142                }
143
144                if (needClientAuth) {
145                    sslEngine.setNeedClientAuth(needClientAuth);
146                }
147
148                sslSession = sslEngine.getSession();
149
150                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
151                inputBuffer.clear();
152            }
153
154            outputStream = new NIOOutputStream(channel);
155            outputStream.setEngine(sslEngine);
156            this.dataOut = new DataOutputStream(outputStream);
157            this.buffOut = outputStream;
158
159            //If the sslEngine was not passed in, then handshake
160            if (!hasSslEngine) {
161                sslEngine.beginHandshake();
162            }
163            handshakeStatus = sslEngine.getHandshakeStatus();
164            if (!hasSslEngine) {
165                doHandshake();
166            }
167
168            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
169                @Override
170                public void onSelect(SelectorSelection selection) {
171                    try {
172                        initialized.await();
173                    } catch (InterruptedException error) {
174                        onException(IOExceptionSupport.create(error));
175                    }
176                    serviceRead();
177                }
178
179                @Override
180                public void onError(SelectorSelection selection, Throwable error) {
181                    if (error instanceof IOException) {
182                        onException((IOException) error);
183                    } else {
184                        onException(IOExceptionSupport.create(error));
185                    }
186                }
187            });
188            doInit();
189
190        } catch (Exception e) {
191            try {
192                if(outputStream != null) {
193                    outputStream.close();
194                }
195                super.closeStreams();
196            } catch (Exception ex) {}
197            throw new IOException(e);
198        }
199    }
200
201    final protected CountDownLatch initialized = new CountDownLatch(1);
202
203    protected void doInit() throws Exception {
204        taskRunnerFactory.execute(new Runnable() {
205
206            @Override
207            public void run() {
208                //Need to start in new thread to let startup finish first
209                //We can trigger a read because we know the channel is ready since the SSL handshake
210                //already happened
211                serviceRead();
212                initialized.countDown();
213            }
214        });
215    }
216
217    //Only used for the auto transport to abort the openwire init method early if already initialized
218    boolean openWireInititialized = false;
219
220    protected void doOpenWireInit() throws Exception {
221        //Do this later to let wire format negotiation happen
222        if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
223            initBuffer.buffer.flip();
224            if (initBuffer.buffer.hasRemaining()) {
225                nextFrameSize = -1;
226                receiveCounter += initBuffer.readSize;
227                processCommand(initBuffer.buffer);
228                processCommand(initBuffer.buffer);
229                initBuffer.buffer.clear();
230                openWireInititialized = true;
231            }
232        }
233    }
234
235    protected void finishHandshake() throws Exception {
236        if (handshakeInProgress) {
237            handshakeInProgress = false;
238            nextFrameSize = -1;
239
240            // Once handshake completes we need to ask for the now real sslSession
241            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
242            // cipher suite.
243            sslSession = sslEngine.getSession();
244        }
245    }
246
247    @Override
248    public void serviceRead() {
249        try {
250            if (handshakeInProgress) {
251                doHandshake();
252            }
253
254            doOpenWireInit();
255
256            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
257            plain.position(plain.limit());
258
259            while (true) {
260                //If the transport was already stopped then break
261                if (this.isStopped()) {
262                    return;
263                }
264
265                if (!plain.hasRemaining()) {
266
267                    int readCount = secureRead(plain);
268
269                    if (readCount == 0) {
270                        break;
271                    }
272
273                    // channel is closed, cleanup
274                    if (readCount == -1) {
275                        onException(new EOFException());
276                        selection.close();
277                        break;
278                    }
279
280                    receiveCounter += readCount;
281                }
282
283                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
284                    processCommand(plain);
285                }
286            }
287        } catch (IOException e) {
288            onException(e);
289        } catch (Throwable e) {
290            onException(IOExceptionSupport.create(e));
291        }
292    }
293
294    protected void processCommand(ByteBuffer plain) throws Exception {
295
296        // Are we waiting for the next Command or are we building on the current one
297        if (nextFrameSize == -1) {
298
299            // We can get small packets that don't give us enough for the frame size
300            // so allocate enough for the initial size value and
301            if (plain.remaining() < Integer.SIZE) {
302                if (currentBuffer == null) {
303                    currentBuffer = ByteBuffer.allocate(4);
304                }
305
306                // Go until we fill the integer sized current buffer.
307                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
308                    currentBuffer.put(plain.get());
309                }
310
311                // Didn't we get enough yet to figure out next frame size.
312                if (currentBuffer.hasRemaining()) {
313                    return;
314                } else {
315                    currentBuffer.flip();
316                    nextFrameSize = currentBuffer.getInt();
317                }
318
319            } else {
320
321                // Either we are completing a previous read of the next frame size or its
322                // fully contained in plain already.
323                if (currentBuffer != null) {
324
325                    // Finish the frame size integer read and get from the current buffer.
326                    while (currentBuffer.hasRemaining()) {
327                        currentBuffer.put(plain.get());
328                    }
329
330                    currentBuffer.flip();
331                    nextFrameSize = currentBuffer.getInt();
332
333                } else {
334                    nextFrameSize = plain.getInt();
335                }
336            }
337
338            if (wireFormat instanceof OpenWireFormat) {
339                OpenWireFormat openWireFormat = (OpenWireFormat) wireFormat;
340                long maxFrameSize = openWireFormat.getMaxFrameSize();
341
342                if (openWireFormat.isMaxFrameSizeEnabled() && nextFrameSize > maxFrameSize) {
343                    throw new MaxFrameSizeExceededException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
344                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
345                }
346            }
347
348            // now we got the data, lets reallocate and store the size for the marshaler.
349            // if there's more data in plain, then the next call will start processing it.
350            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
351            currentBuffer.putInt(nextFrameSize);
352
353        } else {
354            // If its all in one read then we can just take it all, otherwise take only
355            // the current frame size and the next iteration starts a new command.
356            if (currentBuffer != null) {
357                if (currentBuffer.remaining() >= plain.remaining()) {
358                    currentBuffer.put(plain);
359                } else {
360                    byte[] fill = new byte[currentBuffer.remaining()];
361                    plain.get(fill);
362                    currentBuffer.put(fill);
363                }
364
365                // Either we have enough data for a new command or we have to wait for some more.
366                if (currentBuffer.hasRemaining()) {
367                    return;
368                } else {
369                    currentBuffer.flip();
370                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
371                    doConsume(command);
372                    nextFrameSize = -1;
373                    currentBuffer = null;
374               }
375            }
376        }
377    }
378
379    //Prevent concurrent access while reading from the channel
380    protected synchronized int secureRead(ByteBuffer plain) throws Exception {
381
382        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
383            int bytesRead = channel.read(inputBuffer);
384
385            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
386                return 0;
387            }
388
389            if (bytesRead == -1) {
390                sslEngine.closeInbound();
391                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
392                    return -1;
393                }
394            }
395        }
396
397        plain.clear();
398
399        inputBuffer.flip();
400        SSLEngineResult res;
401        do {
402            res = sslEngine.unwrap(inputBuffer, plain);
403        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
404                && res.bytesProduced() == 0);
405
406        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
407            finishHandshake();
408        }
409
410        status = res.getStatus();
411        handshakeStatus = res.getHandshakeStatus();
412
413        // TODO deal with BUFFER_OVERFLOW
414
415        if (status == SSLEngineResult.Status.CLOSED) {
416            sslEngine.closeInbound();
417            return -1;
418        }
419
420        inputBuffer.compact();
421        plain.flip();
422
423        return plain.remaining();
424    }
425
426    protected void doHandshake() throws Exception {
427        handshakeInProgress = true;
428        Selector selector = null;
429        SelectionKey key = null;
430        boolean readable = true;
431        try {
432            while (true) {
433                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
434                switch (handshakeStatus) {
435                    case NEED_UNWRAP:
436                        if (readable) {
437                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
438                        }
439                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
440                            long now = System.currentTimeMillis();
441                            if (selector == null) {
442                                selector = Selector.open();
443                                key = channel.register(selector, SelectionKey.OP_READ);
444                            } else {
445                                key.interestOps(SelectionKey.OP_READ);
446                            }
447                            int keyCount = selector.select(this.getSoTimeout());
448                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
449                                throw new SocketTimeoutException("Timeout during handshake");
450                            }
451                            readable = key.isReadable();
452                        }
453                        break;
454                    case NEED_TASK:
455                        Runnable task;
456                        while ((task = sslEngine.getDelegatedTask()) != null) {
457                            task.run();
458                        }
459                        break;
460                    case NEED_WRAP:
461                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
462                        break;
463                    case FINISHED:
464                    case NOT_HANDSHAKING:
465                        finishHandshake();
466                        return;
467                }
468            }
469        } finally {
470            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
471            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
472        }
473    }
474
475    @Override
476    protected void doStart() throws Exception {
477        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
478        // no need to init as we can delay that until demand (eg in doHandshake)
479        super.doStart();
480    }
481
482    @Override
483    protected void doStop(ServiceStopper stopper) throws Exception {
484        initialized.countDown();
485
486        if (taskRunnerFactory != null) {
487            taskRunnerFactory.shutdownNow();
488            taskRunnerFactory = null;
489        }
490        if (channel != null) {
491            channel.close();
492            channel = null;
493        }
494        super.doStop(stopper);
495    }
496
497    /**
498     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
499     *
500     * @param command
501     *            The Command coming in.
502     */
503    @Override
504    public void doConsume(Object command) {
505        if (command instanceof ConnectionInfo) {
506            ConnectionInfo connectionInfo = (ConnectionInfo) command;
507            connectionInfo.setTransportContext(getPeerCertificates());
508        }
509        super.doConsume(command);
510    }
511
512    /**
513     * @return peer certificate chain associated with the ssl socket
514     */
515    @Override
516    public X509Certificate[] getPeerCertificates() {
517
518        X509Certificate[] clientCertChain = null;
519        try {
520            if (sslEngine.getSession() != null) {
521                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
522            }
523        } catch (SSLPeerUnverifiedException e) {
524            if (LOG.isTraceEnabled()) {
525                LOG.trace("Failed to get peer certificates.", e);
526            }
527        }
528
529        return clientCertChain;
530    }
531
532    public boolean isNeedClientAuth() {
533        return needClientAuth;
534    }
535
536    public void setNeedClientAuth(boolean needClientAuth) {
537        this.needClientAuth = needClientAuth;
538    }
539
540    public boolean isWantClientAuth() {
541        return wantClientAuth;
542    }
543
544    public void setWantClientAuth(boolean wantClientAuth) {
545        this.wantClientAuth = wantClientAuth;
546    }
547
548    public String[] getEnabledCipherSuites() {
549        return enabledCipherSuites;
550    }
551
552    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
553        this.enabledCipherSuites = enabledCipherSuites;
554    }
555
556    public String[] getEnabledProtocols() {
557        return enabledProtocols;
558    }
559
560    public void setEnabledProtocols(String[] enabledProtocols) {
561        this.enabledProtocols = enabledProtocols;
562    }
563
564    public boolean isVerifyHostName() {
565        return verifyHostName;
566    }
567
568    public void setVerifyHostName(boolean verifyHostName) {
569        this.verifyHostName = verifyHostName;
570    }
571}