From f3f885952ca84422e710758abaf64c65705195a0 Mon Sep 17 00:00:00 2001
From: Matthew Swift <matthew.swift@forgerock.com>
Date: Fri, 10 Feb 2012 11:30:15 +0000
Subject: [PATCH] Preparation work for OPENDJ-420: Rare SSLExceptions while handling LDAPS connections and big LDAP searches

---
 opendj-sdk/opends/src/server/org/opends/server/extensions/TLSByteChannel.java |  422 +++++++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 345 insertions(+), 77 deletions(-)

diff --git a/opendj-sdk/opends/src/server/org/opends/server/extensions/TLSByteChannel.java b/opendj-sdk/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
index 53741c8..ad2b96e 100644
--- a/opendj-sdk/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
+++ b/opendj-sdk/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
@@ -53,86 +53,49 @@
 /**
  * A class that provides a TLS byte channel implementation.
  */
-public class TLSByteChannel implements ConnectionSecurityProvider
+public class TLSByteChannel implements ByteChannel, ConnectionSecurityProvider
 {
-  /**
-   * Private implementation.
-   */
-  private final class ByteChannelImpl implements ByteChannel
-  {
-
-    /**
-     * {@inheritDoc}
-     */
-    public int read(ByteBuffer dst) throws IOException
-    {
-      // TODO Auto-generated method stub
-      return 0;
-    }
-
-
-
-    /**
-     * {@inheritDoc}
-     */
-    public boolean isOpen()
-    {
-      // TODO Auto-generated method stub
-      return false;
-    }
-
-
-
-    /**
-     * {@inheritDoc}
-     */
-    public void close() throws IOException
-    {
-      // TODO Auto-generated method stub
-
-    }
-
-
-
-    /**
-     * {@inheritDoc}
-     */
-    public int write(ByteBuffer src) throws IOException
-    {
-      // TODO Auto-generated method stub
-      return 0;
-    }
-
-  }
-
-
-
   private static final DebugTracer TRACER = getTracer();
-  private final ByteChannel socketChannel;
-  private final SSLEngine sslEngine;
-  private final ByteChannelImpl pimpl = new ByteChannelImpl();
 
+  private final ByteChannel socketChannel;
+
+  private final SSLEngine sslEngine;
+
+  // read copy to buffer
+  private final ByteBuffer appData;
+
+  // read encrypted
+  private final ByteBuffer appNetData;
+
+  // Write encrypted
+  private final ByteBuffer netData;
+  private final ByteBuffer tempData;
+
+  private final int sslBufferSize;
+  private final int appBufSize;
+
+  private boolean reading = false;
 
   // Map of cipher phrases to effective key size (bits). Taken from the
   // following RFCs: 5289, 4346, 3268,4132 and 4162.
-  private static final Map<String, Integer> CIPHER_MAP;
+  private static final Map<String, Integer> cipherMap;
   static
   {
-    CIPHER_MAP = new LinkedHashMap<String, Integer>();
-    CIPHER_MAP.put("_WITH_AES_256_CBC_", new Integer(256));
-    CIPHER_MAP.put("_WITH_CAMELLIA_256_CBC_", new Integer(256));
-    CIPHER_MAP.put("_WITH_AES_256_GCM_", new Integer(256));
-    CIPHER_MAP.put("_WITH_3DES_EDE_CBC_", new Integer(112));
-    CIPHER_MAP.put("_WITH_AES_128_GCM_", new Integer(128));
-    CIPHER_MAP.put("_WITH_SEED_CBC_", new Integer(128));
-    CIPHER_MAP.put("_WITH_CAMELLIA_128_CBC_", new Integer(128));
-    CIPHER_MAP.put("_WITH_AES_128_CBC_", new Integer(128));
-    CIPHER_MAP.put("_WITH_IDEA_CBC_", new Integer(128));
-    CIPHER_MAP.put("_WITH_DES_CBC_", new Integer(56));
-    CIPHER_MAP.put("_WITH_RC2_CBC_40_", new Integer(40));
-    CIPHER_MAP.put("_WITH_RC4_40_", new Integer(40));
-    CIPHER_MAP.put("_WITH_DES40_CBC_", new Integer(40));
-    CIPHER_MAP.put("_WITH_NULL_", new Integer(0));
+    cipherMap = new LinkedHashMap<String, Integer>();
+    cipherMap.put("_WITH_AES_256_CBC_", new Integer(256));
+    cipherMap.put("_WITH_CAMELLIA_256_CBC_", new Integer(256));
+    cipherMap.put("_WITH_AES_256_GCM_", new Integer(256));
+    cipherMap.put("_WITH_3DES_EDE_CBC_", new Integer(112));
+    cipherMap.put("_WITH_AES_128_GCM_", new Integer(128));
+    cipherMap.put("_WITH_SEED_CBC_", new Integer(128));
+    cipherMap.put("_WITH_CAMELLIA_128_CBC_", new Integer(128));
+    cipherMap.put("_WITH_AES_128_CBC_", new Integer(128));
+    cipherMap.put("_WITH_IDEA_CBC_", new Integer(128));
+    cipherMap.put("_WITH_DES_CBC_", new Integer(56));
+    cipherMap.put("_WITH_RC2_CBC_40_", new Integer(40));
+    cipherMap.put("_WITH_RC4_40_", new Integer(40));
+    cipherMap.put("_WITH_DES40_CBC_", new Integer(40));
+    cipherMap.put("_WITH_NULL_", new Integer(0));
   };
 
 
@@ -174,11 +137,9 @@
     // avoid blocking new connections. Just remove for now to prevent
     // potential DoS attacks. SSL sessions will not be reused and some
     // cipher suites (such as Kerberos) will not work.
-
     // String hostName = socketChannel.socket().getInetAddress().getHostName();
     // int port = socketChannel.socket().getPort();
     // sslEngine = sslContext.createSSLEngine(hostName, port);
-
     sslEngine = sslContext.createSSLEngine();
     sslEngine.setUseClientMode(false);
     final Set<String> protocols = config.getSSLProtocol();
@@ -209,6 +170,44 @@
       sslEngine.setWantClientAuth(true);
       break;
     }
+
+    final SSLSession sslSession = sslEngine.getSession();
+    sslBufferSize = sslSession.getPacketBufferSize();
+    appBufSize = sslSession.getApplicationBufferSize();
+
+    appNetData = ByteBuffer.allocate(sslBufferSize);
+    netData = ByteBuffer.allocate(sslBufferSize);
+
+    appData = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
+    tempData = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  public synchronized void close() throws IOException
+  {
+    sslEngine.closeInbound();
+    sslEngine.closeOutbound();
+    final SSLEngineResult.HandshakeStatus hsStatus = sslEngine
+        .getHandshakeStatus();
+    if (hsStatus != SSLEngineResult.HandshakeStatus.FINISHED
+        && hsStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING)
+    {
+      doHandshakeWrite(hsStatus);
+    }
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  public int getAppBufSize()
+  {
+    return appBufSize;
   }
 
 
@@ -251,7 +250,7 @@
   {
     int cipherKeySSF = 0;
     final String cipherString = sslEngine.getSession().getCipherSuite();
-    for (final Map.Entry<String, Integer> mapEntry : CIPHER_MAP.entrySet())
+    for (final Map.Entry<String, Integer> mapEntry : cipherMap.entrySet())
     {
       if (cipherString.indexOf(mapEntry.getKey()) >= 0)
       {
@@ -267,6 +266,20 @@
   /**
    * {@inheritDoc}
    */
+  public boolean isOpen()
+  {
+    if (sslEngine.isInboundDone() || sslEngine.isOutboundDone())
+    {
+      return false;
+    }
+    return true;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
   public boolean isSecure()
   {
     return true;
@@ -277,9 +290,264 @@
   /**
    * {@inheritDoc}
    */
-  public ByteChannel wrapChannel(final ByteChannel channel)
+  public synchronized int read(final ByteBuffer clearBuffer) throws IOException
   {
-    return pimpl;
+    SSLEngineResult.HandshakeStatus hsStatus;
+    if (!reading)
+    {
+      appNetData.clear();
+    }
+    else
+    {
+      reading = false;
+    }
+
+    if (!socketChannel.isOpen())
+    {
+      return -1;
+    }
+
+    if (sslEngine.isInboundDone())
+    {
+      return -1;
+    }
+
+    do
+    {
+      final int wrappedBytes = socketChannel.read(appNetData);
+      appNetData.flip();
+
+      if (wrappedBytes == -1)
+      {
+        return -1;
+      }
+
+      hsStatus = sslEngine.getHandshakeStatus();
+
+      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP)
+      {
+        doHandshakeRead(hsStatus);
+      }
+
+      if (wrappedBytes == 0)
+      {
+        return 0;
+      }
+
+      while (appNetData.hasRemaining())
+      {
+        appData.clear();
+        final SSLEngineResult res = sslEngine.unwrap(appNetData, appData);
+        appData.flip();
+
+        if (res.getStatus() == SSLEngineResult.Status.BUFFER_UNDERFLOW)
+        {
+          appNetData.compact();
+          reading = true;
+          break;
+        }
+        else if (res.getStatus() != SSLEngineResult.Status.OK)
+        {
+          return -1;
+        }
+
+        hsStatus = sslEngine.getHandshakeStatus();
+        if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+            || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP)
+        {
+          doHandshakeOp(hsStatus);
+        }
+        clearBuffer.put(appData);
+      }
+      hsStatus = sslEngine.getHandshakeStatus();
+    }
+    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP);
+
+    return clearBuffer.position();
   }
 
+
+
+  /**
+   * {@inheritDoc}
+   */
+  public ByteChannel wrapChannel(final ByteChannel channel)
+  {
+    return this;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  public synchronized int write(final ByteBuffer clearData) throws IOException
+  {
+    if (!socketChannel.isOpen() || sslEngine.isOutboundDone())
+    {
+      throw new ClosedChannelException();
+    }
+
+    final int originalPosition = clearData.position();
+    final int originalLimit = clearData.limit();
+    final int length = originalLimit - originalPosition;
+
+    if (length > sslBufferSize)
+    {
+      int pos = originalPosition;
+      int lim = originalPosition + sslBufferSize;
+      while (pos < originalLimit)
+      {
+        clearData.position(pos);
+        clearData.limit(lim);
+        writeInternal(clearData);
+        pos = lim;
+        lim = Math.min(originalLimit, pos + sslBufferSize);
+      }
+      return length;
+    }
+    else
+    {
+      return writeInternal(clearData);
+    }
+  }
+
+
+
+  private void doHandshakeOp(SSLEngineResult.HandshakeStatus hsStatus)
+      throws IOException
+  {
+    SSLEngineResult res;
+    switch (hsStatus)
+    {
+    case NEED_TASK:
+      hsStatus = doTasks();
+      break;
+    case NEED_WRAP:
+      tempData.clear();
+      netData.clear();
+      res = sslEngine.wrap(tempData, netData);
+      hsStatus = res.getHandshakeStatus();
+      netData.flip();
+      while (netData.hasRemaining())
+      {
+        socketChannel.write(netData);
+      }
+      hsStatus = sslEngine.getHandshakeStatus();
+      return;
+    default:
+      return;
+    }
+  }
+
+
+
+  private void doHandshakeRead(SSLEngineResult.HandshakeStatus hsStatus)
+      throws IOException
+  {
+    do
+    {
+      doHandshakeOp(hsStatus);
+      hsStatus = sslEngine.getHandshakeStatus();
+    }
+    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK);
+  }
+
+
+
+  private void doHandshakeUnwrap() throws IOException
+  {
+    netData.clear();
+    tempData.clear();
+    final int bytesRead = socketChannel.read(netData);
+
+    if (bytesRead <= 0)
+    {
+      throw new ClosedChannelException();
+    }
+    else
+    {
+      sslEngine.unwrap(netData, tempData);
+    }
+  }
+
+
+
+  private void doHandshakeWrite(SSLEngineResult.HandshakeStatus hsStatus)
+      throws IOException
+  {
+    do
+    {
+      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
+      {
+        doHandshakeUnwrap();
+      }
+      else
+      {
+        doHandshakeOp(hsStatus);
+      }
+      hsStatus = sslEngine.getHandshakeStatus();
+    }
+    while (hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
+  }
+
+
+
+  private SSLEngineResult.HandshakeStatus doTasks()
+  {
+    Runnable task;
+    while ((task = sslEngine.getDelegatedTask()) != null)
+    {
+      task.run();
+    }
+    return sslEngine.getHandshakeStatus();
+  }
+
+
+
+  private int writeInternal(final ByteBuffer clearData) throws IOException
+  {
+    int totBytesSent = 0;
+    SSLEngineResult.HandshakeStatus hsStatus;
+    hsStatus = sslEngine.getHandshakeStatus();
+
+    if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
+        || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
+    {
+      doHandshakeWrite(hsStatus);
+    }
+
+    while (clearData.hasRemaining())
+    {
+      netData.clear();
+      final SSLEngineResult res = sslEngine.wrap(clearData, netData);
+      netData.flip();
+      if (netData.remaining() == 0)
+      {
+        // wrap didn't produce any data from our clear buffer.
+        // Throw exception to prevent looping.
+        throw new SSLException("SSLEngine.wrap produced 0 bytes");
+      }
+
+      if (res.getStatus() != SSLEngineResult.Status.OK)
+      {
+        throw new ClosedChannelException();
+      }
+
+      if (hsStatus == SSLEngineResult.HandshakeStatus.NEED_TASK
+          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP
+          || hsStatus == SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
+      {
+        doHandshakeWrite(hsStatus);
+      }
+      totBytesSent += socketChannel.write(netData);
+    }
+    return totBytesSent;
+  }
 }

--
Gitblit v1.10.0