001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.URI;
026import java.net.UnknownHostException;
027import java.nio.ByteBuffer;
028import java.security.cert.X509Certificate;
029
030import javax.net.SocketFactory;
031import javax.net.ssl.SSLContext;
032import javax.net.ssl.SSLEngine;
033import javax.net.ssl.SSLEngineResult;
034import javax.net.ssl.SSLPeerUnverifiedException;
035import javax.net.ssl.SSLSession;
036
037import org.apache.activemq.command.ConnectionInfo;
038import org.apache.activemq.openwire.OpenWireFormat;
039import org.apache.activemq.thread.TaskRunnerFactory;
040import org.apache.activemq.util.IOExceptionSupport;
041import org.apache.activemq.util.ServiceStopper;
042import org.apache.activemq.wireformat.WireFormat;
043import org.slf4j.Logger;
044import org.slf4j.LoggerFactory;
045
046public class NIOSSLTransport extends NIOTransport {
047
048    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
049
050    protected boolean needClientAuth;
051    protected boolean wantClientAuth;
052    protected String[] enabledCipherSuites;
053    protected String[] enabledProtocols;
054
055    protected SSLContext sslContext;
056    protected SSLEngine sslEngine;
057    protected SSLSession sslSession;
058
059    protected volatile boolean handshakeInProgress = false;
060    protected SSLEngineResult.Status status = null;
061    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
062    protected TaskRunnerFactory taskRunnerFactory;
063
064    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
065        super(wireFormat, socketFactory, remoteLocation, localLocation);
066    }
067
068    public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
069        super(wireFormat, socket);
070    }
071
072    public void setSslContext(SSLContext sslContext) {
073        this.sslContext = sslContext;
074    }
075
076    @Override
077    protected void initializeStreams() throws IOException {
078        NIOOutputStream outputStream = null;
079        try {
080            channel = socket.getChannel();
081            channel.configureBlocking(false);
082
083            if (sslContext == null) {
084                sslContext = SSLContext.getDefault();
085            }
086
087            String remoteHost = null;
088            int remotePort = -1;
089
090            try {
091                URI remoteAddress = new URI(this.getRemoteAddress());
092                remoteHost = remoteAddress.getHost();
093                remotePort = remoteAddress.getPort();
094            } catch (Exception e) {
095            }
096
097            // initialize engine, the initial sslSession we get will need to be
098            // updated once the ssl handshake process is completed.
099            if (remoteHost != null && remotePort != -1) {
100                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
101            } else {
102                sslEngine = sslContext.createSSLEngine();
103            }
104
105            sslEngine.setUseClientMode(false);
106            if (enabledCipherSuites != null) {
107                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
108            }
109
110            if (enabledProtocols != null) {
111                sslEngine.setEnabledProtocols(enabledProtocols);
112            }
113
114            if (wantClientAuth) {
115                sslEngine.setWantClientAuth(wantClientAuth);
116            }
117
118            if (needClientAuth) {
119                sslEngine.setNeedClientAuth(needClientAuth);
120            }
121
122            sslSession = sslEngine.getSession();
123
124            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
125            inputBuffer.clear();
126
127            outputStream = new NIOOutputStream(channel);
128            outputStream.setEngine(sslEngine);
129            this.dataOut = new DataOutputStream(outputStream);
130            this.buffOut = outputStream;
131            sslEngine.beginHandshake();
132            handshakeStatus = sslEngine.getHandshakeStatus();
133            doHandshake();
134        } catch (Exception e) {
135            try {
136                if(outputStream != null) {
137                    outputStream.close();
138                }
139                super.closeStreams();
140            } catch (Exception ex) {}
141            throw new IOException(e);
142        }
143    }
144
145    protected void finishHandshake() throws Exception {
146        if (handshakeInProgress) {
147            handshakeInProgress = false;
148            nextFrameSize = -1;
149
150            // Once handshake completes we need to ask for the now real sslSession
151            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
152            // cipher suite.
153            sslSession = sslEngine.getSession();
154
155            // listen for events telling us when the socket is readable.
156            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
157                @Override
158                public void onSelect(SelectorSelection selection) {
159                    serviceRead();
160                }
161
162                @Override
163                public void onError(SelectorSelection selection, Throwable error) {
164                    if (error instanceof IOException) {
165                        onException((IOException) error);
166                    } else {
167                        onException(IOExceptionSupport.create(error));
168                    }
169                }
170            });
171        }
172    }
173
174    @Override
175    protected void serviceRead() {
176        try {
177            if (handshakeInProgress) {
178                doHandshake();
179            }
180
181            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
182            plain.position(plain.limit());
183
184            while (true) {
185                if (!plain.hasRemaining()) {
186
187                    int readCount = secureRead(plain);
188
189                    if (readCount == 0) {
190                        break;
191                    }
192
193                    // channel is closed, cleanup
194                    if (readCount == -1) {
195                        onException(new EOFException());
196                        selection.close();
197                        break;
198                    }
199
200                    receiveCounter += readCount;
201                }
202
203                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
204                    processCommand(plain);
205                }
206            }
207        } catch (IOException e) {
208            onException(e);
209        } catch (Throwable e) {
210            onException(IOExceptionSupport.create(e));
211        }
212    }
213
214    protected void processCommand(ByteBuffer plain) throws Exception {
215
216        // Are we waiting for the next Command or are we building on the current one
217        if (nextFrameSize == -1) {
218
219            // We can get small packets that don't give us enough for the frame size
220            // so allocate enough for the initial size value and
221            if (plain.remaining() < Integer.SIZE) {
222                if (currentBuffer == null) {
223                    currentBuffer = ByteBuffer.allocate(4);
224                }
225
226                // Go until we fill the integer sized current buffer.
227                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
228                    currentBuffer.put(plain.get());
229                }
230
231                // Didn't we get enough yet to figure out next frame size.
232                if (currentBuffer.hasRemaining()) {
233                    return;
234                } else {
235                    currentBuffer.flip();
236                    nextFrameSize = currentBuffer.getInt();
237                }
238
239            } else {
240
241                // Either we are completing a previous read of the next frame size or its
242                // fully contained in plain already.
243                if (currentBuffer != null) {
244
245                    // Finish the frame size integer read and get from the current buffer.
246                    while (currentBuffer.hasRemaining()) {
247                        currentBuffer.put(plain.get());
248                    }
249
250                    currentBuffer.flip();
251                    nextFrameSize = currentBuffer.getInt();
252
253                } else {
254                    nextFrameSize = plain.getInt();
255                }
256            }
257
258            if (wireFormat instanceof OpenWireFormat) {
259                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
260                if (nextFrameSize > maxFrameSize) {
261                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
262                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
263                }
264            }
265
266            // now we got the data, lets reallocate and store the size for the marshaler.
267            // if there's more data in plain, then the next call will start processing it.
268            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
269            currentBuffer.putInt(nextFrameSize);
270
271        } else {
272
273            // If its all in one read then we can just take it all, otherwise take only
274            // the current frame size and the next iteration starts a new command.
275            if (currentBuffer.remaining() >= plain.remaining()) {
276                currentBuffer.put(plain);
277            } else {
278                byte[] fill = new byte[currentBuffer.remaining()];
279                plain.get(fill);
280                currentBuffer.put(fill);
281            }
282
283            // Either we have enough data for a new command or we have to wait for some more.
284            if (currentBuffer.hasRemaining()) {
285                return;
286            } else {
287                currentBuffer.flip();
288                Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
289                doConsume(command);
290                nextFrameSize = -1;
291                currentBuffer = null;
292            }
293        }
294    }
295
296    protected int secureRead(ByteBuffer plain) throws Exception {
297
298        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
299            int bytesRead = channel.read(inputBuffer);
300
301            if (bytesRead == 0) {
302                return 0;
303            }
304
305            if (bytesRead == -1) {
306                sslEngine.closeInbound();
307                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
308                    return -1;
309                }
310            }
311        }
312
313        plain.clear();
314
315        inputBuffer.flip();
316        SSLEngineResult res;
317        do {
318            res = sslEngine.unwrap(inputBuffer, plain);
319        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
320                && res.bytesProduced() == 0);
321
322        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
323            finishHandshake();
324        }
325
326        status = res.getStatus();
327        handshakeStatus = res.getHandshakeStatus();
328
329        // TODO deal with BUFFER_OVERFLOW
330
331        if (status == SSLEngineResult.Status.CLOSED) {
332            sslEngine.closeInbound();
333            return -1;
334        }
335
336        inputBuffer.compact();
337        plain.flip();
338
339        return plain.remaining();
340    }
341
342    protected void doHandshake() throws Exception {
343        handshakeInProgress = true;
344        while (true) {
345            switch (sslEngine.getHandshakeStatus()) {
346            case NEED_UNWRAP:
347                secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
348                break;
349            case NEED_TASK:
350                Runnable task;
351                while ((task = sslEngine.getDelegatedTask()) != null) {
352                    taskRunnerFactory.execute(task);
353                }
354                break;
355            case NEED_WRAP:
356                ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
357                break;
358            case FINISHED:
359            case NOT_HANDSHAKING:
360                finishHandshake();
361                return;
362            }
363        }
364    }
365
366    @Override
367    protected void doStart() throws Exception {
368        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
369        // no need to init as we can delay that until demand (eg in doHandshake)
370        super.doStart();
371    }
372
373    @Override
374    protected void doStop(ServiceStopper stopper) throws Exception {
375        if (taskRunnerFactory != null) {
376            taskRunnerFactory.shutdownNow();
377            taskRunnerFactory = null;
378        }
379        if (channel != null) {
380            channel.close();
381            channel = null;
382        }
383        super.doStop(stopper);
384    }
385
386    /**
387     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
388     *
389     * @param command
390     *            The Command coming in.
391     */
392    @Override
393    public void doConsume(Object command) {
394        if (command instanceof ConnectionInfo) {
395            ConnectionInfo connectionInfo = (ConnectionInfo) command;
396            connectionInfo.setTransportContext(getPeerCertificates());
397        }
398        super.doConsume(command);
399    }
400
401    /**
402     * @return peer certificate chain associated with the ssl socket
403     */
404    public X509Certificate[] getPeerCertificates() {
405
406        X509Certificate[] clientCertChain = null;
407        try {
408            if (sslEngine.getSession() != null) {
409                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
410            }
411        } catch (SSLPeerUnverifiedException e) {
412            if (LOG.isTraceEnabled()) {
413                LOG.trace("Failed to get peer certificates.", e);
414            }
415        }
416
417        return clientCertChain;
418    }
419
420    public boolean isNeedClientAuth() {
421        return needClientAuth;
422    }
423
424    public void setNeedClientAuth(boolean needClientAuth) {
425        this.needClientAuth = needClientAuth;
426    }
427
428    public boolean isWantClientAuth() {
429        return wantClientAuth;
430    }
431
432    public void setWantClientAuth(boolean wantClientAuth) {
433        this.wantClientAuth = wantClientAuth;
434    }
435
436    public String[] getEnabledCipherSuites() {
437        return enabledCipherSuites;
438    }
439
440    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
441        this.enabledCipherSuites = enabledCipherSuites;
442    }
443
444    public String[] getEnabledProtocols() {
445        return enabledProtocols;
446    }
447
448    public void setEnabledProtocols(String[] enabledProtocols) {
449        this.enabledProtocols = enabledProtocols;
450    }
451}