001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.URI; 026import java.net.UnknownHostException; 027import java.nio.ByteBuffer; 028import java.security.cert.X509Certificate; 029 030import javax.net.SocketFactory; 031import javax.net.ssl.SSLContext; 032import javax.net.ssl.SSLEngine; 033import javax.net.ssl.SSLEngineResult; 034import javax.net.ssl.SSLPeerUnverifiedException; 035import javax.net.ssl.SSLSession; 036 037import org.apache.activemq.command.ConnectionInfo; 038import org.apache.activemq.openwire.OpenWireFormat; 039import org.apache.activemq.thread.TaskRunnerFactory; 040import org.apache.activemq.util.IOExceptionSupport; 041import org.apache.activemq.util.ServiceStopper; 042import org.apache.activemq.wireformat.WireFormat; 043import org.slf4j.Logger; 044import org.slf4j.LoggerFactory; 045 046public class NIOSSLTransport extends NIOTransport { 047 048 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 049 050 protected boolean needClientAuth; 051 protected boolean wantClientAuth; 052 protected String[] enabledCipherSuites; 053 protected String[] enabledProtocols; 054 055 protected SSLContext sslContext; 056 protected SSLEngine sslEngine; 057 protected SSLSession sslSession; 058 059 protected volatile boolean handshakeInProgress = false; 060 protected SSLEngineResult.Status status = null; 061 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 062 protected TaskRunnerFactory taskRunnerFactory; 063 064 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 065 super(wireFormat, socketFactory, remoteLocation, localLocation); 066 } 067 068 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 069 super(wireFormat, socket); 070 } 071 072 public void setSslContext(SSLContext sslContext) { 073 this.sslContext = sslContext; 074 } 075 076 @Override 077 protected void initializeStreams() throws IOException { 078 NIOOutputStream outputStream = null; 079 try { 080 channel = socket.getChannel(); 081 channel.configureBlocking(false); 082 083 if (sslContext == null) { 084 sslContext = SSLContext.getDefault(); 085 } 086 087 String remoteHost = null; 088 int remotePort = -1; 089 090 try { 091 URI remoteAddress = new URI(this.getRemoteAddress()); 092 remoteHost = remoteAddress.getHost(); 093 remotePort = remoteAddress.getPort(); 094 } catch (Exception e) { 095 } 096 097 // initialize engine, the initial sslSession we get will need to be 098 // updated once the ssl handshake process is completed. 099 if (remoteHost != null && remotePort != -1) { 100 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 101 } else { 102 sslEngine = sslContext.createSSLEngine(); 103 } 104 105 sslEngine.setUseClientMode(false); 106 if (enabledCipherSuites != null) { 107 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 108 } 109 110 if (enabledProtocols != null) { 111 sslEngine.setEnabledProtocols(enabledProtocols); 112 } 113 114 if (wantClientAuth) { 115 sslEngine.setWantClientAuth(wantClientAuth); 116 } 117 118 if (needClientAuth) { 119 sslEngine.setNeedClientAuth(needClientAuth); 120 } 121 122 sslSession = sslEngine.getSession(); 123 124 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 125 inputBuffer.clear(); 126 127 outputStream = new NIOOutputStream(channel); 128 outputStream.setEngine(sslEngine); 129 this.dataOut = new DataOutputStream(outputStream); 130 this.buffOut = outputStream; 131 sslEngine.beginHandshake(); 132 handshakeStatus = sslEngine.getHandshakeStatus(); 133 doHandshake(); 134 } catch (Exception e) { 135 try { 136 if(outputStream != null) { 137 outputStream.close(); 138 } 139 super.closeStreams(); 140 } catch (Exception ex) {} 141 throw new IOException(e); 142 } 143 } 144 145 protected void finishHandshake() throws Exception { 146 if (handshakeInProgress) { 147 handshakeInProgress = false; 148 nextFrameSize = -1; 149 150 // Once handshake completes we need to ask for the now real sslSession 151 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 152 // cipher suite. 153 sslSession = sslEngine.getSession(); 154 155 // listen for events telling us when the socket is readable. 156 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 157 @Override 158 public void onSelect(SelectorSelection selection) { 159 serviceRead(); 160 } 161 162 @Override 163 public void onError(SelectorSelection selection, Throwable error) { 164 if (error instanceof IOException) { 165 onException((IOException) error); 166 } else { 167 onException(IOExceptionSupport.create(error)); 168 } 169 } 170 }); 171 } 172 } 173 174 @Override 175 protected void serviceRead() { 176 try { 177 if (handshakeInProgress) { 178 doHandshake(); 179 } 180 181 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 182 plain.position(plain.limit()); 183 184 while (true) { 185 if (!plain.hasRemaining()) { 186 187 int readCount = secureRead(plain); 188 189 if (readCount == 0) { 190 break; 191 } 192 193 // channel is closed, cleanup 194 if (readCount == -1) { 195 onException(new EOFException()); 196 selection.close(); 197 break; 198 } 199 200 receiveCounter += readCount; 201 } 202 203 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 204 processCommand(plain); 205 } 206 } 207 } catch (IOException e) { 208 onException(e); 209 } catch (Throwable e) { 210 onException(IOExceptionSupport.create(e)); 211 } 212 } 213 214 protected void processCommand(ByteBuffer plain) throws Exception { 215 216 // Are we waiting for the next Command or are we building on the current one 217 if (nextFrameSize == -1) { 218 219 // We can get small packets that don't give us enough for the frame size 220 // so allocate enough for the initial size value and 221 if (plain.remaining() < Integer.SIZE) { 222 if (currentBuffer == null) { 223 currentBuffer = ByteBuffer.allocate(4); 224 } 225 226 // Go until we fill the integer sized current buffer. 227 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 228 currentBuffer.put(plain.get()); 229 } 230 231 // Didn't we get enough yet to figure out next frame size. 232 if (currentBuffer.hasRemaining()) { 233 return; 234 } else { 235 currentBuffer.flip(); 236 nextFrameSize = currentBuffer.getInt(); 237 } 238 239 } else { 240 241 // Either we are completing a previous read of the next frame size or its 242 // fully contained in plain already. 243 if (currentBuffer != null) { 244 245 // Finish the frame size integer read and get from the current buffer. 246 while (currentBuffer.hasRemaining()) { 247 currentBuffer.put(plain.get()); 248 } 249 250 currentBuffer.flip(); 251 nextFrameSize = currentBuffer.getInt(); 252 253 } else { 254 nextFrameSize = plain.getInt(); 255 } 256 } 257 258 if (wireFormat instanceof OpenWireFormat) { 259 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 260 if (nextFrameSize > maxFrameSize) { 261 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 262 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 263 } 264 } 265 266 // now we got the data, lets reallocate and store the size for the marshaler. 267 // if there's more data in plain, then the next call will start processing it. 268 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 269 currentBuffer.putInt(nextFrameSize); 270 271 } else { 272 273 // If its all in one read then we can just take it all, otherwise take only 274 // the current frame size and the next iteration starts a new command. 275 if (currentBuffer.remaining() >= plain.remaining()) { 276 currentBuffer.put(plain); 277 } else { 278 byte[] fill = new byte[currentBuffer.remaining()]; 279 plain.get(fill); 280 currentBuffer.put(fill); 281 } 282 283 // Either we have enough data for a new command or we have to wait for some more. 284 if (currentBuffer.hasRemaining()) { 285 return; 286 } else { 287 currentBuffer.flip(); 288 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 289 doConsume(command); 290 nextFrameSize = -1; 291 currentBuffer = null; 292 } 293 } 294 } 295 296 protected int secureRead(ByteBuffer plain) throws Exception { 297 298 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 299 int bytesRead = channel.read(inputBuffer); 300 301 if (bytesRead == 0) { 302 return 0; 303 } 304 305 if (bytesRead == -1) { 306 sslEngine.closeInbound(); 307 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 308 return -1; 309 } 310 } 311 } 312 313 plain.clear(); 314 315 inputBuffer.flip(); 316 SSLEngineResult res; 317 do { 318 res = sslEngine.unwrap(inputBuffer, plain); 319 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 320 && res.bytesProduced() == 0); 321 322 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 323 finishHandshake(); 324 } 325 326 status = res.getStatus(); 327 handshakeStatus = res.getHandshakeStatus(); 328 329 // TODO deal with BUFFER_OVERFLOW 330 331 if (status == SSLEngineResult.Status.CLOSED) { 332 sslEngine.closeInbound(); 333 return -1; 334 } 335 336 inputBuffer.compact(); 337 plain.flip(); 338 339 return plain.remaining(); 340 } 341 342 protected void doHandshake() throws Exception { 343 handshakeInProgress = true; 344 while (true) { 345 switch (sslEngine.getHandshakeStatus()) { 346 case NEED_UNWRAP: 347 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 348 break; 349 case NEED_TASK: 350 Runnable task; 351 while ((task = sslEngine.getDelegatedTask()) != null) { 352 taskRunnerFactory.execute(task); 353 } 354 break; 355 case NEED_WRAP: 356 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 357 break; 358 case FINISHED: 359 case NOT_HANDSHAKING: 360 finishHandshake(); 361 return; 362 } 363 } 364 } 365 366 @Override 367 protected void doStart() throws Exception { 368 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 369 // no need to init as we can delay that until demand (eg in doHandshake) 370 super.doStart(); 371 } 372 373 @Override 374 protected void doStop(ServiceStopper stopper) throws Exception { 375 if (taskRunnerFactory != null) { 376 taskRunnerFactory.shutdownNow(); 377 taskRunnerFactory = null; 378 } 379 if (channel != null) { 380 channel.close(); 381 channel = null; 382 } 383 super.doStop(stopper); 384 } 385 386 /** 387 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 388 * 389 * @param command 390 * The Command coming in. 391 */ 392 @Override 393 public void doConsume(Object command) { 394 if (command instanceof ConnectionInfo) { 395 ConnectionInfo connectionInfo = (ConnectionInfo) command; 396 connectionInfo.setTransportContext(getPeerCertificates()); 397 } 398 super.doConsume(command); 399 } 400 401 /** 402 * @return peer certificate chain associated with the ssl socket 403 */ 404 public X509Certificate[] getPeerCertificates() { 405 406 X509Certificate[] clientCertChain = null; 407 try { 408 if (sslEngine.getSession() != null) { 409 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 410 } 411 } catch (SSLPeerUnverifiedException e) { 412 if (LOG.isTraceEnabled()) { 413 LOG.trace("Failed to get peer certificates.", e); 414 } 415 } 416 417 return clientCertChain; 418 } 419 420 public boolean isNeedClientAuth() { 421 return needClientAuth; 422 } 423 424 public void setNeedClientAuth(boolean needClientAuth) { 425 this.needClientAuth = needClientAuth; 426 } 427 428 public boolean isWantClientAuth() { 429 return wantClientAuth; 430 } 431 432 public void setWantClientAuth(boolean wantClientAuth) { 433 this.wantClientAuth = wantClientAuth; 434 } 435 436 public String[] getEnabledCipherSuites() { 437 return enabledCipherSuites; 438 } 439 440 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 441 this.enabledCipherSuites = enabledCipherSuites; 442 } 443 444 public String[] getEnabledProtocols() { 445 return enabledProtocols; 446 } 447 448 public void setEnabledProtocols(String[] enabledProtocols) { 449 this.enabledProtocols = enabledProtocols; 450 } 451}