mirror of https://github.com/OpenIdentityPlatform/OpenDJ.git

Gaetan Boismal
28.52.2014 a4e2fc0298e8d60aa0e4bcfd3304303d952e0972
opendj-grizzly/src/main/java/org/forgerock/opendj/grizzly/GrizzlyLDAPConnectionFactory.java
@@ -28,8 +28,11 @@
package org.forgerock.opendj.grizzly;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLEngine;
@@ -40,12 +43,15 @@
import org.forgerock.opendj.ldap.ResultCode;
import org.forgerock.opendj.ldap.TimeoutChecker;
import org.forgerock.opendj.ldap.TimeoutEventListener;
import org.forgerock.opendj.ldap.spi.AbstractLdapConnectionFactoryImpl;
import org.forgerock.opendj.ldap.spi.AbstractLdapConnectionImpl;
import org.forgerock.opendj.ldap.requests.Requests;
import org.forgerock.opendj.ldap.requests.StartTLSExtendedRequest;
import org.forgerock.opendj.ldap.responses.ExtendedResult;
import org.forgerock.opendj.ldap.spi.LDAPConnectionFactoryImpl;
import org.forgerock.util.promise.Function;
import org.forgerock.util.promise.FailureHandler;
import org.forgerock.util.promise.Promise;
import org.forgerock.util.promise.PromiseImpl;
import org.forgerock.util.promise.SuccessHandler;
import org.glassfish.grizzly.CompletionHandler;
import org.glassfish.grizzly.EmptyCompletionHandler;
import org.glassfish.grizzly.SocketConnectorHandler;
import org.glassfish.grizzly.filterchain.FilterChain;
@@ -64,21 +70,19 @@
/**
 * LDAP connection factory implementation using Grizzly for transport.
 */
public final class GrizzlyLDAPConnectionFactory extends AbstractLdapConnectionFactoryImpl implements
        LDAPConnectionFactoryImpl {
public final class GrizzlyLDAPConnectionFactory implements LDAPConnectionFactoryImpl {
    private static final LocalizedLogger logger = LocalizedLogger.getLoggerForThisClass();
    /**
     * Adapts a Grizzly connection completion handler to an LDAP connection
     * promise.
     * Adapts a Grizzly connection completion handler to an LDAP connection promise.
     */
    @SuppressWarnings("rawtypes")
    private final class CompletionHandlerAdapter extends EmptyCompletionHandler<org.glassfish.grizzly.Connection>
            implements TimeoutEventListener {
        private final PromiseImpl<org.glassfish.grizzly.Connection, LdapException> promise;
    private final class CompletionHandlerAdapter implements
            CompletionHandler<org.glassfish.grizzly.Connection>, TimeoutEventListener {
        private final PromiseImpl<Connection, LdapException> promise;
        private final long timeoutEndTime;
        private CompletionHandlerAdapter(final PromiseImpl<org.glassfish.grizzly.Connection, LdapException> promise) {
        private CompletionHandlerAdapter(final PromiseImpl<Connection, LdapException> promise) {
            this.promise = promise;
            final long timeoutMS = getTimeout();
            this.timeoutEndTime = timeoutMS > 0 ? System.currentTimeMillis() + timeoutMS : 0;
@@ -86,11 +90,68 @@
        }
        @Override
        public void completed(final org.glassfish.grizzly.Connection connection) {
            timeoutChecker.get().removeListener(this);
            if (!promise.tryHandleResult(connection)) {
                // The connection has been either cancelled or it has timed out.
        public void cancelled() {
            // Ignore this.
        }
        @Override
        public void completed(final org.glassfish.grizzly.Connection result) {
            // Adapt the connection.
            final GrizzlyLDAPConnection connection = adaptConnection(result);
            // Plain connection.
            if (options.getSSLContext() == null) {
                onSuccess(connection);
                return;
            }
            // Start TLS or install SSL layer asynchronously.
            // Give up immediately if the promise has been cancelled or timed out.
            if (promise.isDone()) {
                timeoutChecker.get().removeListener(this);
                connection.close();
                return;
            }
            if (options.useStartTLS()) {
                // Chain StartTLS extended request.
                final StartTLSExtendedRequest startTLS =
                        Requests.newStartTLSExtendedRequest(options.getSSLContext());
                startTLS.addEnabledCipherSuite(options.getEnabledCipherSuites().toArray(
                    new String[options.getEnabledCipherSuites().size()]));
                startTLS.addEnabledProtocol(options.getEnabledProtocols().toArray(
                    new String[options.getEnabledProtocols().size()]));
                connection.extendedRequestAsync(startTLS).onSuccess(new SuccessHandler<ExtendedResult>() {
                    @Override
                    public void handleResult(final ExtendedResult result) {
                        onSuccess(connection);
                    }
                }).onFailure(new FailureHandler<LdapException>() {
                    @Override
                    public void handleError(final LdapException error) {
                        onFailure(connection, error);
                    }
                });
            } else {
                // Install SSL/TLS layer.
                try {
                    connection.startTLS(options.getSSLContext(), options.getEnabledProtocols(),
                        options.getEnabledCipherSuites(), new EmptyCompletionHandler<SSLEngine>() {
                            @Override
                            public void completed(final SSLEngine result) {
                                onSuccess(connection);
                            }
                            @Override
                            public void failed(final Throwable throwable) {
                                onFailure(connection, throwable);
                            }
                        });
                } catch (final IOException e) {
                    onFailure(connection, e);
                }
            }
        }
@@ -99,6 +160,51 @@
            // Adapt and forward.
            timeoutChecker.get().removeListener(this);
            promise.handleError(adaptConnectionException(throwable));
            releaseTransportAndTimeoutChecker();
        }
        @Override
        public void updated(final org.glassfish.grizzly.Connection result) {
            // Ignore this.
        }
        private GrizzlyLDAPConnection adaptConnection(
                final org.glassfish.grizzly.Connection<?> connection) {
            configureConnection(connection, options.isTCPNoDelay(), options.isKeepAlive(), options
                    .isReuseAddress(), options.getLinger(), logger);
            final GrizzlyLDAPConnection ldapConnection =
                    new GrizzlyLDAPConnection(connection, GrizzlyLDAPConnectionFactory.this);
            timeoutChecker.get().addListener(ldapConnection);
            clientFilter.registerConnection(connection, ldapConnection);
            return ldapConnection;
        }
        private LdapException adaptConnectionException(Throwable t) {
            if (!(t instanceof LdapException) && t instanceof ExecutionException) {
                t = t.getCause() != null ? t.getCause() : t;
            }
            if (t instanceof LdapException) {
                return (LdapException) t;
            } else {
                return newLdapException(ResultCode.CLIENT_SIDE_CONNECT_ERROR, t.getMessage(), t);
            }
        }
        private void onFailure(final GrizzlyLDAPConnection connection, final Throwable t) {
            // Abort connection attempt due to error.
            timeoutChecker.get().removeListener(this);
            promise.handleError(adaptConnectionException(t));
            connection.close();
        }
        private void onSuccess(final GrizzlyLDAPConnection connection) {
            timeoutChecker.get().removeListener(this);
            if (!promise.tryHandleResult(connection)) {
                // The connection has been either cancelled or it has timed out.
                connection.close();
            }
        }
        @Override
@@ -122,24 +228,25 @@
    private final LDAPClientFilter clientFilter;
    private final FilterChain defaultFilterChain;
    private final ReferenceCountedObject<TCPNIOTransport>.Reference transport;
    private final ReferenceCountedObject<TimeoutChecker>.Reference timeoutChecker = TIMEOUT_CHECKER.acquire();
    private final LDAPOptions options;
    private final String host;
    private final int port;
    @SuppressWarnings("rawtypes")
    private final Function<org.glassfish.grizzly.Connection, AbstractLdapConnectionImpl<?>, LdapException>
    convertToLDAPConnection =
        new Function<org.glassfish.grizzly.Connection, AbstractLdapConnectionImpl<?>, LdapException>() {
            @Override
            public GrizzlyLDAPConnection apply(org.glassfish.grizzly.Connection connection) throws LdapException {
                configureConnection(connection, options.isTCPNoDelay(), options.isKeepAlive(),
                    options.isReuseAddress(), options.getLinger(), logger);
                final GrizzlyLDAPConnection ldapConnection =
                    new GrizzlyLDAPConnection(connection, GrizzlyLDAPConnectionFactory.this);
                timeoutChecker.get().addListener(ldapConnection);
                clientFilter.registerConnection(connection, ldapConnection);
                return ldapConnection;
            }
        };
    /**
     * Prevents the transport and timeoutChecker being released when there are
     * remaining references (this factory or any connections). It is initially
     * set to 1 because this factory has a reference.
     */
    private final AtomicInteger referenceCount = new AtomicInteger(1);
    /**
     * Indicates whether this factory has been closed or not.
     */
    private final AtomicBoolean isClosed = new AtomicBoolean();
    private final ReferenceCountedObject<TCPNIOTransport>.Reference transport;
    private final ReferenceCountedObject<TimeoutChecker>.Reference timeoutChecker = TIMEOUT_CHECKER
            .acquire();
    /**
     * Creates a new LDAP connection factory based on Grizzly which can be used
@@ -155,15 +262,6 @@
     */
    public GrizzlyLDAPConnectionFactory(final String host, final int port, final LDAPOptions options) {
        this(host, port, options, null);
    }
    private LdapException adaptConnectionException(Throwable t) {
        if (t instanceof LdapException) {
            return (LdapException) t;
        }
        t = t instanceof ExecutionException && t.getCause() != null ? t.getCause() : t;
        return newLdapException(ResultCode.CLIENT_SIDE_CONNECT_ERROR, t.getMessage(), t);
    }
    /**
@@ -183,60 +281,88 @@
     *            connections. If {@code null}, default transport will be used.
     */
    public GrizzlyLDAPConnectionFactory(final String host, final int port, final LDAPOptions options,
            final TCPNIOTransport transport) {
        super(host, port, options);
                                        TCPNIOTransport transport) {
        this.transport = DEFAULT_TRANSPORT.acquireIfNull(transport);
        this.clientFilter = new LDAPClientFilter(options.getDecodeOptions(), 0);
        this.defaultFilterChain = buildFilterChain(this.transport.get().getProcessor(), clientFilter);
        this.host = host;
        this.port = port;
        this.options = new LDAPOptions(options);
        this.clientFilter = new LDAPClientFilter(this.options.getDecodeOptions(), 0);
        this.defaultFilterChain =
                buildFilterChain(this.transport.get().getProcessor(), clientFilter);
    }
    @Override
    public void close() {
        if (isClosed.compareAndSet(false, true)) {
            releaseTransportAndTimeoutChecker();
        }
    }
    @Override
    public Connection getConnection() throws LdapException {
        try {
            return getConnectionAsync().getOrThrow();
        } catch (final InterruptedException e) {
            throw newLdapException(ResultCode.CLIENT_SIDE_USER_CANCELLED, e);
        }
    }
    @Override
    public Promise<Connection, LdapException> getConnectionAsync() {
        acquireTransportAndTimeoutChecker(); // Protect resources.
        final SocketConnectorHandler connectorHandler =
                TCPNIOConnectorHandler.builder(transport.get()).processor(defaultFilterChain)
                        .build();
        final PromiseImpl<Connection, LdapException> promise = PromiseImpl.create();
        connectorHandler.connect(getSocketAddress(), new CompletionHandlerAdapter(promise));
        return promise;
    }
    @Override
    public InetSocketAddress getSocketAddress() {
        return new InetSocketAddress(host, port);
    }
    @Override
    public String getHostName() {
        return host;
    }
    @Override
    public int getPort() {
        return port;
    }
    @Override
    public String toString() {
        return getClass().getSimpleName() + "(" + host + ':' + port + ')';
    }
    TimeoutChecker getTimeoutChecker() {
        return timeoutChecker.get();
    }
    @Override
    @SuppressWarnings("rawtypes")
    protected Promise<AbstractLdapConnectionImpl<?>, LdapException> getConnectionAsync0() {
        final SocketConnectorHandler connectorHandler = TCPNIOConnectorHandler.builder(transport.get())
                .processor(defaultFilterChain).build();
        final PromiseImpl<org.glassfish.grizzly.Connection, LdapException> promise = PromiseImpl.create();
        connectorHandler.connect(getSocketAddress(), new CompletionHandlerAdapter(promise));
        return promise.then(convertToLDAPConnection);
    LDAPOptions getLDAPOptions() {
        return options;
    }
    @Override
    protected Promise<Void, LdapException> installSecureLayer(final Connection connection) {
        final PromiseImpl<Void, LdapException> sslHandshakePromise = PromiseImpl.create();
        try {
            final GrizzlyLDAPConnection grizzlyConnection = (GrizzlyLDAPConnection) connection;
            grizzlyConnection.startTLS(options.getSSLContext(), options.getEnabledProtocols(),
                    options.getEnabledCipherSuites(), new EmptyCompletionHandler<SSLEngine>() {
                        @Override
                        public void completed(final SSLEngine result) {
                            if (!sslHandshakePromise.tryHandleResult(null)) {
                                // The connection has been either cancelled or
                                // it has timed out.
                                connection.close();
                            }
                        }
                        @Override
                        public void failed(final Throwable throwable) {
                            sslHandshakePromise.handleError(adaptConnectionException(throwable));
                        }
                    });
        } catch (final IOException e) {
            sslHandshakePromise.handleError(adaptConnectionException(e));
    void releaseTransportAndTimeoutChecker() {
        if (referenceCount.decrementAndGet() == 0) {
            transport.release();
            timeoutChecker.release();
        }
        return sslHandshakePromise;
    }
    @Override
    protected void releaseImplResources() {
        transport.release();
        timeoutChecker.release();
    private void acquireTransportAndTimeoutChecker() {
        /*
         * If the factory is not closed then we need to prevent the resources
         * (transport, timeout checker) from being released while the connection
         * attempt is in progress.
         */
        referenceCount.incrementAndGet();
        if (isClosed.get()) {
            releaseTransportAndTimeoutChecker();
            throw new IllegalStateException("Attempted to get a connection after factory close");
        }
    }
}