From 88be99a38d4f02a6227ef5a2b514f77c6f28e524 Mon Sep 17 00:00:00 2001
From: Matthew Swift <matthew.swift@forgerock.com>
Date: Wed, 08 Feb 2012 10:59:12 +0000
Subject: [PATCH] Preparation work for OPENDJ-420: Rare SSLExceptions while handling LDAPS connections and big LDAP searches

---
 opends/src/server/org/opends/server/extensions/SASLByteChannel.java |  708 +++++++++++++++++++++++++++++++++-------------------------
 1 files changed, 403 insertions(+), 305 deletions(-)

diff --git a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
index f7749d8..03ab350 100644
--- a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
+++ b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
@@ -28,354 +28,452 @@
 
 package org.opends.server.extensions;
 
-import java.nio.channels.ByteChannel;
-import java.security.cert.Certificate;
+
+
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.nio.channels.ByteChannel;
+import java.security.cert.Certificate;
+
 import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
 import org.opends.server.api.ClientConnection;
 
+
+
 /**
  * This class implements a SASL byte channel that can be used during
  * confidentiality and integrity.
- *
  */
-public class
-SASLByteChannel implements ByteChannel, ConnectionSecurityProvider {
+public class SASLByteChannel implements ByteChannel, ConnectionSecurityProvider
+{
 
-    // The client connection associated with this provider.
-    private ClientConnection connection;
+  /**
+   * Return a SASL byte channel instance created using the specified parameters.
+   *
+   * @param c
+   *          A client connection associated with the instance.
+   * @param name
+   *          The name of the instance (SASL mechanism name).
+   * @param context
+   *          A SASL context associated with the instance.
+   * @return A SASL byte channel.
+   */
+  public static SASLByteChannel getSASLByteChannel(final ClientConnection c,
+      final String name, final SASLContext context)
+  {
+    return new SASLByteChannel(c, name, context);
+  }
 
-    // The SASL context associated with the provider
-    private SASLContext saslContext;
 
-    // The byte channel associated with this provider.
-    private RedirectingByteChannel channel;
 
-    // The number of bytes in the length buffer.
-    private static final int lengthSize = 4;
+  // The SASL context associated with the provider
+  private SASLContext saslContext;
 
-    //Length of the buffer.
-    private int bufLength;
+  // The byte channel associated with this provider.
+  private final RedirectingByteChannel channel;
 
-    // The SASL mechanism name.
-    private String name;
+  // The number of bytes in the length buffer.
+  private static final int lengthSize = 4;
 
-    //Buffers used in reading and decoding (unwrap)
-    private ByteBuffer readBuffer, decodeBuffer;
+  // Length of the buffer.
+  private int bufLength;
 
-    //How many bytes of the subsequent buffer is needed to complete a partially
-    //read buffer.
-    private int neededBytes = 0;
+  // The SASL mechanism name.
+  private final String name;
 
-    //Used to not reset the buffer length size because the first 4 bytes of a
-    //buffer are not size bytes.
-    private boolean reading = false;
+  // Buffers used in reading and decoding (unwrap)
+  private final ByteBuffer readBuffer, decodeBuffer;
 
-    /**
-     * Create a SASL byte channel with the specified parameters
-     * that is capable of processing a confidentiality/integrity SASL
-     * connection.
-     *
-     * @param connection
-     *          The client connection to read/write the bytes.
-     * @param name
-     *          The SASL mechanism name.
-     * @param saslContext
-     *          The SASL context to process the data through.
-     */
-    private SASLByteChannel(ClientConnection connection, String name,
-        SASLContext saslContext) {
-      this.connection = connection;
-      this.name = name;
-      this.saslContext = saslContext;
-      this.channel = connection.getChannel();
-      this.readBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
-      this.decodeBuffer =
-                ByteBuffer.allocate(connection.getAppBufferSize() + lengthSize);
+  // How many bytes of the subsequent buffer is needed to complete a partially
+  // read buffer.
+  private int neededBytes = 0;
+
+  // Used to not reset the buffer length size because the first 4 bytes of a
+  // buffer are not size bytes.
+  private boolean reading = false;
+
+
+
+  /**
+   * Create a SASL byte channel with the specified parameters that is capable of
+   * processing a confidentiality/integrity SASL connection.
+   *
+   * @param connection
+   *          The client connection to read/write the bytes.
+   * @param name
+   *          The SASL mechanism name.
+   * @param saslContext
+   *          The SASL context to process the data through.
+   */
+  private SASLByteChannel(final ClientConnection connection, final String name,
+      final SASLContext saslContext)
+  {
+    this.name = name;
+    this.saslContext = saslContext;
+    this.channel = connection.getChannel();
+    this.readBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
+    this.decodeBuffer = ByteBuffer.allocate(connection.getAppBufferSize()
+        + lengthSize);
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public synchronized void close() throws IOException
+  {
+    saslContext.dispose();
+    saslContext = null;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public int getAppBufSize()
+  {
+    return saslContext.getBufSize(Sasl.MAX_BUFFER);
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public Certificate[] getClientCertificateChain()
+  {
+    return new Certificate[0];
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public String getName()
+  {
+    return name;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public int getSSF()
+  {
+    return saslContext.getSSF();
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public boolean isOpen()
+  {
+    return saslContext != null;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public boolean isSecure()
+  {
+    return true;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public synchronized int read(final ByteBuffer clearDst) throws IOException
+  {
+    int bytesToRead = lengthSize;
+    if (reading)
+    {
+      bytesToRead = neededBytes;
     }
 
-    /**
-     * Return a SASL byte channel instance created using the specified
-     * parameters.
-     *
-     * @param c A client connection associated with the instance.
-     * @param name The name of the instance (SASL mechanism name).
-     * @param context A SASL context associated with the instance.
-     * @return A SASL byte channel.
-     */
-    public static SASLByteChannel
-    getSASLByteChannel(ClientConnection c, String name,
-                          SASLContext context) {
-          return new SASLByteChannel(c, name, context);
+    final int readResult = readAll(readBuffer, bytesToRead);
+    if (readResult == -1)
+    {
+      return -1;
     }
 
-    /**
-     * Finish processing a previous, partially read buffer using some, or, all
-     * of the bytes of the current buffer.
-     *
-     */
-    private int processPartial(int readResult, ByteBuffer clearDst)
-    throws IOException {
+    // The previous buffer read was not complete, the current
+    // buffer completes it.
+    if (neededBytes > 0 && readResult > 0)
+    {
+      return (processPartial(readResult, clearDst));
+    }
+
+    if (readResult == 0 && !reading)
+    {
+      return 0;
+    }
+
+    if (!reading)
+    {
+      bufLength = getBufLength(readBuffer);
+    }
+
+    reading = false;
+
+    // The buffer length is greater than what is there, save what is there,
+    // figure out how much more is needed and return.
+    if (bufLength > readBuffer.position())
+    {
+      neededBytes = bufLength - readBuffer.position() + lengthSize;
       readBuffer.flip();
-      //Use all of the bytes of the current buffer and read some more.
-      if(neededBytes > readResult) {
-        neededBytes -= readResult;
-        decodeBuffer.put(readBuffer);
-        readBuffer.clear();
-        reading = false;
-        return 0;
-      }
-      //Use a portion of the current buffer.
-      for(;neededBytes > 0;neededBytes--) {
-        decodeBuffer.put(readBuffer.get());
-      }
-      //Unwrap the now completed buffer.
-      byte[] inBytes = decodeBuffer.array();
-      byte[]clearBytes = saslContext.unwrap(inBytes, lengthSize, bufLength);
-      clearDst.put(clearBytes);
+      decodeBuffer.put(readBuffer);
+      readBuffer.clear();
+      return 0;
+    }
+    else
+    {
+      readBuffer.flip();
+      decodeBuffer.put(readBuffer);
+      final byte[] inBytes = decodeBuffer.array();
+      final byte[] clearBytes = saslContext.unwrap(inBytes, lengthSize,
+          bufLength);
       decodeBuffer.clear();
-      readBuffer.compact();
-      //If the read buffer has bytes, these are a new buffer. Reset the
-      //buffer length to the new value.
-      if(readBuffer.position() != 0) {
-        bufLength = getBufLength(readBuffer);
-        reading = true;
-      } else
-        reading=false;
+      clearDst.put(clearBytes);
+      readBuffer.clear();
       return clearDst.position();
     }
+  }
 
-    /**
-     * Read from the socket channel into the specified byte buffer at least
-     * the number of bytes specified in the total parameter.
-     *
-     * @param byteBuf
-     *          The byte buffer to put the bytes in.
-     * @param total
-     *          The total number of bytes to read from the socket
-     *          channel.
-     * @return The number of bytes read, 0 or -1.
-     * @throws IOException
-     *           If an error occurred reading the socket channel.
-     */
-    private int readAll(ByteBuffer byteBuf, int total) throws IOException
-    {
-      while (channel.isOpen() && total > 0) {
-        int count = channel.read(byteBuf);
-        if (count == -1) return -1;
-        if (count == 0) return 0;
-        total -= count;
-      }
-      if (total > 0)
-        return -1;
-      else
-        return byteBuf.position();
-    }
 
-    /**
-     * Return the clear buffer length as determined by processing the
-     * first 4 bytes of the specified buffer.
-     *
-     * @param byteBuf
-     *          The buffer to examine the first 4 bytes of.
-     * @return The size of the clear buffer.
-     */
-    private int getBufLength(ByteBuffer byteBuf)
-    {
-      int answer = 0;
 
-      for (int i = 0; i < lengthSize; i++)
-      {
-        byte b = byteBuf.get(i);
-        answer <<= 8;
-        answer |= ((int) b & 0xff);
-      }
-      return answer;
-    }
-
-    /**
-     * {@inheritDoc}
-     */
+  /**
+   * {@inheritDoc}
+   */
   @Override
-    public synchronized int read(ByteBuffer clearDst) throws IOException {
-      int bytesToRead = lengthSize;
-      if(reading)
-        bytesToRead = neededBytes;
-      int readResult = readAll(readBuffer, bytesToRead);
-      if (readResult == -1)
-        return -1;
-      //The previous buffer read was not complete, the current
-      //buffer completes it.
-      if(neededBytes > 0 && readResult > 0)
-          return(processPartial(readResult, clearDst));
-      if(readResult == 0 && !reading) return 0;
-      if(!reading) {
-        bufLength = getBufLength(readBuffer);
+  public ByteChannel wrapChannel(final ByteChannel channel)
+  {
+    return this;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public synchronized int write(final ByteBuffer clearSrc) throws IOException
+  {
+    final int sendBufSize = getAppBufSize();
+    final int srcLen = clearSrc.remaining();
+    final ByteBuffer sendBuffer = ByteBuffer.allocate(sendBufSize);
+
+    if (srcLen > sendBufSize)
+    {
+      final int oldPos = clearSrc.position();
+      int curPos = oldPos;
+      int curLimit = oldPos + sendBufSize;
+
+      while (curPos < srcLen)
+      {
+        clearSrc.position(curPos);
+        clearSrc.limit(curLimit);
+        sendBuffer.put(clearSrc);
+        writeChannel(wrap(sendBuffer.array(), clearSrc.remaining()));
+        curPos = curLimit;
+        curLimit = Math.min(srcLen, curPos + sendBufSize);
       }
-      reading=false;
-      //The buffer length is greater than what is there, save what is there,
-      //figure out how much more is needed and return.
-      if(bufLength > readBuffer.position()) {
-        neededBytes = bufLength - readBuffer.position() + lengthSize;
-        readBuffer.flip();
-        decodeBuffer.put(readBuffer);
-        readBuffer.clear();
+      return srcLen;
+    }
+    else
+    {
+      sendBuffer.put(clearSrc);
+      return writeChannel(wrap(sendBuffer.array(), srcLen));
+    }
+  }
+
+
+
+  /**
+   * Return the clear buffer length as determined by processing the first 4
+   * bytes of the specified buffer.
+   *
+   * @param byteBuf
+   *          The buffer to examine the first 4 bytes of.
+   * @return The size of the clear buffer.
+   */
+  private int getBufLength(final ByteBuffer byteBuf)
+  {
+    int answer = 0;
+    for (int i = 0; i < lengthSize; i++)
+    {
+      final byte b = byteBuf.get(i);
+      answer <<= 8;
+      answer |= (b & 0xff);
+    }
+    return answer;
+  }
+
+
+
+  /**
+   * Finish processing a previous, partially read buffer using some, or, all of
+   * the bytes of the current buffer.
+   */
+  private int processPartial(final int readResult, final ByteBuffer clearDst)
+      throws IOException
+  {
+    readBuffer.flip();
+
+    // Use all of the bytes of the current buffer and read some more.
+    if (neededBytes > readResult)
+    {
+      neededBytes -= readResult;
+      decodeBuffer.put(readBuffer);
+      readBuffer.clear();
+      reading = false;
+      return 0;
+    }
+
+    // Use a portion of the current buffer.
+    for (; neededBytes > 0; neededBytes--)
+    {
+      decodeBuffer.put(readBuffer.get());
+    }
+
+    // Unwrap the now completed buffer.
+    final byte[] inBytes = decodeBuffer.array();
+    final byte[] clearBytes = saslContext
+        .unwrap(inBytes, lengthSize, bufLength);
+    clearDst.put(clearBytes);
+    decodeBuffer.clear();
+    readBuffer.compact();
+
+    // If the read buffer has bytes, these are a new buffer. Reset the
+    // buffer length to the new value.
+    if (readBuffer.position() != 0)
+    {
+      bufLength = getBufLength(readBuffer);
+      reading = true;
+    }
+    else
+    {
+      reading = false;
+    }
+    return clearDst.position();
+  }
+
+
+
+  /**
+   * Read from the socket channel into the specified byte buffer at least the
+   * number of bytes specified in the total parameter.
+   *
+   * @param byteBuf
+   *          The byte buffer to put the bytes in.
+   * @param total
+   *          The total number of bytes to read from the socket channel.
+   * @return The number of bytes read, 0 or -1.
+   * @throws IOException
+   *           If an error occurred reading the socket channel.
+   */
+  private int readAll(final ByteBuffer byteBuf, int total) throws IOException
+  {
+    while (channel.isOpen() && total > 0)
+    {
+      final int count = channel.read(byteBuf);
+      if (count == -1)
+      {
+        return -1;
+      }
+      if (count == 0)
+      {
         return 0;
-      } else {
-        readBuffer.flip();
-        decodeBuffer.put(readBuffer);
-        byte[] inBytes = decodeBuffer.array();
-        byte[]clearBytes = saslContext.unwrap(inBytes, lengthSize, bufLength);
-        decodeBuffer.clear();
-        clearDst.put(clearBytes);
-        readBuffer.clear();
       }
-      return clearDst.position();
+      total -= count;
     }
-
-    /**
-     * Writes the specified len parameter into the buffer in a form that
-     * can be sent over a network to the client.
-     *
-     * @param buf
-     *          The buffer to hold the length bytes.
-     * @param len
-     *          The length to encode.
-     */
-    private void writeBufLen(byte[] buf, int len)
+    if (total > 0)
     {
-      for (int i = 3; i >= 0; i--)
-      {
-        buf[i] = (byte) (len & 0xff);
-        len >>>= 8;
-      }
+      return -1;
     }
-
-    /**
-     * Creates a buffer suitable to send to the client using the
-     * specified clear byte array  and length of the bytes to
-     * wrap.
-     *
-     * @param clearBytes
-     *          The clear byte array to send to the client.
-     * @param len
-     *          The length of the bytes to wrap in the byte array.
-     * @throws SaslException
-     *           If the wrap of the bytes fails.
-     */
-    private ByteBuffer wrap(byte[] clearBytes, int len) throws IOException {
-      byte[] wrapBytes = saslContext.wrap(clearBytes, 0, len);
-      byte[] outBytes = new byte[wrapBytes.length + lengthSize];
-      writeBufLen(outBytes, wrapBytes.length);
-      System.arraycopy(wrapBytes, 0, outBytes, lengthSize, wrapBytes.length);
-      return ByteBuffer.wrap(outBytes);
+    else
+    {
+      return byteBuf.position();
     }
+  }
 
 
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public synchronized int write(ByteBuffer clearSrc) throws IOException {
-        int sendBufSize = getAppBufSize();
-        int srcLen = clearSrc.remaining();
-        ByteBuffer sendBuffer = ByteBuffer.allocate(sendBufSize);
-        if (srcLen > sendBufSize) {
-            int oldPos = clearSrc.position();
-            int curPos = oldPos;
-            int curLimit = oldPos + sendBufSize;
-            while (curPos < srcLen) {
-                clearSrc.position(curPos);
-                clearSrc.limit(curLimit);
-                sendBuffer.put(clearSrc);
-                writeChannel(wrap(sendBuffer.array(), clearSrc.remaining()));
-                curPos = curLimit;
-                curLimit = Math.min(srcLen, curPos + sendBufSize);
-            }
-            return srcLen;
-        } else {
-            sendBuffer.put(clearSrc);
-            return writeChannel(wrap(sendBuffer.array() ,srcLen));
-        }
+
+  /**
+   * Creates a buffer suitable to send to the client using the specified clear
+   * byte array and length of the bytes to wrap.
+   *
+   * @param clearBytes
+   *          The clear byte array to send to the client.
+   * @param len
+   *          The length of the bytes to wrap in the byte array.
+   * @throws SaslException
+   *           If the wrap of the bytes fails.
+   */
+  private ByteBuffer wrap(final byte[] clearBytes, final int len)
+      throws SaslException
+  {
+    final byte[] wrapBytes = saslContext.wrap(clearBytes, 0, len);
+    final byte[] outBytes = new byte[wrapBytes.length + lengthSize];
+
+    writeBufLen(outBytes, wrapBytes.length);
+    System.arraycopy(wrapBytes, 0, outBytes, lengthSize, wrapBytes.length);
+
+    return ByteBuffer.wrap(outBytes);
+  }
+
+
+
+  /**
+   * Writes the specified len parameter into the buffer in a form that can be
+   * sent over a network to the client.
+   *
+   * @param buf
+   *          The buffer to hold the length bytes.
+   * @param len
+   *          The length to encode.
+   */
+  private void writeBufLen(final byte[] buf, int len)
+  {
+    for (int i = 3; i >= 0; i--)
+    {
+      buf[i] = (byte) (len & 0xff);
+      len >>>= 8;
     }
+  }
 
 
-    /**
-     * Write the specified byte buffer to the socket channel.
-     *
-     * @param buffer
-     *          The byte buffer to write to the socket channel.
-     * @return {@code true} if the byte buffer was successfully written
-     *         to the socket channel, or, {@code false} if not.
-     */
-    private int writeChannel(ByteBuffer buffer) throws IOException {
-        return channel.write(buffer);
-      }
 
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public synchronized void close() throws IOException {
-        saslContext.dispose();
-        saslContext=null;
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public boolean isOpen() {
-        return saslContext != null;
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public int getAppBufSize() {
-        return saslContext.getBufSize(Sasl.MAX_BUFFER);
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public Certificate[] getClientCertificateChain() {
-        return new Certificate[0];
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public int getSSF() {
-        return saslContext.getSSF();
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public ByteChannel wrapChannel(ByteChannel channel) {
-        return this;
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public String getName() {
-        return name;
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-  @Override
-    public boolean isSecure() {
-        return true;
-    }
+  /**
+   * Write the specified byte buffer to the socket channel.
+   *
+   * @param buffer
+   *          The byte buffer to write to the socket channel.
+   * @return {@code true} if the byte buffer was successfully written to the
+   *         socket channel, or, {@code false} if not.
+   */
+  private int writeChannel(final ByteBuffer buffer) throws IOException
+  {
+    return channel.write(buffer);
+  }
 
 }

--
Gitblit v1.10.0