mirror of https://github.com/OpenIdentityPlatform/OpenDJ.git

Matthew Swift
27.23.2013 2dff202709790a27ef5cd47e6e10f43b82f4823d
opendj-sdk/opendj3/opendj-rest2ldap-servlet/src/main/java/org/forgerock/opendj/rest2ldap/servlet/Rest2LDAPAuthnFilter.java
@@ -49,8 +49,8 @@
import org.forgerock.json.fluent.JsonValue;
import org.forgerock.json.fluent.JsonValueException;
import org.forgerock.json.resource.ResourceException;
import org.forgerock.json.resource.servlet.CompletionHandler;
import org.forgerock.json.resource.servlet.CompletionHandlerFactory;
import org.forgerock.json.resource.servlet.ServletSynchronizer;
import org.forgerock.json.resource.servlet.ServletApiVersionAdapter;
import org.forgerock.opendj.ldap.AuthenticationException;
import org.forgerock.opendj.ldap.AuthorizationException;
import org.forgerock.opendj.ldap.ByteString;
@@ -101,7 +101,7 @@
    private SearchScope searchScope = SearchScope.WHOLE_SUBTREE;
    private boolean supportAltAuthentication;
    private boolean supportHTTPBasicAuthentication = true;
    private CompletionHandlerFactory completionHandlerFactory;
    private ServletApiVersionAdapter syncFactory;
    /**
     * {@inheritDoc}
@@ -139,16 +139,14 @@
         * completion.
         */
        final AtomicReference<Connection> savedConnection = new AtomicReference<Connection>();
        final CompletionHandler completionHandler =
                completionHandlerFactory.createCompletionHandler(req, res);
        if (completionHandler.isAsynchronous()) {
            completionHandler.addCompletionListener(new Runnable() {
                @Override
                public void run() {
                    closeConnection(savedConnection);
                }
            });
        }
        final ServletSynchronizer sync = syncFactory.createServletSynchronizer(req, res);
        sync.addAsyncListener(new Runnable() {
            @Override
            public void run() {
                closeConnection(savedConnection);
            }
        });
        try {
            final String headerUsername =
@@ -196,8 +194,8 @@
                authzid = new LinkedHashMap<String, Object>(2);
                authzid.put(AUTHZID_DN, username);
                authzid.put(AUTHZID_ID, username);
                doBind(req, response, newSimpleBindRequest(username, password), chain,
                        savedConnection, completionHandler, username, authzid);
                doBind(req, res, newSimpleBindRequest(username, password), chain, savedConnection,
                        sync, username, authzid);
                break;
            }
            case SASL_PLAIN: {
@@ -215,8 +213,8 @@
                    bindId = String.format(saslAuthzIdTemplate, username);
                    authzid = Collections.singletonMap(AUTHZID_ID, (Object) username);
                }
                doBind(req, response, newPlainSASLBindRequest(bindId, password), chain,
                        savedConnection, completionHandler, username, authzid);
                doBind(req, res, newPlainSASLBindRequest(bindId, password), chain, savedConnection,
                        sync, username, authzid);
                break;
            }
            default: // SEARCH_SIMPLE
@@ -232,7 +230,7 @@
                searchLDAPConnectionFactory.getConnectionAsync(new ResultHandler<Connection>() {
                    @Override
                    public void handleErrorResult(final ErrorResultException error) {
                        completionHandler.onError(asResourceException(error));
                        sync.signalAndComplete(asResourceException(error));
                    }
                    @Override
@@ -263,8 +261,7 @@
                                        } else {
                                            normalizedError = error;
                                        }
                                        completionHandler
                                                .onError(asResourceException(normalizedError));
                                        sync.signalAndComplete(asResourceException(normalizedError));
                                    }
                                    @Override
@@ -275,10 +272,8 @@
                                                new LinkedHashMap<String, Object>(2);
                                        authzid.put(AUTHZID_DN, bindDN);
                                        authzid.put(AUTHZID_ID, username);
                                        doBind(req, response,
                                                newSimpleBindRequest(bindDN, password), chain,
                                                savedConnection, completionHandler, username,
                                                authzid);
                                        doBind(req, res, newSimpleBindRequest(bindDN, password),
                                                chain, savedConnection, sync, username, authzid);
                                    }
                                });
                    }
@@ -286,20 +281,14 @@
                break;
            }
            }
            /*
             * Block until authentication completes if needed and then invoke
             * the remainder of the filter chain.
             */
            if (!completionHandler.isAsynchronous()) {
                completionHandler.awaitIfNeeded();
            sync.awaitIfNeeded();
            if (!sync.isAsync()) {
                chain.doFilter(request, response);
                closeConnection(savedConnection);
            }
        } catch (final Throwable t) {
            // Complete and close the connection if needed.
            completionHandler.onError(t);
            if (!completionHandler.isAsynchronous()) {
            sync.signalAndComplete(t);
        } finally {
            if (!sync.isAsync()) {
                closeConnection(savedConnection);
            }
        }
@@ -389,8 +378,7 @@
                                "ldapConnectionFactories").required(), ldapFactoryName);
                // Set the completion handler factory based on the Servlet API version.
                completionHandlerFactory =
                        CompletionHandlerFactory.getInstance(config.getServletContext());
                syncFactory = ServletApiVersionAdapter.getInstance(config.getServletContext());
                isEnabled = true;
            }
@@ -422,13 +410,12 @@
     */
    private void doBind(final HttpServletRequest request, final ServletResponse response,
            final BindRequest bindRequest, final FilterChain chain,
            final AtomicReference<Connection> savedConnection,
            final CompletionHandler completionHandler, final String authcid,
            final Map<String, Object> authzid) {
            final AtomicReference<Connection> savedConnection, final ServletSynchronizer sync,
            final String authcid, final Map<String, Object> authzid) {
        bindLDAPConnectionFactory.getConnectionAsync(new ResultHandler<Connection>() {
            @Override
            public void handleErrorResult(final ErrorResultException error) {
                completionHandler.onError(asResourceException(error));
                sync.signalAndComplete(asResourceException(error));
            }
            @Override
@@ -438,7 +425,7 @@
                    @Override
                    public void handleErrorResult(final ErrorResultException error) {
                        completionHandler.onError(asResourceException(error));
                        sync.signalAndComplete(asResourceException(error));
                    }
                    @Override
@@ -452,11 +439,14 @@
                        request.setAttribute(ATTRIBUTE_AUTHCID, authcid);
                        request.setAttribute(ATTRIBUTE_AUTHZID, authzid);
                        // Invoke the remained of the filter chain.
                        try {
                            chain.doFilter(request, response);
                        } catch (final Throwable t) {
                            completionHandler.onError(asResourceException(t));
                        // Invoke the remainder of the filter chain.
                        sync.signal();
                        if (sync.isAsync()) {
                            try {
                                chain.doFilter(request, response);
                            } catch (Throwable t) {
                                sync.signalAndComplete(asResourceException(t));
                            }
                        }
                    }
                });