From 20b30000c1266606a8cdae50a79982f415f11286 Mon Sep 17 00:00:00 2001
From: Ludovic Poitou <ludovic.poitou@forgerock.com>
Date: Wed, 22 Dec 2010 12:06:31 +0000
Subject: [PATCH] Ensure that correct Grizzly MemoryManager is used for SASL and ASN1 filters.

---
 opendj-sdk/sdk/src/com/sun/opends/sdk/tools/PerformanceRunner.java |  321 +++++++++++++++++++++++++++++++++++++++++-----------
 1 files changed, 250 insertions(+), 71 deletions(-)

diff --git a/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/PerformanceRunner.java b/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/PerformanceRunner.java
index b4c5300..9f4ab43 100644
--- a/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/PerformanceRunner.java
+++ b/opendj-sdk/sdk/src/com/sun/opends/sdk/tools/PerformanceRunner.java
@@ -33,16 +33,18 @@
 import java.lang.management.GarbageCollectorMXBean;
 import java.lang.management.ManagementFactory;
 import java.util.*;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
-import com.sun.opends.sdk.util.StaticUtils;
 import org.opends.sdk.*;
+import org.opends.sdk.responses.BindResult;
 import org.opends.sdk.responses.ExtendedResult;
 import org.opends.sdk.responses.Result;
 
 import com.sun.opends.sdk.tools.AuthenticatedConnectionFactory.AuthenticatedAsynchronousConnection;
+import com.sun.opends.sdk.util.StaticUtils;
 
 
 
@@ -51,6 +53,170 @@
  */
 abstract class PerformanceRunner implements ConnectionEventListener
 {
+  abstract class ConnectionWorker
+  {
+    private final AtomicInteger operationsInFlight = new AtomicInteger();
+
+    private volatile int count;
+
+    private final AsynchronousConnection staticConnection;
+
+    private final ConnectionFactory connectionFactory;
+
+    private final CountDownLatch latch = new CountDownLatch(1);
+
+
+
+    ConnectionWorker(final AsynchronousConnection staticConnection,
+        final ConnectionFactory connectionFactory)
+    {
+      this.staticConnection = staticConnection;
+      this.connectionFactory = connectionFactory;
+    }
+
+
+
+    public void operationCompleted(final AsynchronousConnection connection)
+    {
+      if (operationsInFlight.decrementAndGet() == 0
+          && this.staticConnection == null)
+      {
+        connection.close();
+      }
+      startWork();
+    }
+
+
+
+    public abstract FutureResult<?> performOperation(
+        final AsynchronousConnection connection,
+        final DataSource[] dataSources, final long startTime);
+
+
+
+    public void startWork()
+    {
+      if (!stopRequested && !(maxIterations > 0 && count >= maxIterations))
+      {
+        if (this.staticConnection == null)
+        {
+          connectionFactory
+              .getAsynchronousConnection(new ResultHandler<AsynchronousConnection>()
+              {
+                public void handleErrorResult(final ErrorResultException e)
+                {
+                  app.println(LocalizableMessage.raw(e.getResult()
+                      .getDiagnosticMessage()));
+                  if (e.getCause() != null && app.isVerbose())
+                  {
+                    e.getCause().printStackTrace(app.getErrorStream());
+                  }
+                  stopRequested = true;
+                }
+
+
+
+                public void handleResult(final AsynchronousConnection result)
+                {
+                  doWork(result);
+                }
+              });
+        }
+        else
+        {
+          if (!noRebind
+              && this.staticConnection instanceof AuthenticatedAsynchronousConnection)
+          {
+            final AuthenticatedAsynchronousConnection ac =
+              (AuthenticatedAsynchronousConnection) this.staticConnection;
+            ac.rebind(new ResultHandler<BindResult>()
+            {
+              public void handleErrorResult(final ErrorResultException e)
+              {
+                app.println(LocalizableMessage.raw(e.getResult().toString()));
+                if (e.getCause() != null && app.isVerbose())
+                {
+                  e.getCause().printStackTrace(app.getErrorStream());
+                }
+                stopRequested = true;
+              }
+
+
+
+              public void handleResult(final BindResult result)
+              {
+                doWork(staticConnection);
+              }
+            });
+          }
+          else
+          {
+            doWork(staticConnection);
+          }
+        }
+      }
+      else
+      {
+        latch.countDown();
+      }
+    }
+
+
+
+    public void waitFor() throws InterruptedException
+    {
+      latch.await();
+    }
+
+
+
+    private void doWork(final AsynchronousConnection connection)
+    {
+      long start;
+      double sleepTimeInMS = 0;
+      final int opsToPerform = isAsync ? numConcurrentTasks : numConcurrentTasks
+          - operationsInFlight.get();
+      for (int i = 0; i < opsToPerform; i++)
+      {
+        if (maxIterations > 0 && count >= maxIterations)
+        {
+          break;
+        }
+        start = System.nanoTime();
+        performOperation(connection, dataSources.get(), start);
+        operationRecentCount.getAndIncrement();
+        operationsInFlight.getAndIncrement();
+        count++;
+
+        if (targetThroughput > 0)
+        {
+          try
+          {
+            if (sleepTimeInMS > 1)
+            {
+              Thread.sleep((long) Math.floor(sleepTimeInMS));
+            }
+          }
+          catch (final InterruptedException e)
+          {
+            continue;
+          }
+
+          sleepTimeInMS += targetTimeInMS
+              - ((System.nanoTime() - start) / 1000000.0);
+          if (sleepTimeInMS < -60000)
+          {
+            // If we fall behind by 60 seconds, just forget about
+            // catching up
+            sleepTimeInMS = -60000;
+          }
+        }
+      }
+    }
+  }
+
+
+
   /**
    * Statistics thread base implementation.
    */
@@ -263,8 +429,8 @@
         if (resultCount > 0)
         {
           strings[2] = String.format("%.3f",
-              (waitTime - (gcDuration - lastGCDuration))
-                  / (double) resultCount / 1000000.0);
+              (waitTime - (gcDuration - lastGCDuration)) / (double) resultCount
+                  / 1000000.0);
         }
         else
         {
@@ -370,7 +536,7 @@
         {
           // Script-friendly.
           app.getOutputStream().print(averageDuration);
-          for (String s : strings)
+          for (final String s : strings)
           {
             app.getOutputStream().print(",");
             app.getOutputStream().print(s);
@@ -399,12 +565,17 @@
   class UpdateStatsResultHandler<S extends Result> implements ResultHandler<S>
   {
     private final long startTime;
+    private final AsynchronousConnection connection;
+    private final ConnectionWorker worker;
 
 
 
-    UpdateStatsResultHandler(final long startTime)
+    UpdateStatsResultHandler(final long startTime,
+        final AsynchronousConnection connection, final ConnectionWorker worker)
     {
       this.startTime = startTime;
+      this.connection = connection;
+      this.worker = worker;
     }
 
 
@@ -418,6 +589,8 @@
       {
         app.println(LocalizableMessage.raw(error.getResult().toString()));
       }
+
+      worker.operationCompleted(connection);
     }
 
 
@@ -426,6 +599,7 @@
     {
       successRecentCount.getAndIncrement();
       updateStats();
+      worker.operationCompleted(connection);
     }
 
 
@@ -487,8 +661,7 @@
       AsynchronousConnection connection;
 
       final double targetTimeInMS =
-        (1.0 / (targetThroughput /
-            (double) (numThreads * numConnections))) * 1000.0;
+        (1.0 / (targetThroughput / (double) (numConcurrentTasks * numConnections))) * 1000.0;
       double sleepTimeInMS = 0;
       long start;
       while (!stopRequested && !(maxIterations > 0 && count >= maxIterations))
@@ -797,20 +970,20 @@
 
   private final AtomicLong waitRecentTime = new AtomicLong();
 
-  private final AtomicReference<ReversableArray> eTimeBuffer =
-    new AtomicReference<ReversableArray>(new ReversableArray(100000));
+  private final AtomicReference<ReversableArray> eTimeBuffer = new AtomicReference<ReversableArray>(
+      new ReversableArray(100000));
 
   private final ConsoleApplication app;
 
   private DataSource[] dataSourcePrototypes;
 
   // Thread local copies of the data sources
-  private final ThreadLocal<DataSource[]> dataSources =
-    new ThreadLocal<DataSource[]>()
+  private final ThreadLocal<DataSource[]> dataSources = new ThreadLocal<DataSource[]>()
   {
     /**
      * {@inheritDoc}
      */
+    @Override
     protected DataSource[] initialValue()
     {
       final DataSource[] prototypes = getDataSources();
@@ -827,7 +1000,7 @@
 
   private volatile boolean stopRequested;
 
-  private int numThreads;
+  private int numConcurrentTasks;
 
   private int numConnections;
 
@@ -841,7 +1014,9 @@
 
   private int statsInterval;
 
-  private final IntegerArgument numThreadsArgument;
+  private double targetTimeInMS;
+
+  private final IntegerArgument numConcurrentTasksArgument;
 
   private final IntegerArgument maxIterationsArgument;
 
@@ -864,52 +1039,52 @@
 
 
   PerformanceRunner(final ArgumentParser argParser,
-                    final ConsoleApplication app,
-                    boolean neverRebind, boolean neverAsynchronous,
-                    boolean alwaysSingleThreaded)
+      final ConsoleApplication app, final boolean neverRebind,
+      final boolean neverAsynchronous, final boolean alwaysSingleThreaded)
       throws ArgumentException
   {
     this.app = app;
-    numThreadsArgument = new IntegerArgument("numThreads", 't', "numThreads",
-        false, false, true, LocalizableMessage.raw("{numThreads}"), 1, null,
-        true, 1, false, 0, LocalizableMessage
-            .raw("Number of worker threads per connection"));
-    numThreadsArgument.setPropertyName("numThreads");
-    if(!alwaysSingleThreaded)
+    numConcurrentTasksArgument = new IntegerArgument("numConcurrentTasks", 't',
+        "numConcurrentTasks", false, false, true,
+        LocalizableMessage.raw("{numConcurrentTasks}"), 1, null, true, 1,
+        false, 0,
+        LocalizableMessage.raw("Number of concurrent tasks per connection"));
+    numConcurrentTasksArgument.setPropertyName("numConcurrentTasks");
+    if (!alwaysSingleThreaded)
     {
-      argParser.addArgument(numThreadsArgument);
+      argParser.addArgument(numConcurrentTasksArgument);
     }
     else
     {
-      numThreadsArgument.addValue("1");
+      numConcurrentTasksArgument.addValue("1");
     }
 
     numConnectionsArgument = new IntegerArgument("numConnections", 'c',
-        "numConnections", false, false, true, LocalizableMessage
-            .raw("{numConnections}"), 1, null, true, 1, false, 0,
+        "numConnections", false, false, true,
+        LocalizableMessage.raw("{numConnections}"), 1, null, true, 1, false, 0,
         LocalizableMessage.raw("Number of connections"));
     numConnectionsArgument.setPropertyName("numConnections");
     argParser.addArgument(numConnectionsArgument);
 
     maxIterationsArgument = new IntegerArgument("maxIterations", 'm',
-        "maxIterations", false, false, true, LocalizableMessage
-            .raw("{maxIterations}"), 0, null, LocalizableMessage
-            .raw("Max iterations, 0 for unlimited"));
+        "maxIterations", false, false, true,
+        LocalizableMessage.raw("{maxIterations}"), 0, null,
+        LocalizableMessage.raw("Max iterations, 0 for unlimited"));
     maxIterationsArgument.setPropertyName("maxIterations");
     argParser.addArgument(maxIterationsArgument);
 
     statsIntervalArgument = new IntegerArgument("statInterval", 'i',
-        "statInterval", false, false, true, LocalizableMessage
-            .raw("{statInterval}"), 5, null, true, 1, false, 0,
+        "statInterval", false, false, true,
+        LocalizableMessage.raw("{statInterval}"), 5, null, true, 1, false, 0,
         LocalizableMessage
             .raw("Display results each specified number of seconds"));
     statsIntervalArgument.setPropertyName("statInterval");
     argParser.addArgument(statsIntervalArgument);
 
     targetThroughputArgument = new IntegerArgument("targetThroughput", 'M',
-        "targetThroughput", false, false, true, LocalizableMessage
-            .raw("{targetThroughput}"), 0, null, LocalizableMessage
-            .raw("Target average throughput to achieve"));
+        "targetThroughput", false, false, true,
+        LocalizableMessage.raw("{targetThroughput}"), 0, null,
+        LocalizableMessage.raw("Target average throughput to achieve"));
     targetThroughputArgument.setPropertyName("targetThroughput");
     argParser.addArgument(targetThroughputArgument);
 
@@ -929,7 +1104,7 @@
     noRebindArgument = new BooleanArgument("noRebind", 'F', "noRebind",
         LocalizableMessage.raw("Keep connections open and don't rebind"));
     noRebindArgument.setPropertyName("noRebind");
-    if(!neverRebind)
+    if (!neverRebind)
     {
       argParser.addArgument(noRebindArgument);
     }
@@ -939,24 +1114,30 @@
     }
 
     asyncArgument = new BooleanArgument("asynchronous", 'A', "asynchronous",
-        LocalizableMessage.raw("Use asynchronous mode and don't " +
-            "wait for results before sending the next request"));
+        LocalizableMessage.raw("Use asynchronous mode and don't "
+            + "wait for results before sending the next request"));
     asyncArgument.setPropertyName("asynchronous");
-    if(!neverAsynchronous)
+    if (!neverAsynchronous)
     {
       argParser.addArgument(asyncArgument);
     }
 
-    arguments = new StringArgument("argument", 'g', "argument", false, true,
-        true, LocalizableMessage.raw("{generator function or static string}"),
-        null, null,
-        LocalizableMessage.raw("Argument used to evaluate the Java " +
-            "style format strings in program parameters (ie. Base DN, " +
-            "Search Filter). The set of all arguments provided form the " +
-            "the argument list in order. Besides static string " +
-            "arguments, they can be generated per iteration with the " +
-            "following functions: " + StaticUtils.EOL +
-            DataSource.getUsage()));
+    arguments = new StringArgument(
+        "argument",
+        'g',
+        "argument",
+        false,
+        true,
+        true,
+        LocalizableMessage.raw("{generator function or static string}"),
+        null,
+        null,
+        LocalizableMessage.raw("Argument used to evaluate the Java "
+            + "style format strings in program parameters (ie. Base DN, "
+            + "Search Filter). The set of all arguments provided form the "
+            + "the argument list in order. Besides static string "
+            + "arguments, they can be generated per iteration with the "
+            + "following functions: " + StaticUtils.EOL + DataSource.getUsage()));
     argParser.addArgument(arguments);
   }
 
@@ -986,8 +1167,7 @@
 
 
 
-  public void handleUnsolicitedNotification(
-      final ExtendedResult notification)
+  public void handleUnsolicitedNotification(final ExtendedResult notification)
   {
     // Ignore
   }
@@ -997,20 +1177,19 @@
   public final void validate() throws ArgumentException
   {
     numConnections = numConnectionsArgument.getIntValue();
-    numThreads = numThreadsArgument.getIntValue();
-    maxIterations = maxIterationsArgument.getIntValue() /
-        numConnections / numThreads;
+    numConcurrentTasks = numConcurrentTasksArgument.getIntValue();
+    maxIterations = maxIterationsArgument.getIntValue() / numConnections;
     statsInterval = statsIntervalArgument.getIntValue() * 1000;
     targetThroughput = targetThroughputArgument.getIntValue();
 
     isAsync = asyncArgument.isPresent();
     noRebind = noRebindArgument.isPresent();
 
-    if (!noRebindArgument.isPresent() && this.numThreads > 1)
+    if (!noRebindArgument.isPresent() && this.numConcurrentTasks > 1)
     {
       throw new ArgumentException(LocalizableMessage.raw("--"
           + noRebindArgument.getLongIdentifier() + " must be used if --"
-          + numThreadsArgument.getLongIdentifier() + " is > 1"));
+          + numConcurrentTasksArgument.getLongIdentifier() + " is > 1"));
     }
 
     if (!noRebindArgument.isPresent() && asyncArgument.isPresent())
@@ -1021,6 +1200,9 @@
     }
 
     dataSourcePrototypes = DataSource.parse(arguments.getValues());
+
+    targetTimeInMS =
+      (1.0 / (targetThroughput / (double) (numConcurrentTasks * numConnections))) * 1000.0;
   }
 
 
@@ -1037,22 +1219,22 @@
 
 
 
+  abstract ConnectionWorker newConnectionWorker(
+      final AsynchronousConnection connection,
+      final ConnectionFactory connectionFactory);
+
+
+
   abstract StatsThread newStatsThread();
 
 
 
-  abstract WorkerThread newWorkerThread(AsynchronousConnection connection,
-      ConnectionFactory connectionFactory);
-
-
-
   final int run(final ConnectionFactory connectionFactory)
   {
-    final List<Thread> threads = new ArrayList<Thread>();
+    final List<ConnectionWorker> workers = new ArrayList<ConnectionWorker>();
     final List<AsynchronousConnection> connections = new ArrayList<AsynchronousConnection>();
 
     AsynchronousConnection connection = null;
-    Thread thread;
     try
     {
       for (int i = 0; i < numConnections; i++)
@@ -1063,21 +1245,18 @@
           connection.addConnectionEventListener(this);
           connections.add(connection);
         }
-        for (int j = 0; j < numThreads; j++)
-        {
-          thread = newWorkerThread(connection, connectionFactory);
-
-          threads.add(thread);
-          thread.start();
-        }
+        final ConnectionWorker worker = newConnectionWorker(connection,
+            connectionFactory);
+        workers.add(worker);
+        worker.startWork();
       }
 
       final Thread statsThread = newStatsThread();
       statsThread.start();
 
-      for (final Thread t : threads)
+      for (final ConnectionWorker w : workers)
       {
-        t.join();
+        w.waitFor();
       }
       stopRequested = true;
       statsThread.join();

--
Gitblit v1.10.0