mirror of https://github.com/OpenIdentityPlatform/OpenDJ.git

Matthew Swift
10.11.2012 8b8e1bc71d1d452998f7d92f3e1ec34e5439c880
Preparation work for OPENDJ-420: Rare SSLExceptions while handling LDAPS connections and big LDAP searches

Reformat and clean up code.
1 files modified
416 ■■■■ changed files
opends/src/server/org/opends/server/extensions/TLSByteChannel.java 416 ●●●● patch | view | raw | blame | history
opends/src/server/org/opends/server/extensions/TLSByteChannel.java
@@ -53,49 +53,86 @@
/**
 * A class that provides a TLS byte channel implementation.
 */
public class TLSByteChannel implements ByteChannel, ConnectionSecurityProvider
public class TLSByteChannel implements ConnectionSecurityProvider
{
  /**
   * Private implementation.
   */
  private final class ByteChannelImpl implements ByteChannel
  {
    /**
     * {@inheritDoc}
     */
    public int read(ByteBuffer dst) throws IOException
    {
      // TODO Auto-generated method stub
      return 0;
    }
    /**
     * {@inheritDoc}
     */
    public boolean isOpen()
    {
      // TODO Auto-generated method stub
      return false;
    }
    /**
     * {@inheritDoc}
     */
    public void close() throws IOException
    {
      // TODO Auto-generated method stub
    }
    /**
     * {@inheritDoc}
     */
    public int write(ByteBuffer src) throws IOException
    {
      // TODO Auto-generated method stub
      return 0;
    }
  }
  private static final DebugTracer TRACER = getTracer();
  private final ByteChannel socketChannel;
  private final SSLEngine sslEngine;
  private final ByteChannelImpl pimpl = new ByteChannelImpl();
  // read copy to buffer
  private final ByteBuffer appData;
  // read encrypted
  private final ByteBuffer appNetData;
  // Write encrypted
  private final ByteBuffer netData;
  private final ByteBuffer tempData;
  private final int sslBufferSize;
  private final int appBufSize;
  private boolean reading = false;
  // Map of cipher phrases to effective key size (bits). Taken from the
  // following RFCs: 5289, 4346, 3268,4132 and 4162.
  private static final Map<String, Integer> cipherMap;
  private static final Map<String, Integer> CIPHER_MAP;
  static
  {
    cipherMap = new LinkedHashMap<String, Integer>();
    cipherMap.put("_WITH_AES_256_CBC_", new Integer(256));
    cipherMap.put("_WITH_CAMELLIA_256_CBC_", new Integer(256));
    cipherMap.put("_WITH_AES_256_GCM_", new Integer(256));
    cipherMap.put("_WITH_3DES_EDE_CBC_", new Integer(112));
    cipherMap.put("_WITH_AES_128_GCM_", new Integer(128));
    cipherMap.put("_WITH_SEED_CBC_", new Integer(128));
    cipherMap.put("_WITH_CAMELLIA_128_CBC_", new Integer(128));
    cipherMap.put("_WITH_AES_128_CBC_", new Integer(128));
    cipherMap.put("_WITH_IDEA_CBC_", new Integer(128));
    cipherMap.put("_WITH_DES_CBC_", new Integer(56));
    cipherMap.put("_WITH_RC2_CBC_40_", new Integer(40));
    cipherMap.put("_WITH_RC4_40_", new Integer(40));
    cipherMap.put("_WITH_DES40_CBC_", new Integer(40));
    cipherMap.put("_WITH_NULL_", new Integer(0));
    CIPHER_MAP = new LinkedHashMap<String, Integer>();
    CIPHER_MAP.put("_WITH_AES_256_CBC_", new Integer(256));
    CIPHER_MAP.put("_WITH_CAMELLIA_256_CBC_", new Integer(256));
    CIPHER_MAP.put("_WITH_AES_256_GCM_", new Integer(256));
    CIPHER_MAP.put("_WITH_3DES_EDE_CBC_", new Integer(112));
    CIPHER_MAP.put("_WITH_AES_128_GCM_", new Integer(128));
    CIPHER_MAP.put("_WITH_SEED_CBC_", new Integer(128));
    CIPHER_MAP.put("_WITH_CAMELLIA_128_CBC_", new Integer(128));
    CIPHER_MAP.put("_WITH_AES_128_CBC_", new Integer(128));
    CIPHER_MAP.put("_WITH_IDEA_CBC_", new Integer(128));
    CIPHER_MAP.put("_WITH_DES_CBC_", new Integer(56));
    CIPHER_MAP.put("_WITH_RC2_CBC_40_", new Integer(40));
    CIPHER_MAP.put("_WITH_RC4_40_", new Integer(40));
    CIPHER_MAP.put("_WITH_DES40_CBC_", new Integer(40));
    CIPHER_MAP.put("_WITH_NULL_", new Integer(0));
  };
@@ -137,9 +174,11 @@
    // avoid blocking new connections. Just remove for now to prevent
    // potential DoS attacks. SSL sessions will not be reused and some
    // cipher suites (such as Kerberos) will not work.
    // String hostName = socketChannel.socket().getInetAddress().getHostName();
    // int port = socketChannel.socket().getPort();
    // sslEngine = sslContext.createSSLEngine(hostName, port);
    sslEngine = sslContext.createSSLEngine();
    sslEngine.setUseClientMode(false);
    final Set<String> protocols = config.getSSLProtocol();
@@ -170,44 +209,6 @@
      sslEngine.setWantClientAuth(true);
      break;
    }
    final SSLSession sslSession = sslEngine.getSession();
    sslBufferSize = sslSession.getPacketBufferSize();
    appBufSize = sslSession.getApplicationBufferSize();
    appNetData = ByteBuffer.allocate(sslBufferSize);
    netData = ByteBuffer.allocate(sslBufferSize);
    appData = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
    tempData = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
  }
  /**
   * {@inheritDoc}
   */
  public synchronized void close() throws IOException
  {
    sslEngine.closeInbound();
    sslEngine.closeOutbound();
    final SSLEngineResult.HandshakeStatus hsStatus = sslEngine
        .getHandshakeStatus();
    if (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED
        && hsStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING)
    {
      doHandshakeWrite(hsStatus);
    }
  }
  /**
   * {@inheritDoc}
   */
  public int getAppBufSize()
  {
    return appBufSize;
  }
@@ -250,7 +251,7 @@
  {
    int cipherKeySSF = 0;
    final String cipherString = sslEngine.getSession().getCipherSuite();
    for (final Map.Entry<String, Integer> mapEntry : cipherMap.entrySet())
    for (final Map.Entry<String, Integer> mapEntry : CIPHER_MAP.entrySet())
    {
      if (cipherString.indexOf(mapEntry.getKey()) >= 0)
      {
@@ -266,20 +267,6 @@
  /**
   * {@inheritDoc}
   */
  public boolean isOpen()
  {
    if (sslEngine.isInboundDone() || sslEngine.isOutboundDone())
    {
      return false;
    }
    return true;
  }
  /**
   * {@inheritDoc}
   */
  public boolean isSecure()
  {
    return true;
@@ -290,264 +277,9 @@
  /**
   * {@inheritDoc}
   */
  public synchronized int read(final ByteBuffer clearBuffer) throws IOException
  {
    SSLEngineResult.HandshakeStatus hsStatus;
    if (!reading)
    {
      appNetData.clear();
    }
    else
    {
      reading = false;
    }
    if (!socketChannel.isOpen())
    {
      return -1;
    }
    if (sslEngine.isInboundDone())
    {
      return -1;
    }
    do
    {
      final int wrappedBytes = socketChannel.read(appNetData);
      appNetData.flip();
      if (wrappedBytes == -1)
      {
        return -1;
      }
      hsStatus = sslEngine.getHandshakeStatus();
      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP)
      {
        doHandshakeRead(hsStatus);
      }
      if (wrappedBytes == 0)
      {
        return 0;
      }
      while (appNetData.hasRemaining())
      {
        appData.clear();
        final SSLEngineResult res = sslEngine.unwrap(appNetData, appData);
        appData.flip();
        if (res.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW)
        {
          appNetData.compact();
          reading = true;
          break;
        }
        else if (res.getStatus() != SSLEngineResult.Status.OK)
        {
          return -1;
        }
        hsStatus = sslEngine.getHandshakeStatus();
        if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
            || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP)
        {
          doHandshakeOp(hsStatus);
        }
        clearBuffer.put(appData);
      }
      hsStatus = sslEngine.getHandshakeStatus();
    }
    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP);
    return clearBuffer.position();
  }
  /**
   * {@inheritDoc}
   */
  public ByteChannel wrapChannel(final ByteChannel channel)
  {
    return this;
    return pimpl;
  }
  /**
   * {@inheritDoc}
   */
  public synchronized int write(final ByteBuffer clearData) throws IOException
  {
    if (!socketChannel.isOpen() || sslEngine.isOutboundDone())
    {
      throw new ClosedChannelException();
    }
    final int originalPosition = clearData.position();
    final int originalLimit = clearData.limit();
    final int length = originalLimit - originalPosition;
    if (length > sslBufferSize)
    {
      int pos = originalPosition;
      int lim = originalPosition + sslBufferSize;
      while (pos < originalLimit)
      {
        clearData.position(pos);
        clearData.limit(lim);
        writeInternal(clearData);
        pos = lim;
        lim = Math.min(originalLimit, pos + sslBufferSize);
      }
      return length;
    }
    else
    {
      return writeInternal(clearData);
    }
  }
  private void doHandshakeOp(SSLEngineResult.HandshakeStatus hsStatus)
      throws IOException
  {
    SSLEngineResult res;
    switch (hsStatus)
    {
    case NEED_TASK:
      hsStatus = doTasks();
      break;
    case NEED_WRAP:
      tempData.clear();
      netData.clear();
      res = sslEngine.wrap(tempData, netData);
      hsStatus = res.getHandshakeStatus();
      netData.flip();
      while (netData.hasRemaining())
      {
        socketChannel.write(netData);
      }
      hsStatus = sslEngine.getHandshakeStatus();
      return;
    default:
      return;
    }
  }
  private void doHandshakeRead(SSLEngineResult.HandshakeStatus hsStatus)
      throws IOException
  {
    do
    {
      doHandshakeOp(hsStatus);
      hsStatus = sslEngine.getHandshakeStatus();
    }
    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK);
  }
  private void doHandshakeUnwrap() throws IOException
  {
    netData.clear();
    tempData.clear();
    final int bytesRead = socketChannel.read(netData);
    if (bytesRead <= 0)
    {
      throw new ClosedChannelException();
    }
    else
    {
      sslEngine.unwrap(netData, tempData);
    }
  }
  private void doHandshakeWrite(SSLEngineResult.HandshakeStatus hsStatus)
      throws IOException
  {
    do
    {
      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
      {
        doHandshakeUnwrap();
      }
      else
      {
        doHandshakeOp(hsStatus);
      }
      hsStatus = sslEngine.getHandshakeStatus();
    }
    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
  }
  private SSLEngineResult.HandshakeStatus doTasks()
  {
    Runnable task;
    while ((task = sslEngine.getDelegatedTask()) != null)
    {
      task.run();
    }
    return sslEngine.getHandshakeStatus();
  }
  private int writeInternal(final ByteBuffer clearData) throws IOException
  {
    int totBytesSent = 0;
    SSLEngineResult.HandshakeStatus hsStatus;
    hsStatus = sslEngine.getHandshakeStatus();
    if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
    {
      doHandshakeWrite(hsStatus);
    }
    while (clearData.hasRemaining())
    {
      netData.clear();
      final SSLEngineResult res = sslEngine.wrap(clearData, netData);
      netData.flip();
      if (netData.remaining() == 0)
      {
        // wrap didn't produce any data from our clear buffer.
        // Throw exception to prevent looping.
        throw new SSLException("SSLEngine.wrap produced 0 bytes");
      }
      if (res.getStatus() != SSLEngineResult.Status.OK)
      {
        throw new ClosedChannelException();
      }
      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
      {
        doHandshakeWrite(hsStatus);
      }
      totBytesSent += socketChannel.write(netData);
    }
    return totBytesSent;
  }
}