From 1dfff197eadcf24823d7915e6eead2a850f679f9 Mon Sep 17 00:00:00 2001
From: Matthew Swift <matthew.swift@forgerock.com>
Date: Tue, 14 Feb 2012 16:09:28 +0000
Subject: [PATCH] Fix OPENDJ-420: Rare SSLExceptions while handling LDAPS connections and big LDAP searches

---
 opends/src/server/org/opends/server/extensions/SASLByteChannel.java |  579 ++++++++++++++++++++++++---------------------------------
 1 files changed, 245 insertions(+), 334 deletions(-)

diff --git a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
index 03ab350..bf6ece9 100644
--- a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
+++ b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
@@ -35,9 +35,6 @@
 import java.nio.channels.ByteChannel;
 import java.security.cert.Certificate;
 
-import javax.security.sasl.Sasl;
-import javax.security.sasl.SaslException;
-
 import org.opends.server.api.ClientConnection;
 
 
@@ -46,10 +43,229 @@
  * This class implements a SASL byte channel that can be used during
  * confidentiality and integrity.
  */
-public class SASLByteChannel implements ByteChannel, ConnectionSecurityProvider
+public final class SASLByteChannel implements ConnectionSecurityProvider
 {
 
   /**
+   * Private implementation.
+   */
+  private final class ByteChannelImpl implements ByteChannel
+  {
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override
+    public void close() throws IOException
+    {
+      synchronized (readLock)
+      {
+        synchronized (writeLock)
+        {
+          saslContext.dispose();
+          channel.close();
+        }
+      }
+    }
+
+
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override
+    public boolean isOpen()
+    {
+      return saslContext != null;
+    }
+
+
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override
+    public int read(final ByteBuffer unwrappedData) throws IOException
+    {
+      synchronized (readLock)
+      {
+        // Only read and unwrap new data if needed.
+        if (!recvUnwrappedBuffer.hasRemaining())
+        {
+          final int read = doRecvAndUnwrap();
+          if (read <= 0)
+          {
+            // No data read or end of stream.
+            return read;
+          }
+        }
+
+        // Copy available data.
+        final int startPos = unwrappedData.position();
+        if (recvUnwrappedBuffer.remaining() > unwrappedData.remaining())
+        {
+          // Unwrapped data does not fit in client buffer so copy one byte at a
+          // time: it's annoying that there is no easy way to do this with
+          // ByteBuffers.
+          while (unwrappedData.hasRemaining())
+          {
+            unwrappedData.put(recvUnwrappedBuffer.get());
+          }
+        }
+        else
+        {
+          // Unwrapped data fits client buffer so block copy.
+          unwrappedData.put(recvUnwrappedBuffer);
+        }
+        return unwrappedData.position() - startPos;
+      }
+    }
+
+
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override
+    public int write(final ByteBuffer unwrappedData) throws IOException
+    {
+      // This method will block until the entire message is sent.
+      final int bytesWritten = unwrappedData.remaining();
+
+      // Synchronized in order to prevent interleaving and reordering.
+      synchronized (writeLock)
+      {
+        // Write data in sendBufferSize segments.
+        while (unwrappedData.hasRemaining())
+        {
+          final int remaining = unwrappedData.remaining();
+          final int wrapSize = (remaining < sendUnwrappedBufferSize) ? remaining
+              : sendUnwrappedBufferSize;
+
+          final byte[] wrappedDataBytes;
+          if (unwrappedData.hasArray())
+          {
+            // Avoid extra copy if ByteBuffer is array based.
+            wrappedDataBytes = saslContext.wrap(unwrappedData.array(),
+                unwrappedData.arrayOffset(), wrapSize);
+            unwrappedData.position(unwrappedData.position() + wrapSize);
+          }
+          else
+          {
+            // Non-array based ByteBuffer, so copy.
+            unwrappedData.get(sendUnwrappedBytes, 0, wrapSize);
+            wrappedDataBytes = saslContext
+                .wrap(sendUnwrappedBytes, 0, wrapSize);
+            unwrappedData.position(unwrappedData.position() + wrapSize);
+          }
+
+          // Encode SASL packet: 4 byte length + wrapped data.
+          if (sendWrappedBuffer.capacity() < wrappedDataBytes.length + 4)
+          {
+            // Resize the send buffer.
+            sendWrappedBuffer = ByteBuffer
+                .allocate(wrappedDataBytes.length + 4);
+          }
+          sendWrappedBuffer.clear();
+          sendWrappedBuffer.putInt(wrappedDataBytes.length);
+          sendWrappedBuffer.put(wrappedDataBytes);
+          sendWrappedBuffer.flip();
+
+          // Write the SASL packet: our IO stack will block until all the data
+          // is written.
+          channel.write(sendWrappedBuffer);
+        }
+      }
+
+      return bytesWritten;
+    }
+
+
+
+    // Attempt to read and unwrap the next SASL packet.
+    private int doRecvAndUnwrap() throws IOException
+    {
+      // Read the encoded packet length first.
+      if (recvWrappedLength < 0)
+      {
+        final int read = channel.read(recvWrappedLengthBuffer);
+        if (read <= 0)
+        {
+          // No data read or end of stream.
+          return read;
+        }
+
+        if (recvWrappedLengthBuffer.hasRemaining())
+        {
+          // Unable to read the length, so no data available yet.
+          return 0;
+        }
+
+        // Decode the length and reset the length buffer.
+        recvWrappedLengthBuffer.flip();
+        recvWrappedLength = recvWrappedLengthBuffer.getInt();
+        recvWrappedLengthBuffer.clear();
+
+        // Check that the length is valid.
+        if (recvWrappedLength > recvWrappedBufferMaximumSize)
+        {
+          throw new IOException(
+              "Client sent a SASL packet specifying a length "
+                  + recvWrappedLength
+                  + " which exceeds the negotiated limit of "
+                  + recvWrappedBufferMaximumSize);
+        }
+
+        if (recvWrappedLength < 0)
+        {
+          throw new IOException(
+              "Client sent a SASL packet specifying a negative length "
+                  + recvWrappedLength);
+        }
+
+        // Prepare the recv buffer for reading.
+        recvWrappedBuffer.clear();
+        recvWrappedBuffer.limit(recvWrappedLength);
+      }
+
+      // Read wrapped data.
+      final int read = channel.read(recvWrappedBuffer);
+      if (read <= 0)
+      {
+        // No data read or end of stream.
+        return read;
+      }
+
+      if (recvWrappedBuffer.hasRemaining())
+      {
+        // Unable to read the full packet, so no data available yet.
+        return 0;
+      }
+
+      // The complete packet has been read, so unwrap it.
+      recvWrappedBuffer.flip();
+      final byte[] unwrappedDataBytes = saslContext.unwrap(
+          recvWrappedBuffer.array(), 0, recvWrappedLength);
+      recvWrappedLength = -1;
+
+      if (recvUnwrappedBuffer.capacity() < unwrappedDataBytes.length)
+      {
+        // Resize the recv buffer (this shouldn't ever happen).
+        recvUnwrappedBuffer = ByteBuffer.allocate(unwrappedDataBytes.length);
+      }
+
+      recvUnwrappedBuffer.clear();
+      recvUnwrappedBuffer.put(unwrappedDataBytes);
+      recvUnwrappedBuffer.flip();
+
+      return recvUnwrappedBuffer.remaining();
+    }
+
+  }
+
+
+
+  /**
    * Return a SASL byte channel instance created using the specified parameters.
    *
    * @param c
@@ -68,31 +284,23 @@
 
 
 
-  // The SASL context associated with the provider
-  private SASLContext saslContext;
-
-  // The byte channel associated with this provider.
-  private final RedirectingByteChannel channel;
-
-  // The number of bytes in the length buffer.
-  private static final int lengthSize = 4;
-
-  // Length of the buffer.
-  private int bufLength;
-
-  // The SASL mechanism name.
   private final String name;
+  private final ByteChannel channel;
+  private final ByteChannelImpl pimpl = new ByteChannelImpl();
+  private final SASLContext saslContext;
 
-  // Buffers used in reading and decoding (unwrap)
-  private final ByteBuffer readBuffer, decodeBuffer;
+  private ByteBuffer recvUnwrappedBuffer;
+  private final ByteBuffer recvWrappedBuffer;
+  private final int recvWrappedBufferMaximumSize;
+  private int recvWrappedLength = -1;
+  private final ByteBuffer recvWrappedLengthBuffer = ByteBuffer.allocate(4);
 
-  // How many bytes of the subsequent buffer is needed to complete a partially
-  // read buffer.
-  private int neededBytes = 0;
+  private final int sendUnwrappedBufferSize;
+  private final byte[] sendUnwrappedBytes;
+  private ByteBuffer sendWrappedBuffer;
 
-  // Used to not reset the buffer length size because the first 4 bytes of a
-  // buffer are not size bytes.
-  private boolean reading = false;
+  private final Object readLock = new Object();
+  private final Object writeLock = new Object();
 
 
 
@@ -112,10 +320,16 @@
   {
     this.name = name;
     this.saslContext = saslContext;
-    this.channel = connection.getChannel();
-    this.readBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
-    this.decodeBuffer = ByteBuffer.allocate(connection.getAppBufferSize()
-        + lengthSize);
+
+    channel = connection.getChannel();
+    recvWrappedBufferMaximumSize = saslContext.getMaxReceiveBufferSize();
+    sendUnwrappedBufferSize = saslContext.getMaxRawSendBufferSize();
+
+    recvWrappedBuffer = ByteBuffer.allocate(recvWrappedBufferMaximumSize);
+    recvUnwrappedBuffer = ByteBuffer.allocate(recvWrappedBufferMaximumSize);
+    recvUnwrappedBuffer.flip(); // Initially nothing has been received.
+    sendUnwrappedBytes = new byte[sendUnwrappedBufferSize];
+    sendWrappedBuffer = ByteBuffer.allocate(sendUnwrappedBufferSize + 64);
   }
 
 
@@ -124,21 +338,9 @@
    * {@inheritDoc}
    */
   @Override
-  public synchronized void close() throws IOException
+  public ByteChannel getChannel()
   {
-    saslContext.dispose();
-    saslContext = null;
-  }
-
-
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public int getAppBufSize()
-  {
-    return saslContext.getBufSize(Sasl.MAX_BUFFER);
+    return pimpl;
   }
 
 
@@ -180,300 +382,9 @@
    * {@inheritDoc}
    */
   @Override
-  public boolean isOpen()
-  {
-    return saslContext != null;
-  }
-
-
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
   public boolean isSecure()
   {
     return true;
   }
 
-
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public synchronized int read(final ByteBuffer clearDst) throws IOException
-  {
-    int bytesToRead = lengthSize;
-    if (reading)
-    {
-      bytesToRead = neededBytes;
-    }
-
-    final int readResult = readAll(readBuffer, bytesToRead);
-    if (readResult == -1)
-    {
-      return -1;
-    }
-
-    // The previous buffer read was not complete, the current
-    // buffer completes it.
-    if (neededBytes > 0 && readResult > 0)
-    {
-      return (processPartial(readResult, clearDst));
-    }
-
-    if (readResult == 0 && !reading)
-    {
-      return 0;
-    }
-
-    if (!reading)
-    {
-      bufLength = getBufLength(readBuffer);
-    }
-
-    reading = false;
-
-    // The buffer length is greater than what is there, save what is there,
-    // figure out how much more is needed and return.
-    if (bufLength > readBuffer.position())
-    {
-      neededBytes = bufLength - readBuffer.position() + lengthSize;
-      readBuffer.flip();
-      decodeBuffer.put(readBuffer);
-      readBuffer.clear();
-      return 0;
-    }
-    else
-    {
-      readBuffer.flip();
-      decodeBuffer.put(readBuffer);
-      final byte[] inBytes = decodeBuffer.array();
-      final byte[] clearBytes = saslContext.unwrap(inBytes, lengthSize,
-          bufLength);
-      decodeBuffer.clear();
-      clearDst.put(clearBytes);
-      readBuffer.clear();
-      return clearDst.position();
-    }
-  }
-
-
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public ByteChannel wrapChannel(final ByteChannel channel)
-  {
-    return this;
-  }
-
-
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public synchronized int write(final ByteBuffer clearSrc) throws IOException
-  {
-    final int sendBufSize = getAppBufSize();
-    final int srcLen = clearSrc.remaining();
-    final ByteBuffer sendBuffer = ByteBuffer.allocate(sendBufSize);
-
-    if (srcLen > sendBufSize)
-    {
-      final int oldPos = clearSrc.position();
-      int curPos = oldPos;
-      int curLimit = oldPos + sendBufSize;
-
-      while (curPos < srcLen)
-      {
-        clearSrc.position(curPos);
-        clearSrc.limit(curLimit);
-        sendBuffer.put(clearSrc);
-        writeChannel(wrap(sendBuffer.array(), clearSrc.remaining()));
-        curPos = curLimit;
-        curLimit = Math.min(srcLen, curPos + sendBufSize);
-      }
-      return srcLen;
-    }
-    else
-    {
-      sendBuffer.put(clearSrc);
-      return writeChannel(wrap(sendBuffer.array(), srcLen));
-    }
-  }
-
-
-
-  /**
-   * Return the clear buffer length as determined by processing the first 4
-   * bytes of the specified buffer.
-   *
-   * @param byteBuf
-   *          The buffer to examine the first 4 bytes of.
-   * @return The size of the clear buffer.
-   */
-  private int getBufLength(final ByteBuffer byteBuf)
-  {
-    int answer = 0;
-    for (int i = 0; i < lengthSize; i++)
-    {
-      final byte b = byteBuf.get(i);
-      answer <<= 8;
-      answer |= (b & 0xff);
-    }
-    return answer;
-  }
-
-
-
-  /**
-   * Finish processing a previous, partially read buffer using some, or, all of
-   * the bytes of the current buffer.
-   */
-  private int processPartial(final int readResult, final ByteBuffer clearDst)
-      throws IOException
-  {
-    readBuffer.flip();
-
-    // Use all of the bytes of the current buffer and read some more.
-    if (neededBytes > readResult)
-    {
-      neededBytes -= readResult;
-      decodeBuffer.put(readBuffer);
-      readBuffer.clear();
-      reading = false;
-      return 0;
-    }
-
-    // Use a portion of the current buffer.
-    for (; neededBytes > 0; neededBytes--)
-    {
-      decodeBuffer.put(readBuffer.get());
-    }
-
-    // Unwrap the now completed buffer.
-    final byte[] inBytes = decodeBuffer.array();
-    final byte[] clearBytes = saslContext
-        .unwrap(inBytes, lengthSize, bufLength);
-    clearDst.put(clearBytes);
-    decodeBuffer.clear();
-    readBuffer.compact();
-
-    // If the read buffer has bytes, these are a new buffer. Reset the
-    // buffer length to the new value.
-    if (readBuffer.position() != 0)
-    {
-      bufLength = getBufLength(readBuffer);
-      reading = true;
-    }
-    else
-    {
-      reading = false;
-    }
-    return clearDst.position();
-  }
-
-
-
-  /**
-   * Read from the socket channel into the specified byte buffer at least the
-   * number of bytes specified in the total parameter.
-   *
-   * @param byteBuf
-   *          The byte buffer to put the bytes in.
-   * @param total
-   *          The total number of bytes to read from the socket channel.
-   * @return The number of bytes read, 0 or -1.
-   * @throws IOException
-   *           If an error occurred reading the socket channel.
-   */
-  private int readAll(final ByteBuffer byteBuf, int total) throws IOException
-  {
-    while (channel.isOpen() && total > 0)
-    {
-      final int count = channel.read(byteBuf);
-      if (count == -1)
-      {
-        return -1;
-      }
-      if (count == 0)
-      {
-        return 0;
-      }
-      total -= count;
-    }
-    if (total > 0)
-    {
-      return -1;
-    }
-    else
-    {
-      return byteBuf.position();
-    }
-  }
-
-
-
-  /**
-   * Creates a buffer suitable to send to the client using the specified clear
-   * byte array and length of the bytes to wrap.
-   *
-   * @param clearBytes
-   *          The clear byte array to send to the client.
-   * @param len
-   *          The length of the bytes to wrap in the byte array.
-   * @throws SaslException
-   *           If the wrap of the bytes fails.
-   */
-  private ByteBuffer wrap(final byte[] clearBytes, final int len)
-      throws SaslException
-  {
-    final byte[] wrapBytes = saslContext.wrap(clearBytes, 0, len);
-    final byte[] outBytes = new byte[wrapBytes.length + lengthSize];
-
-    writeBufLen(outBytes, wrapBytes.length);
-    System.arraycopy(wrapBytes, 0, outBytes, lengthSize, wrapBytes.length);
-
-    return ByteBuffer.wrap(outBytes);
-  }
-
-
-
-  /**
-   * Writes the specified len parameter into the buffer in a form that can be
-   * sent over a network to the client.
-   *
-   * @param buf
-   *          The buffer to hold the length bytes.
-   * @param len
-   *          The length to encode.
-   */
-  private void writeBufLen(final byte[] buf, int len)
-  {
-    for (int i = 3; i >= 0; i--)
-    {
-      buf[i] = (byte) (len & 0xff);
-      len >>>= 8;
-    }
-  }
-
-
-
-  /**
-   * Write the specified byte buffer to the socket channel.
-   *
-   * @param buffer
-   *          The byte buffer to write to the socket channel.
-   * @return {@code true} if the byte buffer was successfully written to the
-   *         socket channel, or, {@code false} if not.
-   */
-  private int writeChannel(final ByteBuffer buffer) throws IOException
-  {
-    return channel.write(buffer);
-  }
-
 }

--
Gitblit v1.10.0