mirror of https://github.com/OpenIdentityPlatform/OpenDJ.git

Matthew Swift
14.09.2012 1dfff197eadcf24823d7915e6eead2a850f679f9
opends/src/server/org/opends/server/extensions/SASLByteChannel.java
@@ -35,9 +35,6 @@
import java.nio.channels.ByteChannel;
import java.security.cert.Certificate;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import org.opends.server.api.ClientConnection;
@@ -46,10 +43,229 @@
 * This class implements a SASL byte channel that can be used during
 * confidentiality and integrity.
 */
public class SASLByteChannel implements ByteChannel, ConnectionSecurityProvider
public final class SASLByteChannel implements ConnectionSecurityProvider
{
  /**
   * Private implementation.
   */
  private final class ByteChannelImpl implements ByteChannel
  {
    /**
     * {@inheritDoc}
     */
    @Override
    public void close() throws IOException
    {
      synchronized (readLock)
      {
        synchronized (writeLock)
        {
          saslContext.dispose();
          channel.close();
        }
      }
    }
    /**
     * {@inheritDoc}
     */
    @Override
    public boolean isOpen()
    {
      return saslContext != null;
    }
    /**
     * {@inheritDoc}
     */
    @Override
    public int read(final ByteBuffer unwrappedData) throws IOException
    {
      synchronized (readLock)
      {
        // Only read and unwrap new data if needed.
        if (!recvUnwrappedBuffer.hasRemaining())
        {
          final int read = doRecvAndUnwrap();
          if (read <= 0)
          {
            // No data read or end of stream.
            return read;
          }
        }
        // Copy available data.
        final int startPos = unwrappedData.position();
        if (recvUnwrappedBuffer.remaining() > unwrappedData.remaining())
        {
          // Unwrapped data does not fit in client buffer so copy one byte at a
          // time: it's annoying that there is no easy way to do this with
          // ByteBuffers.
          while (unwrappedData.hasRemaining())
          {
            unwrappedData.put(recvUnwrappedBuffer.get());
          }
        }
        else
        {
          // Unwrapped data fits client buffer so block copy.
          unwrappedData.put(recvUnwrappedBuffer);
        }
        return unwrappedData.position() - startPos;
      }
    }
    /**
     * {@inheritDoc}
     */
    @Override
    public int write(final ByteBuffer unwrappedData) throws IOException
    {
      // This method will block until the entire message is sent.
      final int bytesWritten = unwrappedData.remaining();
      // Synchronized in order to prevent interleaving and reordering.
      synchronized (writeLock)
      {
        // Write data in sendBufferSize segments.
        while (unwrappedData.hasRemaining())
        {
          final int remaining = unwrappedData.remaining();
          final int wrapSize = (remaining < sendUnwrappedBufferSize) ? remaining
              : sendUnwrappedBufferSize;
          final byte[] wrappedDataBytes;
          if (unwrappedData.hasArray())
          {
            // Avoid extra copy if ByteBuffer is array based.
            wrappedDataBytes = saslContext.wrap(unwrappedData.array(),
                unwrappedData.arrayOffset(), wrapSize);
            unwrappedData.position(unwrappedData.position() + wrapSize);
          }
          else
          {
            // Non-array based ByteBuffer, so copy.
            unwrappedData.get(sendUnwrappedBytes, 0, wrapSize);
            wrappedDataBytes = saslContext
                .wrap(sendUnwrappedBytes, 0, wrapSize);
            unwrappedData.position(unwrappedData.position() + wrapSize);
          }
          // Encode SASL packet: 4 byte length + wrapped data.
          if (sendWrappedBuffer.capacity() < wrappedDataBytes.length + 4)
          {
            // Resize the send buffer.
            sendWrappedBuffer = ByteBuffer
                .allocate(wrappedDataBytes.length + 4);
          }
          sendWrappedBuffer.clear();
          sendWrappedBuffer.putInt(wrappedDataBytes.length);
          sendWrappedBuffer.put(wrappedDataBytes);
          sendWrappedBuffer.flip();
          // Write the SASL packet: our IO stack will block until all the data
          // is written.
          channel.write(sendWrappedBuffer);
        }
      }
      return bytesWritten;
    }
    // Attempt to read and unwrap the next SASL packet.
    private int doRecvAndUnwrap() throws IOException
    {
      // Read the encoded packet length first.
      if (recvWrappedLength < 0)
      {
        final int read = channel.read(recvWrappedLengthBuffer);
        if (read <= 0)
        {
          // No data read or end of stream.
          return read;
        }
        if (recvWrappedLengthBuffer.hasRemaining())
        {
          // Unable to read the length, so no data available yet.
          return 0;
        }
        // Decode the length and reset the length buffer.
        recvWrappedLengthBuffer.flip();
        recvWrappedLength = recvWrappedLengthBuffer.getInt();
        recvWrappedLengthBuffer.clear();
        // Check that the length is valid.
        if (recvWrappedLength > recvWrappedBufferMaximumSize)
        {
          throw new IOException(
              "Client sent a SASL packet specifying a length "
                  + recvWrappedLength
                  + " which exceeds the negotiated limit of "
                  + recvWrappedBufferMaximumSize);
        }
        if (recvWrappedLength < 0)
        {
          throw new IOException(
              "Client sent a SASL packet specifying a negative length "
                  + recvWrappedLength);
        }
        // Prepare the recv buffer for reading.
        recvWrappedBuffer.clear();
        recvWrappedBuffer.limit(recvWrappedLength);
      }
      // Read wrapped data.
      final int read = channel.read(recvWrappedBuffer);
      if (read <= 0)
      {
        // No data read or end of stream.
        return read;
      }
      if (recvWrappedBuffer.hasRemaining())
      {
        // Unable to read the full packet, so no data available yet.
        return 0;
      }
      // The complete packet has been read, so unwrap it.
      recvWrappedBuffer.flip();
      final byte[] unwrappedDataBytes = saslContext.unwrap(
          recvWrappedBuffer.array(), 0, recvWrappedLength);
      recvWrappedLength = -1;
      if (recvUnwrappedBuffer.capacity() < unwrappedDataBytes.length)
      {
        // Resize the recv buffer (this shouldn't ever happen).
        recvUnwrappedBuffer = ByteBuffer.allocate(unwrappedDataBytes.length);
      }
      recvUnwrappedBuffer.clear();
      recvUnwrappedBuffer.put(unwrappedDataBytes);
      recvUnwrappedBuffer.flip();
      return recvUnwrappedBuffer.remaining();
    }
  }
  /**
   * Return a SASL byte channel instance created using the specified parameters.
   *
   * @param c
@@ -68,31 +284,23 @@
  // The SASL context associated with the provider
  private SASLContext saslContext;
  // The byte channel associated with this provider.
  private final RedirectingByteChannel channel;
  // The number of bytes in the length buffer.
  private static final int lengthSize = 4;
  // Length of the buffer.
  private int bufLength;
  // The SASL mechanism name.
  private final String name;
  private final ByteChannel channel;
  private final ByteChannelImpl pimpl = new ByteChannelImpl();
  private final SASLContext saslContext;
  // Buffers used in reading and decoding (unwrap)
  private final ByteBuffer readBuffer, decodeBuffer;
  private ByteBuffer recvUnwrappedBuffer;
  private final ByteBuffer recvWrappedBuffer;
  private final int recvWrappedBufferMaximumSize;
  private int recvWrappedLength = -1;
  private final ByteBuffer recvWrappedLengthBuffer = ByteBuffer.allocate(4);
  // How many bytes of the subsequent buffer is needed to complete a partially
  // read buffer.
  private int neededBytes = 0;
  private final int sendUnwrappedBufferSize;
  private final byte[] sendUnwrappedBytes;
  private ByteBuffer sendWrappedBuffer;
  // Used to not reset the buffer length size because the first 4 bytes of a
  // buffer are not size bytes.
  private boolean reading = false;
  private final Object readLock = new Object();
  private final Object writeLock = new Object();
@@ -112,10 +320,16 @@
  {
    this.name = name;
    this.saslContext = saslContext;
    this.channel = connection.getChannel();
    this.readBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
    this.decodeBuffer = ByteBuffer.allocate(connection.getAppBufferSize()
        + lengthSize);
    channel = connection.getChannel();
    recvWrappedBufferMaximumSize = saslContext.getMaxReceiveBufferSize();
    sendUnwrappedBufferSize = saslContext.getMaxRawSendBufferSize();
    recvWrappedBuffer = ByteBuffer.allocate(recvWrappedBufferMaximumSize);
    recvUnwrappedBuffer = ByteBuffer.allocate(recvWrappedBufferMaximumSize);
    recvUnwrappedBuffer.flip(); // Initially nothing has been received.
    sendUnwrappedBytes = new byte[sendUnwrappedBufferSize];
    sendWrappedBuffer = ByteBuffer.allocate(sendUnwrappedBufferSize + 64);
  }
@@ -124,21 +338,9 @@
   * {@inheritDoc}
   */
  @Override
  public synchronized void close() throws IOException
  public ByteChannel getChannel()
  {
    saslContext.dispose();
    saslContext = null;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public int getAppBufSize()
  {
    return saslContext.getBufSize(Sasl.MAX_BUFFER);
    return pimpl;
  }
@@ -180,300 +382,9 @@
   * {@inheritDoc}
   */
  @Override
  public boolean isOpen()
  {
    return saslContext != null;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public boolean isSecure()
  {
    return true;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public synchronized int read(final ByteBuffer clearDst) throws IOException
  {
    int bytesToRead = lengthSize;
    if (reading)
    {
      bytesToRead = neededBytes;
    }
    final int readResult = readAll(readBuffer, bytesToRead);
    if (readResult == -1)
    {
      return -1;
    }
    // The previous buffer read was not complete, the current
    // buffer completes it.
    if (neededBytes > 0 && readResult > 0)
    {
      return (processPartial(readResult, clearDst));
    }
    if (readResult == 0 && !reading)
    {
      return 0;
    }
    if (!reading)
    {
      bufLength = getBufLength(readBuffer);
    }
    reading = false;
    // The buffer length is greater than what is there, save what is there,
    // figure out how much more is needed and return.
    if (bufLength > readBuffer.position())
    {
      neededBytes = bufLength - readBuffer.position() + lengthSize;
      readBuffer.flip();
      decodeBuffer.put(readBuffer);
      readBuffer.clear();
      return 0;
    }
    else
    {
      readBuffer.flip();
      decodeBuffer.put(readBuffer);
      final byte[] inBytes = decodeBuffer.array();
      final byte[] clearBytes = saslContext.unwrap(inBytes, lengthSize,
          bufLength);
      decodeBuffer.clear();
      clearDst.put(clearBytes);
      readBuffer.clear();
      return clearDst.position();
    }
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public ByteChannel wrapChannel(final ByteChannel channel)
  {
    return this;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public synchronized int write(final ByteBuffer clearSrc) throws IOException
  {
    final int sendBufSize = getAppBufSize();
    final int srcLen = clearSrc.remaining();
    final ByteBuffer sendBuffer = ByteBuffer.allocate(sendBufSize);
    if (srcLen > sendBufSize)
    {
      final int oldPos = clearSrc.position();
      int curPos = oldPos;
      int curLimit = oldPos + sendBufSize;
      while (curPos < srcLen)
      {
        clearSrc.position(curPos);
        clearSrc.limit(curLimit);
        sendBuffer.put(clearSrc);
        writeChannel(wrap(sendBuffer.array(), clearSrc.remaining()));
        curPos = curLimit;
        curLimit = Math.min(srcLen, curPos + sendBufSize);
      }
      return srcLen;
    }
    else
    {
      sendBuffer.put(clearSrc);
      return writeChannel(wrap(sendBuffer.array(), srcLen));
    }
  }
  /**
   * Return the clear buffer length as determined by processing the first 4
   * bytes of the specified buffer.
   *
   * @param byteBuf
   *          The buffer to examine the first 4 bytes of.
   * @return The size of the clear buffer.
   */
  private int getBufLength(final ByteBuffer byteBuf)
  {
    int answer = 0;
    for (int i = 0; i < lengthSize; i++)
    {
      final byte b = byteBuf.get(i);
      answer <<= 8;
      answer |= (b & 0xff);
    }
    return answer;
  }
  /**
   * Finish processing a previous, partially read buffer using some, or, all of
   * the bytes of the current buffer.
   */
  private int processPartial(final int readResult, final ByteBuffer clearDst)
      throws IOException
  {
    readBuffer.flip();
    // Use all of the bytes of the current buffer and read some more.
    if (neededBytes > readResult)
    {
      neededBytes -= readResult;
      decodeBuffer.put(readBuffer);
      readBuffer.clear();
      reading = false;
      return 0;
    }
    // Use a portion of the current buffer.
    for (; neededBytes > 0; neededBytes--)
    {
      decodeBuffer.put(readBuffer.get());
    }
    // Unwrap the now completed buffer.
    final byte[] inBytes = decodeBuffer.array();
    final byte[] clearBytes = saslContext
        .unwrap(inBytes, lengthSize, bufLength);
    clearDst.put(clearBytes);
    decodeBuffer.clear();
    readBuffer.compact();
    // If the read buffer has bytes, these are a new buffer. Reset the
    // buffer length to the new value.
    if (readBuffer.position() != 0)
    {
      bufLength = getBufLength(readBuffer);
      reading = true;
    }
    else
    {
      reading = false;
    }
    return clearDst.position();
  }
  /**
   * Read from the socket channel into the specified byte buffer at least the
   * number of bytes specified in the total parameter.
   *
   * @param byteBuf
   *          The byte buffer to put the bytes in.
   * @param total
   *          The total number of bytes to read from the socket channel.
   * @return The number of bytes read, 0 or -1.
   * @throws IOException
   *           If an error occurred reading the socket channel.
   */
  private int readAll(final ByteBuffer byteBuf, int total) throws IOException
  {
    while (channel.isOpen() && total > 0)
    {
      final int count = channel.read(byteBuf);
      if (count == -1)
      {
        return -1;
      }
      if (count == 0)
      {
        return 0;
      }
      total -= count;
    }
    if (total > 0)
    {
      return -1;
    }
    else
    {
      return byteBuf.position();
    }
  }
  /**
   * Creates a buffer suitable to send to the client using the specified clear
   * byte array and length of the bytes to wrap.
   *
   * @param clearBytes
   *          The clear byte array to send to the client.
   * @param len
   *          The length of the bytes to wrap in the byte array.
   * @throws SaslException
   *           If the wrap of the bytes fails.
   */
  private ByteBuffer wrap(final byte[] clearBytes, final int len)
      throws SaslException
  {
    final byte[] wrapBytes = saslContext.wrap(clearBytes, 0, len);
    final byte[] outBytes = new byte[wrapBytes.length + lengthSize];
    writeBufLen(outBytes, wrapBytes.length);
    System.arraycopy(wrapBytes, 0, outBytes, lengthSize, wrapBytes.length);
    return ByteBuffer.wrap(outBytes);
  }
  /**
   * Writes the specified len parameter into the buffer in a form that can be
   * sent over a network to the client.
   *
   * @param buf
   *          The buffer to hold the length bytes.
   * @param len
   *          The length to encode.
   */
  private void writeBufLen(final byte[] buf, int len)
  {
    for (int i = 3; i >= 0; i--)
    {
      buf[i] = (byte) (len & 0xff);
      len >>>= 8;
    }
  }
  /**
   * Write the specified byte buffer to the socket channel.
   *
   * @param buffer
   *          The byte buffer to write to the socket channel.
   * @return {@code true} if the byte buffer was successfully written to the
   *         socket channel, or, {@code false} if not.
   */
  private int writeChannel(final ByteBuffer buffer) throws IOException
  {
    return channel.write(buffer);
  }
}