From b45e7fb00a64d2fd8897a485def4296d03c39b55 Mon Sep 17 00:00:00 2001
From: dugan <dugan@localhost>
Date: Wed, 18 Feb 2009 14:19:40 +0000
Subject: [PATCH] Commit SASL Phase2 changes. Issue 3805. Unit tests to follow later.

---
 opends/src/server/org/opends/server/util/StaticUtils.java                         |   15 --
 opends/src/server/org/opends/server/api/ClientConnection.java                     |   46 +++++++++
 opends/src/server/org/opends/server/extensions/GSSAPISASLMechanismHandler.java    |   15 ++
 opends/src/server/org/opends/server/extensions/DigestMD5SASLMechanismHandler.java |   13 ++
 opends/src/server/org/opends/server/extensions/TLSByteChannel.java                |    5 
 opends/src/server/org/opends/server/extensions/SASLByteChannel.java               |  146 +++++++++++++++++++---------
 opends/src/server/org/opends/server/protocols/ldap/LDAPClientConnection.java      |   40 +++++++
 7 files changed, 213 insertions(+), 67 deletions(-)

diff --git a/opends/src/server/org/opends/server/api/ClientConnection.java b/opends/src/server/org/opends/server/api/ClientConnection.java
index 2dc2816..bbe1a9e 100644
--- a/opends/src/server/org/opends/server/api/ClientConnection.java
+++ b/opends/src/server/org/opends/server/api/ClientConnection.java
@@ -31,6 +31,7 @@
 import java.net.InetAddress;
 import java.nio.ByteBuffer;
 import java.nio.channels.Selector;
+import java.nio.channels.SocketChannel;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
@@ -45,6 +46,7 @@
 import org.opends.server.core.PluginConfigManager;
 import org.opends.server.core.SearchOperation;
 import org.opends.server.core.networkgroups.NetworkGroup;
+import org.opends.server.extensions.RedirectingByteChannel;
 import org.opends.server.loggers.debug.DebugTracer;
 import org.opends.server.types.Attribute;
 import org.opends.server.types.AttributeType;
@@ -1390,6 +1392,50 @@
   }
 
 
+  /**
+   * Return the lowest level channel associated with a connection.
+   * This is normally the channel associated with the socket
+   * channel.
+   *
+   * @return The lowest level channel associated with a connection.
+   */
+  public RedirectingByteChannel getChannel() {
+    // By default, return null, which indicates that there should
+    // be no channel.  Subclasses should override this if
+    // they want to support a channel.
+    return null;
+  }
+
+
+
+  /**
+   * Return the Socket channel associated with a connection.
+   *
+   * @return The Socket channel associated with a connection.
+   */
+  public SocketChannel getSocketChannel() {
+    // By default, return null, which indicates that there should
+    // be no socket channel.  Subclasses should override this if
+    // they want to support a socket channel.
+    return null;
+  }
+
+
+
+  /**
+   * Return the largest application buffer size that should be used
+   * for a connection.
+   *
+   * @return The application buffer size.
+   */
+  public int getAppBufferSize() {
+    // By default, return 0, which indicates that there should
+    // be no application buffer size.  Subclasses should override
+    //this if they want to support a application buffer size.
+    return 0;
+  }
+
+
 
   /**
    * Retrieves the size limit that will be enforced for searches
diff --git a/opends/src/server/org/opends/server/extensions/DigestMD5SASLMechanismHandler.java b/opends/src/server/org/opends/server/extensions/DigestMD5SASLMechanismHandler.java
index cd2de57..2c9a519 100644
--- a/opends/src/server/org/opends/server/extensions/DigestMD5SASLMechanismHandler.java
+++ b/opends/src/server/org/opends/server/extensions/DigestMD5SASLMechanismHandler.java
@@ -22,7 +22,7 @@
  * CDDL HEADER END
  *
  *
- *      Copyright 2006-2008 Sun Microsystems, Inc.
+ *      Copyright 2006-2009 Sun Microsystems, Inc.
  */
 package org.opends.server.extensions;
 
@@ -163,6 +163,17 @@
          (SASLContext) clientConn.getSASLAuthStateInfo();
       if(saslContext == null) {
           try {
+            //If the connection is secure already (i.e., TLS), then make the
+            //receive buffers sizes match.
+            if(clientConn.isSecure()) {
+              HashMap<String, String>secProps =
+                                      new HashMap<String,String>(saslProps);
+              int maxBuf = clientConn.getAppBufferSize();
+              secProps.put(Sasl.MAX_BUFFER, Integer.toString(maxBuf));
+              saslContext = SASLContext.createSASLContext(secProps,
+                                      serverFQDN, SASL_MECHANISM_DIGEST_MD5,
+                                      identityMapper);
+            } else
               saslContext = SASLContext.createSASLContext(saslProps, serverFQDN,
                             SASL_MECHANISM_DIGEST_MD5, identityMapper);
           } catch (SaslException ex) {
diff --git a/opends/src/server/org/opends/server/extensions/GSSAPISASLMechanismHandler.java b/opends/src/server/org/opends/server/extensions/GSSAPISASLMechanismHandler.java
index 2ec2a10..75b29c7 100644
--- a/opends/src/server/org/opends/server/extensions/GSSAPISASLMechanismHandler.java
+++ b/opends/src/server/org/opends/server/extensions/GSSAPISASLMechanismHandler.java
@@ -383,8 +383,19 @@
     {
       try
       {
-        saslContext = SASLContext.createSASLContext(saslProps, serverFQDN,
-            SASL_MECHANISM_GSSAPI, identityMapper);
+        //If the connection is secure already (i.e., TLS), then make the
+        //receive buffers sizes match.
+        if(clientConn.isSecure()) {
+          HashMap<String, String>secProps =
+                                  new HashMap<String,String>(saslProps);
+          int maxBuf = clientConn.getAppBufferSize();
+          secProps.put(Sasl.MAX_BUFFER, Integer.toString(maxBuf));
+          saslContext = SASLContext.createSASLContext(secProps, serverFQDN,
+                                  SASL_MECHANISM_GSSAPI, identityMapper);
+        } else {
+          saslContext = SASLContext.createSASLContext(saslProps, serverFQDN,
+                                  SASL_MECHANISM_GSSAPI, identityMapper);
+        }
       }
       catch (SaslException ex)
       {
diff --git a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
index fc5a959..0d8555f 100644
--- a/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
+++ b/opends/src/server/org/opends/server/extensions/SASLByteChannel.java
@@ -22,22 +22,18 @@
  * CDDL HEADER END
  *
  *
- *      Copyright 2008 Sun Microsystems, Inc.
+ *      Copyright 2008-2009 Sun Microsystems, Inc.
  */
 
 package org.opends.server.extensions;
 
 import java.nio.channels.ByteChannel;
 import java.security.cert.Certificate;
-import static org.opends.server.loggers.debug.DebugLogger.*;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.ClosedChannelException;
-import java.nio.channels.SocketChannel;
 import javax.security.sasl.Sasl;
 import org.opends.server.api.ClientConnection;
-import org.opends.server.loggers.debug.DebugTracer;
-import org.opends.server.protocols.ldap.LDAPClientConnection;
 import org.opends.server.util.StaticUtils;
 
 /**
@@ -48,27 +44,34 @@
 public class
 SASLByteChannel implements ByteChannel, ConnectionSecurityProvider {
 
-    // The tracer object for the debug logger.
-    private static final DebugTracer TRACER = getTracer();
-
     // The client connection associated with this provider.
     private ClientConnection connection;
 
-    // The socket channel associated with this provider.
-    private SocketChannel sockChannel;
-
     // The SASL context associated with the provider
     private SASLContext saslContext;
 
+    // The byte channel associated with this provider.
+    private RedirectingByteChannel channel;
+
     // The number of bytes in the length buffer.
     private final int lengthSize = 4;
 
-    // A byte buffer used to hold the length of the clear buffer.
-    private ByteBuffer lengthBuf = ByteBuffer.allocate(lengthSize);
+    //Length of the buffer.
+    private int bufLength;
 
     // The SASL mechanism name.
     private String name;
 
+    //Buffers used in reading and decoding (unwrap)
+    private ByteBuffer readBuffer, decodeBuffer;
+
+    //How many bytes of the subsequent buffer is needed to complete a partially
+    //read buffer.
+    private int neededBytes = 0;
+
+    //Used to not reset the buffer length size because the first 4 bytes of a
+    //buffer are not size bytes.
+    private boolean reading = false;
 
     /**
      * Create a SASL byte channel with the specified parameters
@@ -87,7 +90,9 @@
       this.connection = connection;
       this.name = name;
       this.saslContext = saslContext;
-      this.sockChannel = ((LDAPClientConnection) connection).getSocketChannel();
+      this.channel = connection.getChannel();
+      this.readBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
+      this.decodeBuffer = ByteBuffer.allocate(connection.getAppBufferSize());
     }
 
     /**
@@ -96,7 +101,7 @@
      *
      * @param c A client connection associated with the instance.
      * @param name The name of the instance (SASL mechanism name).
-     * @param context A SASL context associaetd with the instance.
+     * @param context A SASL context associated with the instance.
      * @return A SASL byte channel.
      */
     public static SASLByteChannel
@@ -106,8 +111,44 @@
     }
 
     /**
-     * Read from the socket channel into the specified byte buffer the
-     * number of bytes specified in the total parameter.
+     * Finish processing a previous, partially read buffer using some, or, all
+     * of the bytes of the current buffer.
+     *
+     */
+    private int processPartial(int readResult, ByteBuffer clearDst)
+    throws IOException {
+      readBuffer.flip();
+      //Use all of the bytes of the current buffer and read some more.
+      if(neededBytes > readResult) {
+        neededBytes -= readResult;
+        decodeBuffer.put(readBuffer);
+        readBuffer.clear();
+        reading = false;
+        return 0;
+      }
+      //Use a portion of the current buffer.
+      for(;neededBytes > 0;neededBytes--) {
+        decodeBuffer.put(readBuffer.get());
+      }
+      //Unwrap the now completed buffer.
+      byte[] inBytes = decodeBuffer.array();
+      byte[]clearBytes = saslContext.unwrap(inBytes, lengthSize, bufLength);
+      clearDst.put(clearBytes);
+      decodeBuffer.clear();
+      readBuffer.compact();
+      //If the read buffer has bytes, these are a new buffer. Reset the
+      //buffer length to the new value.
+      if(readBuffer.position() != 0) {
+        bufLength = getBufLength(readBuffer);
+        reading = true;
+      } else
+        reading=false;
+      return clearDst.position();
+    }
+
+    /**
+     * Read from the socket channel into the specified byte buffer at least
+     * the number of bytes specified in the total parameter.
      *
      * @param byteBuf
      *          The byte buffer to put the bytes in.
@@ -121,8 +162,8 @@
     private int readAll(ByteBuffer byteBuf, int total) throws IOException
     {
       int count = 0;
-      while (sockChannel.isOpen() && total > 0) {
-        count = sockChannel.read(byteBuf);
+      while (channel.isOpen() && total > 0) {
+        count = channel.read(byteBuf);
         if (count == -1) return -1;
         if (count == 0) return 0;
         total -= count;
@@ -144,43 +185,53 @@
     private int getBufLength(ByteBuffer byteBuf)
     {
       int answer = 0;
-      byte[] buf = byteBuf.array();
 
       for (int i = 0; i < lengthSize; i++)
       {
+        byte b = byteBuf.get(i);
         answer <<= 8;
-        answer |= ((int) buf[i] & 0xff);
+        answer |= ((int) b & 0xff);
       }
       return answer;
     }
 
-
     /**
      * {@inheritDoc}
      */
     public int read(ByteBuffer clearDst) throws IOException {
-        int recvBufSize = getAppBufSize();
-        if(recvBufSize > clearDst.capacity())
-            return -1;
-        lengthBuf.clear();
-        int readResult = readAll(lengthBuf, lengthSize);
-        if (readResult == -1)
-            return -1;
-        else if (readResult == 0) return 0;
-        int bufLength = getBufLength(lengthBuf);
-        if (bufLength > recvBufSize) //TODO SASLPhase2 add message
-            return -1;
-        ByteBuffer readBuf = ByteBuffer.allocate(bufLength);
-        readResult = readAll(readBuf, bufLength);
-        if (readResult == -1)
-            return -1;
-        else if (readResult == 0) return 0;
-        byte[] inBytes = readBuf.array();
-        byte[] clearBytes = saslContext.unwrap(inBytes, 0, inBytes.length);
-        for(int i = 0; i < clearBytes.length; i++) {
-            clearDst.put(clearBytes[i]);
-        }
-        return clearDst.remaining();
+      int bytesToRead = lengthSize;
+      if(reading)
+        bytesToRead = neededBytes;
+      int readResult = readAll(readBuffer, bytesToRead);
+      if (readResult == -1)
+        return -1;
+      //The previous buffer read was not complete, the current
+      //buffer completes it.
+      if(neededBytes > 0 && readResult > 0)
+          return(processPartial(readResult, clearDst));
+      if(readResult == 0 && !reading) return 0;
+      if(!reading) {
+        bufLength = getBufLength(readBuffer);
+      }
+      reading=false;
+      //The buffer length is greater than what is there, save what is there,
+      //figure out how much more is needed and return.
+      if(bufLength > readBuffer.position()) {
+        neededBytes = bufLength - readBuffer.position() + 4;
+        readBuffer.flip();
+        decodeBuffer.put(readBuffer);
+        readBuffer.clear();
+        return 0;
+      } else {
+        readBuffer.flip();
+        decodeBuffer.put(readBuffer);
+        byte[] inBytes = decodeBuffer.array();
+        byte[]clearBytes = saslContext.unwrap(inBytes, lengthSize, bufLength);
+        decodeBuffer.clear();
+        clearDst.put(clearBytes);
+        readBuffer.clear();
+      }
+      return clearDst.position();
     }
 
     /**
@@ -258,12 +309,11 @@
      *         to the socket channel, or, {@code false} if not.
      */
     private int writeChannel(ByteBuffer buffer) throws IOException {
-        int bytesWritten = sockChannel.write(buffer);
+        int bytesWritten = channel.write(buffer);
         if (bytesWritten < 0)
             throw new ClosedChannelException();
         else if (bytesWritten == 0) {
-            if(!StaticUtils.writeWithTimeout(
-                    connection, sockChannel, buffer))
+            if(!StaticUtils.writeWithTimeout(connection, buffer))
                 throw new ClosedChannelException();
         }
         return bytesWritten;
@@ -288,7 +338,7 @@
      * {@inheritDoc}
      */
     public int getAppBufSize() {
-        return saslContext.getBufSize(Sasl.RAW_SEND_SIZE) + lengthSize;
+        return saslContext.getBufSize(Sasl.MAX_BUFFER);
     }
 
     /**
diff --git a/opends/src/server/org/opends/server/extensions/TLSByteChannel.java b/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
index ee7b92b..4437367 100644
--- a/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
+++ b/opends/src/server/org/opends/server/extensions/TLSByteChannel.java
@@ -205,6 +205,8 @@
         SSLEngineResult.HandshakeStatus hsStatus;
         if(!reading)
           appNetData.clear();
+        else
+          reading = false;
         if(!socketChannel.isOpen())
             return -1;
         if(sslEngine.isInboundDone())
@@ -369,8 +371,7 @@
                     throw new ClosedChannelException();
                 else if (bytesWritten == 0) {
                     int bytesSent = netData.remaining();
-                    if(!StaticUtils.writeWithTimeout(
-                            connection, socketChannel, netData))
+                    if(!StaticUtils.writeWithTimeout(connection, netData))
                         throw new ClosedChannelException();
                     totBytesSent += bytesSent;
                 } else
diff --git a/opends/src/server/org/opends/server/protocols/ldap/LDAPClientConnection.java b/opends/src/server/org/opends/server/protocols/ldap/LDAPClientConnection.java
index a5b1d24..bf9cd1d 100644
--- a/opends/src/server/org/opends/server/protocols/ldap/LDAPClientConnection.java
+++ b/opends/src/server/org/opends/server/protocols/ldap/LDAPClientConnection.java
@@ -195,6 +195,8 @@
 
   private ASN1ByteChannelReader asn1Reader;
 
+  private static int APPLICATION_BUFFER_SIZE = 4096;
+
   private final RedirectingByteChannel saslChannel;
   private final RedirectingByteChannel tlsChannel;
   private ConnectionSecurityProvider activeProvider = null;
@@ -286,7 +288,7 @@
     saslChannel =
         RedirectingByteChannel.getRedirectingByteChannel(tlsChannel);
     this.asn1Reader =
-        ASN1.getReader(saslChannel, 4096, connectionHandler
+        ASN1.getReader(saslChannel, APPLICATION_BUFFER_SIZE, connectionHandler
             .getMaxRequestSize());
 
     connectionID = DirectoryServer.newConnectionAccepted(this);
@@ -349,6 +351,7 @@
    * @return The socket channel that can be used to communicate with the
    *         client.
    */
+  @Override
   public SocketChannel getSocketChannel()
   {
     return clientChannel;
@@ -775,7 +778,8 @@
         writerBuffer.writer = ASN1.getWriter(saslChannel, appBufSize);
       }
       else
-        writerBuffer.writer = ASN1.getWriter(saslChannel, 4096);
+        writerBuffer.writer =
+                          ASN1.getWriter(saslChannel, APPLICATION_BUFFER_SIZE);
       cachedBuffers.set(writerBuffer);
     }
     try
@@ -795,7 +799,7 @@
 
         if (keepStats)
         {
-          // TODO SASLPhase2 hard-coded for now, flush probably needs to
+          // TODO hard-coded for now, flush probably needs to
           // return how many bytes were flushed.
           statTracker.updateMessageWritten(message, 4096);
         }
@@ -2552,6 +2556,19 @@
 
 
   /**
+   * Retrieves the TLS redirecting byte channel used in a LDAP client
+   * connection.
+   *
+   * @return The TLS redirecting byte channel.
+   */
+   @Override
+   public RedirectingByteChannel getChannel() {
+     return this.tlsChannel;
+   }
+
+
+
+  /**
    * {@inheritDoc}
    */
   @Override
@@ -2565,6 +2582,23 @@
 
 
 
+  /**
+   * Retrieves the application buffer size used in a LDAP client connection.
+   * If a active security provider is being used, then the application buffer
+   * size of that provider is returned.
+   *
+   * @return The application buffer size.
+   */
+  @Override
+  public int getAppBufferSize() {
+    if(activeProvider != null)
+      return activeProvider.getAppBufSize();
+    else
+      return APPLICATION_BUFFER_SIZE;
+  }
+
+
+
   private void initializeOperationMonitors()
   {
     this.addMonitor = OperationMonitor.getOperationMonitor(ADD);
diff --git a/opends/src/server/org/opends/server/util/StaticUtils.java b/opends/src/server/org/opends/server/util/StaticUtils.java
index b2311aa..6280b11 100644
--- a/opends/src/server/org/opends/server/util/StaticUtils.java
+++ b/opends/src/server/org/opends/server/util/StaticUtils.java
@@ -4170,21 +4170,14 @@
    * problem). If possible, it will attempt to use the selector returned
    * by the {@code ClientConnection.getWriteSelector} method, but it is
    * capable of working even if that method returns {@code null}. <BR>
-   * <BR>
-   * Note that this method has been written in a generic manner so that
-   * other connection security providers can use it to send data to the
-   * client, provided that the given buffer contains the appropriate
-   * pre-encoded information. <BR>
-   * <BR>
-   * Also note that the original position and limit values will not be
+   *
+   * Note that the original position and limit values will not be
    * preserved, so if that is important to the caller, then it should
    * record them before calling this method and restore them after it
    * returns.
    *
    * @param clientConnection
    *          The client connection to which the data is to be written.
-   * @param socketChannel
-   *          The socket channel over which to write the data.
    * @param buffer
    *          The data to be written to the client.
    * @return <CODE>true</CODE> if all the data in the provided buffer was
@@ -4198,10 +4191,10 @@
    *           client. The caller will be responsible for catching this
    *           and terminating the client connection.
    */
-  public static boolean writeWithTimeout(
-      ClientConnection clientConnection, SocketChannel socketChannel,
+  public static boolean writeWithTimeout(ClientConnection clientConnection,
       ByteBuffer buffer) throws IOException
   {
+    SocketChannel socketChannel = clientConnection.getSocketChannel();
     long startTime = System.currentTimeMillis();
     long waitTime = clientConnection.getMaxBlockedWriteTimeLimit();
     if (waitTime <= 0)

--
Gitblit v1.10.0