From 20b30000c1266606a8cdae50a79982f415f11286 Mon Sep 17 00:00:00 2001
From: Ludovic Poitou <ludovic.poitou@forgerock.com>
Date: Wed, 22 Dec 2010 12:06:31 +0000
Subject: [PATCH] Ensure that correct Grizzly MemoryManager is used for SASL and ASN1 filters.

---
 opendj-sdk/sdk/src/com/sun/opends/sdk/tools/AuthRate.java |  188 +++++++++++++++++++++++++----------------------
 1 files changed, 100 insertions(+), 88 deletions(-)

diff --git a/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/AuthRate.java b/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/AuthRate.java
index e8b8f99..fa50545 100644
--- a/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/AuthRate.java
+++ b/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/AuthRate.java
@@ -27,13 +27,11 @@
 
 package com.sun.opends.sdk.tools;
 
-import com.sun.opends.sdk.util.RecursiveFutureResult;
 
-import org.glassfish.grizzly.TransportFactory;
-import org.opends.sdk.*;
-import org.opends.sdk.requests.*;
-import org.opends.sdk.responses.BindResult;
-import org.opends.sdk.responses.SearchResultEntry;
+
+import static com.sun.opends.sdk.messages.Messages.*;
+import static com.sun.opends.sdk.tools.ToolConstants.*;
+import static com.sun.opends.sdk.tools.Utils.filterExitCode;
 
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -44,28 +42,31 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 
-import static com.sun.opends.sdk.messages.Messages.*;
-import static com.sun.opends.sdk.tools.ToolConstants.*;
-import static com.sun.opends.sdk.tools.Utils.filterExitCode;
+import org.glassfish.grizzly.TransportFactory;
+import org.opends.sdk.*;
+import org.opends.sdk.requests.*;
+import org.opends.sdk.responses.BindResult;
+import org.opends.sdk.responses.SearchResultEntry;
+
+import com.sun.opends.sdk.util.RecursiveFutureResult;
+
+
 
 /**
- * A load generation tool that can be used to load a Directory Server with
- * Bind requests using one or more LDAP connections.
+ * A load generation tool that can be used to load a Directory Server with Bind
+ * requests using one or more LDAP connections.
  */
 public final class AuthRate extends ConsoleApplication
 {
   private final class BindPerformanceRunner extends PerformanceRunner
   {
-    private final AtomicLong searchWaitRecentTime = new AtomicLong();
-    private final AtomicInteger invalidCredRecentCount = new AtomicInteger();
-
     private final class BindStatsThread extends StatsThread
     {
       private final String[] extraColumn;
 
 
 
-      private BindStatsThread(boolean extraFieldRequired)
+      private BindStatsThread(final boolean extraFieldRequired)
       {
         super(extraFieldRequired ? new String[] { "bind time %" }
             : new String[0]);
@@ -93,15 +94,16 @@
     private final class BindUpdateStatsResultHandler extends
         UpdateStatsResultHandler<BindResult>
     {
-      private BindUpdateStatsResultHandler(long startTime)
+      private BindUpdateStatsResultHandler(final long startTime,
+          final AsynchronousConnection connection, final ConnectionWorker worker)
       {
-        super(startTime);
+        super(startTime, connection, worker);
       }
 
 
 
       @Override
-      public void handleErrorResult(ErrorResultException error)
+      public void handleErrorResult(final ErrorResultException error)
       {
         super.handleErrorResult(error);
 
@@ -112,7 +114,9 @@
       }
     }
 
-    private final class BindWorkerThread extends WorkerThread
+
+
+    private final class BindWorkerThread extends ConnectionWorker
     {
       private SearchRequest sr;
       private BindRequest br;
@@ -122,6 +126,7 @@
       private final ThreadLocal<Random> rng = new ThreadLocal<Random>()
       {
 
+        @Override
         protected Random initialValue()
         {
           return new Random();
@@ -130,6 +135,7 @@
       };
 
 
+
       private BindWorkerThread(final AsynchronousConnection connection,
           final ConnectionFactory connectionFactory)
       {
@@ -146,14 +152,14 @@
         if (dataSources != null)
         {
           data = DataSource.generateData(dataSources, data);
-          if(data.length == dataSources.length)
+          if (data.length == dataSources.length)
           {
-            Object[] newData = new Object[data.length + 1];
+            final Object[] newData = new Object[data.length + 1];
             System.arraycopy(data, 0, newData, 0, data.length);
             data = newData;
           }
         }
-        if(filter != null && baseDN != null)
+        if (filter != null && baseDN != null)
         {
           if (sr == null)
           {
@@ -163,8 +169,8 @@
             }
             else
             {
-              sr = Requests.newSearchRequest(String.format(baseDN, data), scope,
-                  String.format(filter, data), attributes);
+              sr = Requests.newSearchRequest(String.format(baseDN, data),
+                  scope, String.format(filter, data), attributes);
             }
             sr.setDereferenceAliasesPolicy(dereferencesAliasesPolicy);
           }
@@ -174,32 +180,32 @@
             sr.setName(String.format(baseDN, data));
           }
 
-          RecursiveFutureResult<SearchResultEntry, BindResult> future =
-              new RecursiveFutureResult<SearchResultEntry, BindResult>(
-                  new BindUpdateStatsResultHandler(startTime))
+          final RecursiveFutureResult<SearchResultEntry, BindResult> future =
+            new RecursiveFutureResult<SearchResultEntry, BindResult>(
+              new BindUpdateStatsResultHandler(startTime, connection, this))
+          {
+            @Override
+            protected FutureResult<? extends BindResult> chainResult(
+                final SearchResultEntry innerResult,
+                final ResultHandler<? super BindResult> resultHandler)
+                throws ErrorResultException
+            {
+              searchWaitRecentTime.getAndAdd(System.nanoTime() - startTime);
+              if (data == null)
               {
-                @Override
-                protected FutureResult<? extends BindResult> chainResult(
-                    SearchResultEntry innerResult,
-                    ResultHandler<? super BindResult> resultHandler)
-                    throws ErrorResultException
-                {
-                  searchWaitRecentTime.getAndAdd(System.nanoTime() - startTime);
-                  if(data == null)
-                  {
-                    data = new Object[1];
-                  }
-                  data[data.length-1] = innerResult.getName().toString();
-                  return performBind(connection, data, resultHandler);
-                }
-              };
+                data = new Object[1];
+              }
+              data[data.length - 1] = innerResult.getName().toString();
+              return performBind(connection, data, resultHandler);
+            }
+          };
           connection.searchSingleEntry(sr, future);
           return future;
         }
         else
         {
           return performBind(connection, data,
-              new BindUpdateStatsResultHandler(startTime));
+              new BindUpdateStatsResultHandler(startTime, connection, this));
         }
       }
 
@@ -229,13 +235,13 @@
 
         if (bindRequest instanceof SimpleBindRequest)
         {
-          SimpleBindRequest o = (SimpleBindRequest) bindRequest;
+          final SimpleBindRequest o = (SimpleBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfSimpleBindRequest(o);
           }
 
-          SimpleBindRequest sbr = (SimpleBindRequest) br;
+          final SimpleBindRequest sbr = (SimpleBindRequest) br;
           if (data != null && o.getName() != null)
           {
             sbr.setName(String.format(o.getName(), data));
@@ -251,13 +257,13 @@
         }
         else if (bindRequest instanceof DigestMD5SASLBindRequest)
         {
-          DigestMD5SASLBindRequest o = (DigestMD5SASLBindRequest) bindRequest;
+          final DigestMD5SASLBindRequest o = (DigestMD5SASLBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfDigestMD5SASLBindRequest(o);
           }
 
-          DigestMD5SASLBindRequest sbr = (DigestMD5SASLBindRequest) br;
+          final DigestMD5SASLBindRequest sbr = (DigestMD5SASLBindRequest) br;
           if (data != null)
           {
             if (o.getAuthenticationID() != null)
@@ -281,13 +287,13 @@
         }
         else if (bindRequest instanceof CRAMMD5SASLBindRequest)
         {
-          CRAMMD5SASLBindRequest o = (CRAMMD5SASLBindRequest) bindRequest;
+          final CRAMMD5SASLBindRequest o = (CRAMMD5SASLBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfCRAMMD5SASLBindRequest(o);
           }
 
-          CRAMMD5SASLBindRequest sbr = (CRAMMD5SASLBindRequest) br;
+          final CRAMMD5SASLBindRequest sbr = (CRAMMD5SASLBindRequest) br;
           if (data != null && o.getAuthenticationID() != null)
           {
             sbr.setAuthenticationID(String.format(o.getAuthenticationID(), data));
@@ -303,13 +309,13 @@
         }
         else if (bindRequest instanceof GSSAPISASLBindRequest)
         {
-          GSSAPISASLBindRequest o = (GSSAPISASLBindRequest) bindRequest;
+          final GSSAPISASLBindRequest o = (GSSAPISASLBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfGSSAPISASLBindRequest(o);
           }
 
-          GSSAPISASLBindRequest sbr = (GSSAPISASLBindRequest) br;
+          final GSSAPISASLBindRequest sbr = (GSSAPISASLBindRequest) br;
           if (data != null)
           {
             if (o.getAuthenticationID() != null)
@@ -333,13 +339,13 @@
         }
         else if (bindRequest instanceof ExternalSASLBindRequest)
         {
-          ExternalSASLBindRequest o = (ExternalSASLBindRequest) bindRequest;
+          final ExternalSASLBindRequest o = (ExternalSASLBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfExternalSASLBindRequest(o);
           }
 
-          ExternalSASLBindRequest sbr = (ExternalSASLBindRequest) br;
+          final ExternalSASLBindRequest sbr = (ExternalSASLBindRequest) br;
           if (data != null && o.getAuthorizationID() != null)
           {
             sbr.setAuthorizationID(String.format(o.getAuthorizationID(), data));
@@ -347,13 +353,13 @@
         }
         else if (bindRequest instanceof PlainSASLBindRequest)
         {
-          PlainSASLBindRequest o = (PlainSASLBindRequest) bindRequest;
+          final PlainSASLBindRequest o = (PlainSASLBindRequest) bindRequest;
           if (br == null)
           {
             br = Requests.copyOfPlainSASLBindRequest(o);
           }
 
-          PlainSASLBindRequest sbr = (PlainSASLBindRequest) br;
+          final PlainSASLBindRequest sbr = (PlainSASLBindRequest) br;
           if (data != null)
           {
             if (o.getAuthenticationID() != null)
@@ -382,6 +388,10 @@
 
 
 
+    private final AtomicLong searchWaitRecentTime = new AtomicLong();
+
+    private final AtomicInteger invalidCredRecentCount = new AtomicInteger();
+
     private String filter;
 
     private String baseDN;
@@ -407,20 +417,24 @@
 
 
     @Override
-    StatsThread newStatsThread()
+    ConnectionWorker newConnectionWorker(
+        final AsynchronousConnection connection,
+        final ConnectionFactory connectionFactory)
     {
-      return new BindStatsThread(filter != null && baseDN != null);
+      return new BindWorkerThread(connection, connectionFactory);
     }
 
 
 
     @Override
-    WorkerThread newWorkerThread(final AsynchronousConnection connection,
-        final ConnectionFactory connectionFactory)
+    StatsThread newStatsThread()
     {
-      return new BindWorkerThread(connection, connectionFactory);
+      return new BindStatsThread(filter != null && baseDN != null);
     }
   }
+
+
+
   /**
    * The main method for AuthRate tool.
    *
@@ -485,7 +499,6 @@
 
 
 
-
   private AuthRate(final InputStream in, final OutputStream out,
       final OutputStream err)
   {
@@ -583,10 +596,10 @@
   {
     // Create the command-line argument parser for use with this
     // program.
-    final LocalizableMessage toolDescription =
-        INFO_AUTHRATE_TOOL_DESCRIPTION.get();
-    final ArgumentParser argParser = new ArgumentParser(AuthRate.class
-        .getName(), toolDescription, false, true, 0, 0,
+    final LocalizableMessage toolDescription = INFO_AUTHRATE_TOOL_DESCRIPTION
+        .get();
+    final ArgumentParser argParser = new ArgumentParser(
+        AuthRate.class.getName(), toolDescription, false, true, 0, 0,
         "[filter format string] [attributes ...]");
 
     ConnectionFactoryProvider connectionFactoryProvider;
@@ -604,8 +617,7 @@
     try
     {
       TransportFactory.setInstance(new PerfToolTCPNIOTransportFactory());
-      connectionFactoryProvider =
-        new ConnectionFactoryProvider(argParser, this);
+      connectionFactoryProvider = new ConnectionFactoryProvider(argParser, this);
       runner = new BindPerformanceRunner(argParser, this);
 
       propertiesFileArgument = new StringArgument("propertiesFilePath", null,
@@ -627,34 +639,35 @@
       argParser.setUsageArgument(showUsage, getOutputStream());
 
       baseDN = new StringArgument("baseDN", OPTION_SHORT_BASEDN,
-          OPTION_LONG_BASEDN, false, false, true, INFO_BASEDN_PLACEHOLDER.get(),
-          null, null, INFO_SEARCHRATE_TOOL_DESCRIPTION_BASEDN.get());
+          OPTION_LONG_BASEDN, false, false, true,
+          INFO_BASEDN_PLACEHOLDER.get(), null, null,
+          INFO_SEARCHRATE_TOOL_DESCRIPTION_BASEDN.get());
       baseDN.setPropertyName(OPTION_LONG_BASEDN);
       argParser.addArgument(baseDN);
 
       searchScope = new MultiChoiceArgument<SearchScope>("searchScope", 's',
           "searchScope", false, true, INFO_SEARCH_SCOPE_PLACEHOLDER.get(),
-          SearchScope.values(), false, INFO_SEARCH_DESCRIPTION_SEARCH_SCOPE
-              .get());
+          SearchScope.values(), false,
+          INFO_SEARCH_DESCRIPTION_SEARCH_SCOPE.get());
       searchScope.setPropertyName("searchScope");
       searchScope.setDefaultValue(SearchScope.WHOLE_SUBTREE);
       argParser.addArgument(searchScope);
 
       dereferencePolicy = new MultiChoiceArgument<DereferenceAliasesPolicy>(
           "derefpolicy", 'a', "dereferencePolicy", false, true,
-          INFO_DEREFERENCE_POLICE_PLACEHOLDER.get(), DereferenceAliasesPolicy
-              .values(), false, INFO_SEARCH_DESCRIPTION_DEREFERENCE_POLICY
-              .get());
+          INFO_DEREFERENCE_POLICE_PLACEHOLDER.get(),
+          DereferenceAliasesPolicy.values(), false,
+          INFO_SEARCH_DESCRIPTION_DEREFERENCE_POLICY.get());
       dereferencePolicy.setPropertyName("dereferencePolicy");
       dereferencePolicy.setDefaultValue(DereferenceAliasesPolicy.NEVER);
       argParser.addArgument(dereferencePolicy);
 
       invalidCredPercent = new IntegerArgument("invalidPassword", 'I',
-        "invalidPassword", false, false, true, LocalizableMessage
-            .raw("{invalidPassword}"), 0, null, true, 0, true, 100,
-        LocalizableMessage
-            .raw("Percent of bind operations with simulated " +
-            "invalid password"));
+          "invalidPassword", false, false, true,
+          LocalizableMessage.raw("{invalidPassword}"), 0, null, true, 0, true,
+          100,
+          LocalizableMessage.raw("Percent of bind operations with simulated "
+              + "invalid password"));
       invalidCredPercent.setPropertyName("invalidPassword");
       argParser.addArgument(invalidCredPercent);
 
@@ -688,15 +701,15 @@
         return 0;
       }
 
-      connectionFactory =
-          connectionFactoryProvider.getConnectionFactory();
+      connectionFactory = connectionFactoryProvider.getConnectionFactory();
       runner.validate();
 
       runner.bindRequest = connectionFactoryProvider.getBindRequest();
-      if(runner.bindRequest == null)
+      if (runner.bindRequest == null)
       {
-        throw new ArgumentException(LocalizableMessage.raw(
-            "Authentication information must be provided to use this tool"));
+        throw new ArgumentException(
+            LocalizableMessage
+                .raw("Authentication information must be provided to use this tool"));
       }
     }
     catch (final ArgumentException ae)
@@ -739,11 +752,11 @@
 
     // Try it out to make sure the format string and data sources
     // match.
-    final Object[] data = DataSource.generateData(runner.getDataSources(),
-        null);
+    final Object[] data = DataSource
+        .generateData(runner.getDataSources(), null);
     try
     {
-      if(runner.baseDN != null && runner.filter != null)
+      if (runner.baseDN != null && runner.filter != null)
       {
         String.format(runner.filter, data);
         String.format(runner.baseDN, data);
@@ -759,4 +772,3 @@
     return runner.run(connectionFactory);
   }
 }
-

--
Gitblit v1.10.0