mirror of https://github.com/OpenIdentityPlatform/OpenDJ.git

Matthew Swift
23.27.2011 6ee1440f6f56ac066f97383315b2798287f0821a
opends/src/server/org/opends/server/replication/protocol/TLSSocketSession.java
@@ -27,6 +27,8 @@
 */
package org.opends.server.replication.protocol;
import static org.opends.server.loggers.debug.DebugLogger.debugEnabled;
import static org.opends.server.loggers.debug.DebugLogger.getTracer;
import static org.opends.server.util.StaticUtils.stackTraceToSingleLineString;
@@ -36,241 +38,430 @@
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketException;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.zip.DataFormatException;
import javax.net.ssl.SSLSocket;
import org.opends.server.loggers.debug.DebugTracer;
import javax.net.ssl.SSLSocket;
/**
 * This class implements a protocol session using TLS.
 */
public class TLSSocketSession implements ProtocolSession
public final class TLSSocketSession implements ProtocolSession
{
  /**
   * The tracer object for the debug logger.
   */
  private static final DebugTracer TRACER = getTracer();
  private Socket plainSocket;
  private SSLSocket secureSocket;
  private InputStream input;
  private OutputStream output;
  private InputStream plainInput;
  private OutputStream plainOutput;
  byte[] rcvLengthBuf = new byte[8];
  private final Socket plainSocket;
  private final SSLSocket secureSocket;
  private final InputStream plainInput;
  private final OutputStream plainOutput;
  private final byte[] rcvLengthBuf = new byte[8];
  /**
   * The time the last message published to this session.
   */
  private volatile long lastPublishTime = 0;
  /**
   * The time the last message was received on this session.
   */
  private long lastReceiveTime = 0;
  private volatile long lastReceiveTime = 0;
  // Close and error guarded by stateLock: use a different lock to publish since
  // publishing can block, and we don't want to block while closing failed
  // connections.
  private final Object stateLock = new Object();
  private boolean closeInitiated = false;
  private Throwable sessionError = null;
  // Publish guarded by publishLock: use a full lock here so that we can
  // optionally publish StopMsg during close.
  private final Lock publishLock = new ReentrantLock();
  // Does not need protecting: updated only during single threaded handshake.
  private short protocolVersion = ProtocolVersion.getCurrentVersion();
  private InputStream input;
  private OutputStream output;
  /**
   * Creates a new TLSSocketSession.
   *
   * @param socket       The regular Socket on which the SocketSession will be
   *                     based.
   * @param secureSocket The secure Socket on which the SocketSession will be
   *                     based.
   * @throws IOException When an IException happens on the socket.
   * @param socket
   *          The regular Socket on which the SocketSession will be based.
   * @param secureSocket
   *          The secure Socket on which the SocketSession will be based.
   * @throws IOException
   *           When an IException happens on the socket.
   */
  public TLSSocketSession(Socket socket, SSLSocket secureSocket)
       throws IOException
  public TLSSocketSession(final Socket socket,
      final SSLSocket secureSocket) throws IOException
  {
    plainSocket = socket;
    this.secureSocket = secureSocket;
    plainInput = plainSocket.getInputStream();
    plainOutput = plainSocket.getOutputStream();
    input = secureSocket.getInputStream();
    output = secureSocket.getOutputStream();
  }
  /**
   * {@inheritDoc}
   */
  public void close() throws IOException
  {
    closeInitiated = true;
    if (debugEnabled())
    {
      TRACER.debugInfo("Closing SocketSession." +
          stackTraceToSingleLineString(new Exception("Stack:")));
      TRACER.debugInfo(
          "Creating TLSSocketSession from %s to %s in %s",
          socket.getLocalSocketAddress(),
          socket.getRemoteSocketAddress(),
          stackTraceToSingleLineString(new Exception()));
    }
    if (plainSocket != null && !plainSocket.isClosed())
    {
      plainInput.close();
      plainOutput.close();
      plainSocket.close();
    }
    if (secureSocket != null && !secureSocket.isClosed())
    {
      input.close();
      output.close();
      secureSocket.close();
    }
    this.plainSocket = socket;
    this.secureSocket = secureSocket;
    this.plainInput = plainSocket.getInputStream();
    this.plainOutput = plainSocket.getOutputStream();
    this.input = secureSocket.getInputStream();
    this.output = secureSocket.getOutputStream();
  }
  /**
   * {@inheritDoc}
   */
  public synchronized void publish(ReplicationMsg msg)
         throws IOException
  @Override
  public void close()
  {
    publish(msg, ProtocolVersion.getCurrentVersion());
  }
    Throwable localSessionError;
  /**
   * {@inheritDoc}
   */
  public synchronized void publish(ReplicationMsg msg, short reqProtocolVersion)
         throws IOException
  {
    byte[] buffer = msg.getBytes(reqProtocolVersion);
    String str = String.format("%08x", buffer.length);
    byte[] sendLengthBuf = str.getBytes();
    output.write(sendLengthBuf);
    output.write(buffer);
    output.flush();
    lastPublishTime = System.currentTimeMillis();
  }
  /**
   * {@inheritDoc}
   */
  public ReplicationMsg receive() throws IOException,
      ClassNotFoundException, DataFormatException,
      NotSupportedOldVersionPDUException
  {
    /* Read the first 8 bytes containing the packet length */
    int length = 0;
    /* Let's start the stop-watch before waiting on read */
    /* for the heartbeat check to be operationnal        */
    lastReceiveTime = System.currentTimeMillis();
    while (length<8)
    synchronized (stateLock)
    {
      int read = input.read(rcvLengthBuf, length, 8-length);
      if (read == -1)
      if (closeInitiated)
      {
        lastReceiveTime=0;
        throw new IOException("no more data");
        return;
      }
      localSessionError = sessionError;
      closeInitiated = true;
    }
    // Perform close outside of critical section.
    if (debugEnabled())
    {
      if (localSessionError == null)
      {
        TRACER.debugInfo(
            "Closing TLSSocketSession from %s to %s in %s",
            plainSocket.getLocalSocketAddress(),
            plainSocket.getRemoteSocketAddress(),
            stackTraceToSingleLineString(new Exception()));
      }
      else
      {
        length += read;
        TRACER.debugInfo(
            "Aborting TLSSocketSession from %s to %s in %s due to the "
                + "following error: %s",
            plainSocket.getLocalSocketAddress(),
            plainSocket.getRemoteSocketAddress(),
            stackTraceToSingleLineString(new Exception()),
            stackTraceToSingleLineString(localSessionError));
      }
    }
    int totalLength = Integer.parseInt(new String(rcvLengthBuf), 16);
    // V4 protocol introduces a StopMsg to properly end communications.
    if (localSessionError == null)
    {
      if (protocolVersion >= ProtocolVersion.REPLICATION_PROTOCOL_V4)
      {
        if (publishLock.tryLock())
        {
          try
          {
            publish(new StopMsg());
          }
          catch (final IOException ignored)
          {
            // Ignore errors on close.
          }
          finally
          {
            publishLock.unlock();
          }
        }
      }
    }
    try
    {
      length = 0;
      byte[] buffer = new byte[totalLength];
      while (length < totalLength)
      {
        length += input.read(buffer, length, totalLength - length);
      }
      /* We do not want the heartbeat to close the session when */
      /* we are processing a message even a time consuming one. */
      lastReceiveTime=0;
      return ReplicationMsg.generateMsg(buffer, protocolVersion);
      plainSocket.close();
    }
    catch (OutOfMemoryError e)
    catch (final IOException ignored)
    {
      throw new IOException("Packet too large, can't allocate "
                            + totalLength + " bytes.");
      // Ignore errors on close.
    }
    try
    {
      secureSocket.close();
    }
    catch (final IOException ignored)
    {
      // Ignore errors on close.
    }
  }
  /**
   * {@inheritDoc}
   */
  public void stopEncryption()
  {
    input = plainInput;
    output = plainOutput;
  }
  /**
   * {@inheritDoc}
   */
  public boolean isEncrypted()
  @Override
  public boolean closeInitiated()
  {
    return !(input == plainInput);
    synchronized (stateLock)
    {
      return closeInitiated;
    }
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public long getLastPublishTime()
  {
    return lastPublishTime;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public long getLastReceiveTime()
  {
    if (lastReceiveTime==0)
    if (lastReceiveTime == 0)
    {
      return System.currentTimeMillis();
    }
    return lastReceiveTime;
  }
  /**
   * {@inheritDoc}
   */
  public String getRemoteAddress()
  {
    return plainSocket.getInetAddress().getHostAddress();
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public String getReadableRemoteAddress()
  {
    return plainSocket.getRemoteSocketAddress().toString();
  }
  /**
   * {@inheritDoc}
   */
  public void setSoTimeout(int timeout) throws SocketException
  @Override
  public String getRemoteAddress()
  {
    return plainSocket.getInetAddress().getHostAddress();
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public boolean isEncrypted()
  {
    return input != plainInput;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public void publish(final ReplicationMsg msg) throws IOException
  {
    publish(msg, ProtocolVersion.getCurrentVersion());
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public void publish(final ReplicationMsg msg,
      final short reqProtocolVersion) throws IOException
  {
    final byte[] buffer = msg.getBytes(reqProtocolVersion);
    final String str = String.format("%08x", buffer.length);
    final byte[] sendLengthBuf = str.getBytes();
    publishLock.lock();
    try
    {
      output.write(sendLengthBuf);
      output.write(buffer);
      output.flush();
    }
    catch (final IOException e)
    {
      setSessionError(e);
      throw e;
    }
    finally
    {
      publishLock.unlock();
    }
    lastPublishTime = System.currentTimeMillis();
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public ReplicationMsg receive() throws IOException,
      DataFormatException, NotSupportedOldVersionPDUException
  {
    try
    {
      // Read the first 8 bytes containing the packet length.
      int length = 0;
      // Let's start the stop-watch before waiting on read for the heartbeat
      // check
      // to be operational.
      lastReceiveTime = System.currentTimeMillis();
      while (length < 8)
      {
        final int read = input.read(rcvLengthBuf, length, 8 - length);
        if (read == -1)
        {
          lastReceiveTime = 0;
          throw new IOException("no more data");
        }
        else
        {
          length += read;
        }
      }
      final int totalLength = Integer.parseInt(new String(
          rcvLengthBuf), 16);
      try
      {
        length = 0;
        final byte[] buffer = new byte[totalLength];
        while (length < totalLength)
        {
          final int read = input.read(buffer, length, totalLength
              - length);
          if (read == -1)
          {
            lastReceiveTime = 0;
            throw new IOException("no more data");
          }
          else
          {
            length += read;
          }
        }
        // We do not want the heartbeat to close the session when we are
        // processing a message even a time consuming one.
        lastReceiveTime = 0;
        return ReplicationMsg.generateMsg(buffer, protocolVersion);
      }
      catch (final OutOfMemoryError e)
      {
        throw new IOException("Packet too large, can't allocate "
            + totalLength + " bytes.");
      }
    }
    catch (final IOException e)
    {
      setSessionError(e);
      throw e;
    }
    catch (final DataFormatException e)
    {
      setSessionError(e);
      throw e;
    }
    catch (final NotSupportedOldVersionPDUException e)
    {
      setSessionError(e);
      throw e;
    }
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public void setProtocolVersion(final short version)
  {
    protocolVersion = version;
  }
  /**
   * {@inheritDoc}
   */
  @Override
  public void setSoTimeout(final int timeout) throws SocketException
  {
    plainSocket.setSoTimeout(timeout);
  }
  /**
   * {@inheritDoc}
   */
  public boolean closeInitiated()
  {
    return closeInitiated;
  }
  /**
   * {@inheritDoc}
   */
  public void setProtocolVersion(short version)
  @Override
  public void stopEncryption()
  {
    protocolVersion = version;
    // The secure socket has been configured not to auto close the underlying
    // plain socket.
    try
    {
      secureSocket.close();
    }
    catch (IOException ignored)
    {
      // Ignore.
    }
    input = plainInput;
    output = plainOutput;
  }
  private void setSessionError(final Exception e)
  {
    synchronized (stateLock)
    {
      if (sessionError == null)
      {
        sessionError = e;
      }
    }
  }
}