From ab5716f792d7eab2fb150fc2b6a72c346466486e Mon Sep 17 00:00:00 2001
From: Yannick Lecaillez <yannick.lecaillez@forgerock.com>
Date: Fri, 25 Nov 2016 16:41:56 +0000
Subject: [PATCH] * enableSasl() now returns a boolean (like enableTLS()) * Use connection's attribute to store the SaslServer (like enableTLS()) to ensure thread-safe visibility

---
 opendj-core/clirr-ignored-api-changes.xml                                       |    2 +-
 opendj-core/src/main/java/org/forgerock/opendj/ldap/LDAPClientContext.java      |    7 +++++--
 opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/SaslFilter.java       |   26 ++++++++++++++++++++------
 opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/LDAPServerFilter.java |   13 +++++++------
 4 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/opendj-core/clirr-ignored-api-changes.xml b/opendj-core/clirr-ignored-api-changes.xml
index b19d55f..4466c16 100644
--- a/opendj-core/clirr-ignored-api-changes.xml
+++ b/opendj-core/clirr-ignored-api-changes.xml
@@ -148,7 +148,7 @@
   <difference>
     <className>org/forgerock/opendj/ldap/LDAPClientContext</className>
     <differenceType>7012</differenceType>
-    <method>void enableSASL(javax.security.sasl.SaslServer)</method>
+    <method>boolean enableSASL(javax.security.sasl.SaslServer)</method>
     <justification>Simplify management of security layer</justification>
   </difference>
   <difference>
diff --git a/opendj-core/src/main/java/org/forgerock/opendj/ldap/LDAPClientContext.java b/opendj-core/src/main/java/org/forgerock/opendj/ldap/LDAPClientContext.java
index 3fec4e2..c1511c3 100644
--- a/opendj-core/src/main/java/org/forgerock/opendj/ldap/LDAPClientContext.java
+++ b/opendj-core/src/main/java/org/forgerock/opendj/ldap/LDAPClientContext.java
@@ -156,7 +156,10 @@
      * Installs the SASL security layer on the underlying connection.
      *
      * @param saslServer
-     *            The {@code SaslServer} which should be used to secure the conneciton.
+     *            The {@code SaslServer} which should be used to secure the connection.
+     * @return {@code true} if the SASL filter has been enabled, {@code false} if it was already enabled.
+     * @throws NullPointerException
+     *             if saslServer is null
      */
-    void enableSASL(SaslServer saslServer);
+    boolean enableSASL(SaslServer saslServer);
 }
diff --git a/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/LDAPServerFilter.java b/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/LDAPServerFilter.java
index 5bb9307..55c95a3 100644
--- a/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/LDAPServerFilter.java
+++ b/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/LDAPServerFilter.java
@@ -307,7 +307,6 @@
         private final Connection<?> connection;
         private volatile boolean isClosed;
         private final List<LDAPClientContextEventListener> connectionEventListeners = new LinkedList<>();
-        private SaslServer saslServer;
         private GrizzlyBackpressureSubscription downstream;
 
         private ClientConnectionImpl(final Connection<?> connection) {
@@ -386,21 +385,22 @@
         }
 
         @Override
-        public void enableSASL(final SaslServer saslServer) {
+        public boolean enableSASL(final SaslServer saslServer) {
             Reject.ifNull(saslServer, "saslServer must not be null");
             synchronized (this) {
                 if (filterExists(SaslFilter.class)) {
                     // FIXME: The current saslServer must be replaced with the new one
-                    return;
+                    return false;
                 }
-                this.saslServer = saslServer;
-                installFilter(new SaslFilter(saslServer));
+                SaslFilter.setSaslServer(connection, saslServer);
+                installFilter(new SaslFilter());
+                return true;
             }
         }
 
         @Override
         public SaslServer getSASLServer() {
-            return saslServer;
+            return SaslFilter.getSaslServer(connection);
         }
 
         @Override
@@ -432,6 +432,7 @@
         }
 
         private int getSaslSecurityStrengthFactor() {
+            final SaslServer saslServer = getSASLServer();
             if (saslServer == null) {
                 return 0;
             }
diff --git a/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/SaslFilter.java b/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/SaslFilter.java
index 23ad5a0..d32e01c 100644
--- a/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/SaslFilter.java
+++ b/opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/SaslFilter.java
@@ -21,6 +21,9 @@
 import javax.security.sasl.SaslServer;
 
 import org.glassfish.grizzly.Buffer;
+import org.glassfish.grizzly.Grizzly;
+import org.glassfish.grizzly.attributes.Attribute;
+import org.glassfish.grizzly.attributes.AttributeStorage;
 import org.glassfish.grizzly.filterchain.BaseFilter;
 import org.glassfish.grizzly.filterchain.FilterChainContext;
 import org.glassfish.grizzly.filterchain.NextAction;
@@ -30,21 +33,30 @@
 
 final class SaslFilter extends BaseFilter {
 
+    private static final Attribute<SaslServer> SASL_SERVER_ATTR = Grizzly.DEFAULT_ATTRIBUTE_BUILDER
+            .createAttribute(SaslFilter.class + ".sasl-server");
+
+    static void setSaslServer(final AttributeStorage storage, final SaslServer server) {
+        SASL_SERVER_ATTR.set(storage, server);
+    }
+
+    static SaslServer getSaslServer(final AttributeStorage storage) {
+        return SASL_SERVER_ATTR.get(storage);
+    }
+
     /** Used to check if negotiated QOP is confidentiality or integrity. */
     static final String SASL_AUTH_CONFIDENTIALITY = "auth-conf";
 
     static final String SASL_AUTH_INTEGRITY = "auth-int";
 
     private static final int INT_SIZE = 4;
-    private final SaslServer saslServer;
     private final boolean enableAfterNextMessage;
 
-    SaslFilter(final SaslServer saslServer) {
-        this(saslServer, true);
+    SaslFilter() {
+        this(true);
     }
 
-    private SaslFilter(final SaslServer saslServer, final boolean enableAfterNextMessage) {
-        this.saslServer = saslServer;
+    private SaslFilter(final boolean enableAfterNextMessage) {
         this.enableAfterNextMessage = enableAfterNextMessage;
     }
 
@@ -69,6 +81,7 @@
     }
 
     private Buffer unwrap(final FilterChainContext ctx, final Buffer buffer, final int length) throws SaslException {
+        final SaslServer saslServer = getSaslServer(ctx.getConnection());
         if (buffer.hasArray()) {
             return Buffers.wrap(ctx.getMemoryManager(),
                     saslServer.unwrap(buffer.array(), buffer.arrayOffset() + buffer.position(), length));
@@ -93,7 +106,7 @@
     @Override
     public NextAction handleWrite(final FilterChainContext ctx) throws IOException {
         if (enableAfterNextMessage) {
-            ctx.getFilterChain().set(ctx.getFilterIdx(), new SaslFilter(saslServer, false));
+            ctx.getFilterChain().set(ctx.getFilterIdx(), new SaslFilter(false));
             return ctx.getInvokeAction();
         }
         final Buffer message = ctx.getMessage();
@@ -103,6 +116,7 @@
     }
 
     private Buffer wrap(final FilterChainContext ctx, final Buffer buffer) throws SaslException {
+        final SaslServer saslServer = getSaslServer(ctx.getConnection());
         final Buffer contentBuffer;
         if (buffer.hasArray()) {
             contentBuffer = Buffers.wrap(ctx.getMemoryManager(),

--
Gitblit v1.10.0