From 80774bcd0c732d9446cfc09fc9b7c39a3e4003ad Mon Sep 17 00:00:00 2001
From: Matthew Swift <matthew.swift@forgerock.com>
Date: Wed, 23 Mar 2011 22:27:01 +0000
Subject: [PATCH] Fix issue OpenDJ-95: Socket leak and constant disconnect/reconnect when a directory server can no longer reach its connected replication server

---
 opendj-sdk/opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java |  451 ++++++++++++++++++++++++++++++++++++++++----------------
 1 files changed, 321 insertions(+), 130 deletions(-)

diff --git a/opendj-sdk/opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java b/opendj-sdk/opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java
index adf16b2..8a54bc2 100644
--- a/opendj-sdk/opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java
+++ b/opendj-sdk/opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java
@@ -27,6 +27,8 @@
  */
 package org.opends.server.replication.protocol;
 
+
+
 import static org.opends.server.loggers.debug.DebugLogger.debugEnabled;
 import static org.opends.server.loggers.debug.DebugLogger.getTracer;
 import static org.opends.server.util.StaticUtils.stackTraceToSingleLineString;
@@ -36,241 +38,430 @@
 import java.io.OutputStream;
 import java.net.Socket;
 import java.net.SocketException;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.zip.DataFormatException;
 
+import javax.net.ssl.SSLSocket;
+
 import org.opends.server.loggers.debug.DebugTracer;
 
-import javax.net.ssl.SSLSocket;
+
 
 /**
  * This class implements a protocol session using TLS.
  */
-public class TLSSocketSession implements ProtocolSession
+public final class TLSSocketSession implements ProtocolSession
 {
   /**
    * The tracer object for the debug logger.
    */
   private static final DebugTracer TRACER = getTracer();
 
-  private Socket plainSocket;
-  private SSLSocket secureSocket;
-  private InputStream input;
-  private OutputStream output;
-  private InputStream plainInput;
-  private OutputStream plainOutput;
-  byte[] rcvLengthBuf = new byte[8];
+  private final Socket plainSocket;
+  private final SSLSocket secureSocket;
+  private final InputStream plainInput;
+  private final OutputStream plainOutput;
+  private final byte[] rcvLengthBuf = new byte[8];
 
   /**
    * The time the last message published to this session.
    */
   private volatile long lastPublishTime = 0;
 
-
   /**
    * The time the last message was received on this session.
    */
-  private long lastReceiveTime = 0;
+  private volatile long lastReceiveTime = 0;
 
+  // Close and error guarded by stateLock: use a different lock to publish since
+  // publishing can block, and we don't want to block while closing failed
+  // connections.
+  private final Object stateLock = new Object();
   private boolean closeInitiated = false;
+  private Throwable sessionError = null;
 
+  // Publish guarded by publishLock: use a full lock here so that we can
+  // optionally publish StopMsg during close.
+  private final Lock publishLock = new ReentrantLock();
+
+  // Does not need protecting: updated only during single threaded handshake.
   private short protocolVersion = ProtocolVersion.getCurrentVersion();
+  private InputStream input;
+  private OutputStream output;
+
+
 
   /**
    * Creates a new TLSSocketSession.
    *
-   * @param socket       The regular Socket on which the SocketSession will be
-   *                     based.
-   * @param secureSocket The secure Socket on which the SocketSession will be
-   *                     based.
-   * @throws IOException When an IException happens on the socket.
+   * @param socket
+   *          The regular Socket on which the SocketSession will be based.
+   * @param secureSocket
+   *          The secure Socket on which the SocketSession will be based.
+   * @throws IOException
+   *           When an IException happens on the socket.
    */
-  public TLSSocketSession(Socket socket, SSLSocket secureSocket)
-       throws IOException
+  public TLSSocketSession(final Socket socket,
+      final SSLSocket secureSocket) throws IOException
   {
-    plainSocket = socket;
-    this.secureSocket = secureSocket;
-    plainInput = plainSocket.getInputStream();
-    plainOutput = plainSocket.getOutputStream();
-    input = secureSocket.getInputStream();
-    output = secureSocket.getOutputStream();
-  }
-
-
-  /**
-   * {@inheritDoc}
-   */
-  public void close() throws IOException
-  {
-    closeInitiated = true;
     if (debugEnabled())
     {
-      TRACER.debugInfo("Closing SocketSession." +
-          stackTraceToSingleLineString(new Exception("Stack:")));
+      TRACER.debugInfo(
+          "Creating TLSSocketSession from %s to %s in %s",
+          socket.getLocalSocketAddress(),
+          socket.getRemoteSocketAddress(),
+          stackTraceToSingleLineString(new Exception()));
     }
-    if (plainSocket != null && !plainSocket.isClosed())
-    {
-      plainInput.close();
-      plainOutput.close();
-      plainSocket.close();
-    }
-    if (secureSocket != null && !secureSocket.isClosed())
-    {
-      input.close();
-      output.close();
-      secureSocket.close();
-    }
+
+    this.plainSocket = socket;
+    this.secureSocket = secureSocket;
+    this.plainInput = plainSocket.getInputStream();
+    this.plainOutput = plainSocket.getOutputStream();
+    this.input = secureSocket.getInputStream();
+    this.output = secureSocket.getOutputStream();
   }
 
+
+
   /**
    * {@inheritDoc}
    */
-  public synchronized void publish(ReplicationMsg msg)
-         throws IOException
+  @Override
+  public void close()
   {
-    publish(msg, ProtocolVersion.getCurrentVersion());
-  }
+    Throwable localSessionError;
 
-  /**
-   * {@inheritDoc}
-   */
-  public synchronized void publish(ReplicationMsg msg, short reqProtocolVersion)
-         throws IOException
-  {
-    byte[] buffer = msg.getBytes(reqProtocolVersion);
-    String str = String.format("%08x", buffer.length);
-    byte[] sendLengthBuf = str.getBytes();
-
-    output.write(sendLengthBuf);
-    output.write(buffer);
-    output.flush();
-
-    lastPublishTime = System.currentTimeMillis();
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  public ReplicationMsg receive() throws IOException,
-      ClassNotFoundException, DataFormatException,
-      NotSupportedOldVersionPDUException
-  {
-    /* Read the first 8 bytes containing the packet length */
-    int length = 0;
-
-    /* Let's start the stop-watch before waiting on read */
-    /* for the heartbeat check to be operationnal        */
-    lastReceiveTime = System.currentTimeMillis();
-
-    while (length<8)
+    synchronized (stateLock)
     {
-      int read = input.read(rcvLengthBuf, length, 8-length);
-      if (read == -1)
+      if (closeInitiated)
       {
-        lastReceiveTime=0;
-        throw new IOException("no more data");
+        return;
+      }
+
+      localSessionError = sessionError;
+      closeInitiated = true;
+    }
+
+    // Perform close outside of critical section.
+    if (debugEnabled())
+    {
+      if (localSessionError == null)
+      {
+        TRACER.debugInfo(
+            "Closing TLSSocketSession from %s to %s in %s",
+            plainSocket.getLocalSocketAddress(),
+            plainSocket.getRemoteSocketAddress(),
+            stackTraceToSingleLineString(new Exception()));
       }
       else
       {
-        length += read;
+        TRACER.debugInfo(
+            "Aborting TLSSocketSession from %s to %s in %s due to the "
+                + "following error: %s",
+            plainSocket.getLocalSocketAddress(),
+            plainSocket.getRemoteSocketAddress(),
+            stackTraceToSingleLineString(new Exception()),
+            stackTraceToSingleLineString(localSessionError));
       }
     }
 
-    int totalLength = Integer.parseInt(new String(rcvLengthBuf), 16);
+    // V4 protocol introduces a StopMsg to properly end communications.
+    if (localSessionError == null)
+    {
+      if (protocolVersion >= ProtocolVersion.REPLICATION_PROTOCOL_V4)
+      {
+        if (publishLock.tryLock())
+        {
+          try
+          {
+            publish(new StopMsg());
+          }
+          catch (final IOException ignored)
+          {
+            // Ignore errors on close.
+          }
+          finally
+          {
+            publishLock.unlock();
+          }
+        }
+      }
+    }
 
     try
     {
-      length = 0;
-      byte[] buffer = new byte[totalLength];
-      while (length < totalLength)
-      {
-        length += input.read(buffer, length, totalLength - length);
-      }
-      /* We do not want the heartbeat to close the session when */
-      /* we are processing a message even a time consuming one. */
-      lastReceiveTime=0;
-      return ReplicationMsg.generateMsg(buffer, protocolVersion);
+      plainSocket.close();
     }
-    catch (OutOfMemoryError e)
+    catch (final IOException ignored)
     {
-      throw new IOException("Packet too large, can't allocate "
-                            + totalLength + " bytes.");
+      // Ignore errors on close.
+    }
+
+    try
+    {
+      secureSocket.close();
+    }
+    catch (final IOException ignored)
+    {
+      // Ignore errors on close.
     }
   }
 
-  /**
-   * {@inheritDoc}
-   */
-  public void stopEncryption()
-  {
-    input = plainInput;
-    output = plainOutput;
-  }
+
 
   /**
    * {@inheritDoc}
    */
-  public boolean isEncrypted()
+  @Override
+  public boolean closeInitiated()
   {
-    return !(input == plainInput);
+    synchronized (stateLock)
+    {
+      return closeInitiated;
+    }
   }
 
+
+
   /**
    * {@inheritDoc}
    */
+  @Override
   public long getLastPublishTime()
   {
     return lastPublishTime;
   }
 
+
+
   /**
    * {@inheritDoc}
    */
+  @Override
   public long getLastReceiveTime()
   {
-    if (lastReceiveTime==0)
+    if (lastReceiveTime == 0)
     {
       return System.currentTimeMillis();
     }
     return lastReceiveTime;
   }
 
-  /**
-   * {@inheritDoc}
-   */
-  public String getRemoteAddress()
-  {
-    return plainSocket.getInetAddress().getHostAddress();
-  }
+
 
   /**
    * {@inheritDoc}
    */
+  @Override
   public String getReadableRemoteAddress()
   {
     return plainSocket.getRemoteSocketAddress().toString();
   }
 
+
+
   /**
    * {@inheritDoc}
    */
-  public void setSoTimeout(int timeout) throws SocketException
+  @Override
+  public String getRemoteAddress()
+  {
+    return plainSocket.getInetAddress().getHostAddress();
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public boolean isEncrypted()
+  {
+    return input != plainInput;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void publish(final ReplicationMsg msg) throws IOException
+  {
+    publish(msg, ProtocolVersion.getCurrentVersion());
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void publish(final ReplicationMsg msg,
+      final short reqProtocolVersion) throws IOException
+  {
+    final byte[] buffer = msg.getBytes(reqProtocolVersion);
+    final String str = String.format("%08x", buffer.length);
+    final byte[] sendLengthBuf = str.getBytes();
+
+    publishLock.lock();
+    try
+    {
+      output.write(sendLengthBuf);
+      output.write(buffer);
+      output.flush();
+    }
+    catch (final IOException e)
+    {
+      setSessionError(e);
+      throw e;
+    }
+    finally
+    {
+      publishLock.unlock();
+    }
+
+    lastPublishTime = System.currentTimeMillis();
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ReplicationMsg receive() throws IOException,
+      DataFormatException, NotSupportedOldVersionPDUException
+  {
+    try
+    {
+      // Read the first 8 bytes containing the packet length.
+      int length = 0;
+
+      // Let's start the stop-watch before waiting on read for the heartbeat
+      // check
+      // to be operational.
+      lastReceiveTime = System.currentTimeMillis();
+
+      while (length < 8)
+      {
+        final int read = input.read(rcvLengthBuf, length, 8 - length);
+        if (read == -1)
+        {
+          lastReceiveTime = 0;
+          throw new IOException("no more data");
+        }
+        else
+        {
+          length += read;
+        }
+      }
+
+      final int totalLength = Integer.parseInt(new String(
+          rcvLengthBuf), 16);
+
+      try
+      {
+        length = 0;
+        final byte[] buffer = new byte[totalLength];
+        while (length < totalLength)
+        {
+          final int read = input.read(buffer, length, totalLength
+              - length);
+          if (read == -1)
+          {
+            lastReceiveTime = 0;
+            throw new IOException("no more data");
+          }
+          else
+          {
+            length += read;
+          }
+        }
+        // We do not want the heartbeat to close the session when we are
+        // processing a message even a time consuming one.
+        lastReceiveTime = 0;
+        return ReplicationMsg.generateMsg(buffer, protocolVersion);
+      }
+      catch (final OutOfMemoryError e)
+      {
+        throw new IOException("Packet too large, can't allocate "
+            + totalLength + " bytes.");
+      }
+    }
+    catch (final IOException e)
+    {
+      setSessionError(e);
+      throw e;
+    }
+    catch (final DataFormatException e)
+    {
+      setSessionError(e);
+      throw e;
+    }
+    catch (final NotSupportedOldVersionPDUException e)
+    {
+      setSessionError(e);
+      throw e;
+    }
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void setProtocolVersion(final short version)
+  {
+    protocolVersion = version;
+  }
+
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public void setSoTimeout(final int timeout) throws SocketException
   {
     plainSocket.setSoTimeout(timeout);
   }
 
-  /**
-   * {@inheritDoc}
-   */
-  public boolean closeInitiated()
-  {
-    return closeInitiated;
-  }
+
 
   /**
    * {@inheritDoc}
    */
-  public void setProtocolVersion(short version)
+  @Override
+  public void stopEncryption()
   {
-    protocolVersion = version;
+    // The secure socket has been configured not to auto close the underlying
+    // plain socket.
+    try
+    {
+      secureSocket.close();
+    }
+    catch (IOException ignored)
+    {
+      // Ignore.
+    }
+
+    input = plainInput;
+    output = plainOutput;
+  }
+
+
+
+  private void setSessionError(final Exception e)
+  {
+    synchronized (stateLock)
+    {
+      if (sessionError == null)
+      {
+        sessionError = e;
+      }
+    }
   }
 }

--
Gitblit v1.10.0