View Javadoc
1   /*
2    * Copyright 2008-2009 Web Cohesion
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.security.oauth.consumer.filter;
18  
19  import java.io.IOException;
20  import java.io.UnsupportedEncodingException;
21  import java.net.URLEncoder;
22  import java.util.ArrayList;
23  import java.util.HashMap;
24  import java.util.Map;
25  import java.util.TreeMap;
26  
27  import javax.servlet.Filter;
28  import javax.servlet.FilterChain;
29  import javax.servlet.FilterConfig;
30  import javax.servlet.ServletException;
31  import javax.servlet.ServletRequest;
32  import javax.servlet.ServletResponse;
33  import javax.servlet.http.HttpServletRequest;
34  import javax.servlet.http.HttpServletResponse;
35  
36  import org.apache.commons.logging.Log;
37  import org.apache.commons.logging.LogFactory;
38  import org.springframework.beans.factory.InitializingBean;
39  import org.springframework.context.MessageSource;
40  import org.springframework.context.MessageSourceAware;
41  import org.springframework.context.support.MessageSourceAccessor;
42  import org.springframework.security.core.SpringSecurityMessageSource;
43  import org.springframework.security.oauth.common.OAuthProviderParameter;
44  import org.springframework.security.oauth.consumer.AccessTokenRequiredException;
45  import org.springframework.security.oauth.consumer.OAuthConsumerSupport;
46  import org.springframework.security.oauth.consumer.OAuthConsumerToken;
47  import org.springframework.security.oauth.consumer.OAuthRequestFailedException;
48  import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder;
49  import org.springframework.security.oauth.consumer.OAuthSecurityContextImpl;
50  import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
51  import org.springframework.security.oauth.consumer.rememberme.HttpSessionOAuthRememberMeServices;
52  import org.springframework.security.oauth.consumer.rememberme.OAuthRememberMeServices;
53  import org.springframework.security.oauth.consumer.token.HttpSessionBasedTokenServices;
54  import org.springframework.security.oauth.consumer.token.OAuthConsumerTokenServices;
55  import org.springframework.security.web.DefaultRedirectStrategy;
56  import org.springframework.security.web.PortResolver;
57  import org.springframework.security.web.PortResolverImpl;
58  import org.springframework.security.web.RedirectStrategy;
59  import org.springframework.security.web.access.AccessDeniedHandler;
60  import org.springframework.security.web.savedrequest.DefaultSavedRequest;
61  import org.springframework.security.web.util.ThrowableAnalyzer;
62  import org.springframework.security.web.util.ThrowableCauseExtractor;
63  import org.springframework.util.Assert;
64  
65  /**
66   * OAuth filter that establishes an OAuth security context.
67   *
68   * @author Ryan Heaton
69   */
70  public class OAuthConsumerContextFilter implements Filter, InitializingBean, MessageSourceAware {
71  
72  	public static final String ACCESS_TOKENS_DEFAULT_ATTRIBUTE = "OAUTH_ACCESS_TOKENS";
73  	public static final String OAUTH_FAILURE_KEY = "OAUTH_FAILURE_KEY";
74  	private static final Log LOG = LogFactory.getLog(OAuthConsumerContextFilter.class);
75  
76  	private AccessDeniedHandler OAuthFailureHandler;
77  	protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
78  	private OAuthRememberMeServices rememberMeServices = new HttpSessionOAuthRememberMeServices();
79  	private OAuthConsumerSupport consumerSupport;
80  	private String accessTokensRequestAttribute = ACCESS_TOKENS_DEFAULT_ATTRIBUTE;
81  	private PortResolver portResolver = new PortResolverImpl();
82  	private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
83  	private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
84  
85  	private OAuthConsumerTokenServices tokenServices = new HttpSessionBasedTokenServices();
86  
87  	public void afterPropertiesSet() throws Exception {
88  		Assert.notNull(rememberMeServices, "Remember-me services must be provided.");
89  		Assert.notNull(consumerSupport, "Consumer support must be provided.");
90  		Assert.notNull(tokenServices, "OAuth token services are required.");
91  		Assert.notNull(redirectStrategy, "A redirect strategy must be supplied.");
92  	}
93  
94  	public void init(FilterConfig ignored) throws ServletException {
95  	}
96  
97  	public void destroy() {
98  	}
99  
100 	public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
101 		HttpServletRequest request = (HttpServletRequest) servletRequest;
102 		HttpServletResponse response = (HttpServletResponse) servletResponse;
103 		OAuthSecurityContextImpl context = new OAuthSecurityContextImpl();
104 		context.setDetails(request);
105 
106 		Map<String, OAuthConsumerToken> rememberedTokens = getRememberMeServices().loadRememberedTokens(request, response);
107 		Map<String, OAuthConsumerToken> accessTokens = new TreeMap<String, OAuthConsumerToken>();
108 		Map<String, OAuthConsumerToken> requestTokens = new TreeMap<String, OAuthConsumerToken>();
109 		if (rememberedTokens != null) {
110 			for (Map.Entry<String, OAuthConsumerToken> tokenEntry : rememberedTokens.entrySet()) {
111 				OAuthConsumerToken token = tokenEntry.getValue();
112 				if (token != null) {
113 					if (token.isAccessToken()) {
114 						accessTokens.put(tokenEntry.getKey(), token);
115 					}
116 					else {
117 						requestTokens.put(tokenEntry.getKey(), token);
118 					}
119 				}
120 			}
121 		}
122 
123 		context.setAccessTokens(accessTokens);
124 		OAuthSecurityContextHolder.setContext(context);
125 		if (LOG.isDebugEnabled()) {
126 			LOG.debug("Storing access tokens in request attribute '" + getAccessTokensRequestAttribute() + "'.");
127 		}
128 
129 		try {
130 			try {
131 				request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
132 				chain.doFilter(request, response);
133 			}
134 			catch (Exception e) {
135 				try {
136 					ProtectedResourceDetails resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e);
137 					String neededResourceId = resourceThatNeedsAuthorization.getId();
138 					while (!accessTokens.containsKey(neededResourceId)) {
139 						OAuthConsumerToken token = requestTokens.remove(neededResourceId);
140 						if (token == null) {
141 							token = getTokenServices().getToken(neededResourceId);
142 						}
143 
144 						String verifier = request.getParameter(OAuthProviderParameter.oauth_verifier.toString());
145 						// if the token is null OR
146 						// if there is NO access token and (we're not using 1.0a or the verifier is not null)
147 						if (token == null || (!token.isAccessToken() && (!resourceThatNeedsAuthorization.isUse10a() || verifier == null))) {
148 							//no token associated with the resource, start the oauth flow.
149 							//if there's a request token, but no verifier, we'll assume that a previous oauth request failed and we need to get a new request token.
150 							if (LOG.isDebugEnabled()) {
151 								LOG.debug("Obtaining request token for resource: " + neededResourceId);
152 							}
153 
154 							//obtain authorization.
155 							String callbackURL = response.encodeRedirectURL(getCallbackURL(request));
156 							token = getConsumerSupport().getUnauthorizedRequestToken(neededResourceId, callbackURL);
157 							if (LOG.isDebugEnabled()) {
158 								LOG.debug("Request token obtained for resource " + neededResourceId + ": " + token);
159 							}
160 
161 							//okay, we've got a request token, now we need to authorize it.
162 							requestTokens.put(neededResourceId, token);
163 							getTokenServices().storeToken(neededResourceId, token);
164 							String redirect = getUserAuthorizationRedirectURL(resourceThatNeedsAuthorization, token, callbackURL);
165 
166 							if (LOG.isDebugEnabled()) {
167 								LOG.debug("Redirecting request to " + redirect + " for user authorization of the request token for resource " + neededResourceId + ".");
168 							}
169 
170 							request.setAttribute("org.springframework.security.oauth.consumer.AccessTokenRequiredException", e);
171 							this.redirectStrategy.sendRedirect(request, response, redirect);
172 							return;
173 						}
174 						else if (!token.isAccessToken()) {
175 							//we have a presumably authorized request token, let's try to get an access token with it.
176 							if (LOG.isDebugEnabled()) {
177 								LOG.debug("Obtaining access token for resource: " + neededResourceId);
178 							}
179 
180 							//authorize the request token and store it.
181 							try {
182 								token = getConsumerSupport().getAccessToken(token, verifier);
183 							}
184 							finally {
185 								getTokenServices().removeToken(neededResourceId);
186 							}
187 
188 							if (LOG.isDebugEnabled()) {
189 								LOG.debug("Access token " + token + " obtained for resource " + neededResourceId + ". Now storing and using.");
190 							}
191 
192 							getTokenServices().storeToken(neededResourceId, token);
193 						}
194 
195 						accessTokens.put(neededResourceId, token);
196 
197 						try {
198 							//try again
199 							if (!response.isCommitted()) {
200 								request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
201 								chain.doFilter(request, response);
202 							}
203 							else {
204 								//dang. what do we do now?
205 								throw new IllegalStateException("Unable to reprocess filter chain with needed OAuth2 resources because the response is already committed.");
206 							}
207 						}
208 						catch (Exception e1) {
209 							resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e1);
210 							neededResourceId = resourceThatNeedsAuthorization.getId();
211 						}
212 					}
213 				}
214 				catch (OAuthRequestFailedException eo) {
215 					fail(request, response, eo);
216 				}
217 				catch (Exception ex) {
218 					Throwable[] causeChain = getThrowableAnalyzer().determineCauseChain(ex);
219 					OAuthRequestFailedException rfe = (OAuthRequestFailedException) getThrowableAnalyzer().getFirstThrowableOfType(OAuthRequestFailedException.class, causeChain);
220 					if (rfe != null) {
221 						fail(request, response, rfe);
222 					}
223 					else {
224 						// Rethrow ServletExceptions and RuntimeExceptions as-is
225 						if (ex instanceof ServletException) {
226 							throw (ServletException) ex;
227 						}
228 						else if (ex instanceof RuntimeException) {
229 							throw (RuntimeException) ex;
230 						}
231 
232 						// Wrap other Exceptions. These are not expected to happen
233 						throw new RuntimeException(ex);
234 					}
235 				}
236 			}
237 		}
238 		finally {
239 			OAuthSecurityContextHolder.setContext(null);
240 			HashMap<String, OAuthConsumerToken> tokensToRemember = new HashMap<String, OAuthConsumerToken>();
241 			tokensToRemember.putAll(requestTokens);
242 			tokensToRemember.putAll(accessTokens);
243 			getRememberMeServices().rememberTokens(tokensToRemember, request, response);
244 		}
245 	}
246 
247 	/**
248 	 * Check the given exception for the resource that needs authorization. If the exception was not thrown because a resource needed authorization, then rethrow
249 	 * the exception.
250 	 *
251 	 * @param ex The exception.
252 	 * @return The resource that needed authorization (never null).
253 	 * @throws ServletException in the case of an underlying Servlet API exception
254 	 * @throws IOException in the case of general IO exceptions
255 	 */
256 	protected ProtectedResourceDetails checkForResourceThatNeedsAuthorization(Exception ex) throws ServletException, IOException {
257 		Throwable[] causeChain = getThrowableAnalyzer().determineCauseChain(ex);
258 		AccessTokenRequiredException ase = (AccessTokenRequiredException) getThrowableAnalyzer().getFirstThrowableOfType(AccessTokenRequiredException.class, causeChain);
259 		ProtectedResourceDetails resourceThatNeedsAuthorization;
260 		if (ase != null) {
261 			resourceThatNeedsAuthorization = ase.getResource();
262 			if (resourceThatNeedsAuthorization == null) {
263 				throw new OAuthRequestFailedException(ase.getMessage());
264 			}
265 		}
266 		else {
267 			// Rethrow ServletExceptions and RuntimeExceptions as-is
268 			if (ex instanceof ServletException) {
269 				throw (ServletException) ex;
270 			}
271 			if (ex instanceof IOException) {
272 				throw (IOException) ex;
273 			}
274 			else if (ex instanceof RuntimeException) {
275 				throw (RuntimeException) ex;
276 			}
277 
278 			// Wrap other Exceptions. These are not expected to happen
279 			throw new RuntimeException(ex);
280 		}
281 		return resourceThatNeedsAuthorization;
282 	}
283 
284 	/**
285 	 * Get the callback URL for the specified request.
286 	 *
287 	 * @param request The request.
288 	 * @return The callback URL.
289 	 */
290 	protected String getCallbackURL(HttpServletRequest request) {
291 		return new DefaultSavedRequest(request, getPortResolver()).getRedirectUrl();
292 	}
293 
294 	/**
295 	 * Get the URL to which to redirect the user for authorization of protected resources.
296 	 *
297 	 * @param details	  The resource for which to get the authorization url.
298 	 * @param requestToken The request token.
299 	 * @param callbackURL  The callback URL.
300 	 * @return The URL.
301 	 */
302 	protected String getUserAuthorizationRedirectURL(ProtectedResourceDetails details, OAuthConsumerToken requestToken, String callbackURL) {
303 		try {
304 			String baseURL = details.getUserAuthorizationURL();
305 			StringBuilder builder = new StringBuilder(baseURL);
306 			char appendChar = baseURL.indexOf('?') < 0 ? '?' : '&';
307 			builder.append(appendChar).append("oauth_token=");
308 			builder.append(URLEncoder.encode(requestToken.getValue(), "UTF-8"));
309 			if (!details.isUse10a()) {
310 				builder.append('&').append("oauth_callback=");
311 				builder.append(URLEncoder.encode(callbackURL, "UTF-8"));
312 			}
313 			return builder.toString();
314 		}
315 		catch (UnsupportedEncodingException e) {
316 			throw new IllegalStateException(e);
317 		}
318 	}
319 
320 	/**
321 	 * Common logic for OAuth failed. (Note that the default logic doesn't pass the failure through so as to not mess
322 	 * with the current authentication.)
323 	 *
324 	 * @param request  The request.
325 	 * @param response The response.
326 	 * @param failure  The failure.
327 	 * @throws ServletException in the case of an underlying Servlet API exception
328 	 * @throws IOException in the case of general IO exceptions
329 	 */
330 	protected void fail(HttpServletRequest request, HttpServletResponse response, OAuthRequestFailedException failure) throws IOException, ServletException {
331 		try {
332 			//attempt to set the last exception.
333 			request.getSession().setAttribute(OAUTH_FAILURE_KEY, failure);
334 		}
335 		catch (Exception e) {
336 			//fall through....
337 		}
338 
339 		if (LOG.isDebugEnabled()) {
340 			LOG.debug(failure);
341 		}
342 
343 		if (getOAuthFailureHandler() != null) {
344 			getOAuthFailureHandler().handle(request, response, failure);
345 		}
346 		else {
347 			throw failure;
348 		}
349 	}
350 
351 	/**
352 	 * The oauth failure handler.
353 	 *
354 	 * @return The oauth failure handler.
355 	 */
356 	public AccessDeniedHandler getOAuthFailureHandler() {
357 		return OAuthFailureHandler;
358 	}
359 
360 	/**
361 	 * The oauth failure handler.
362 	 *
363 	 * @param OAuthFailureHandler The oauth failure handler.
364 	 */
365 	public void setOAuthFailureHandler(AccessDeniedHandler OAuthFailureHandler) {
366 		this.OAuthFailureHandler = OAuthFailureHandler;
367 	}
368 
369 	/**
370 	 * The token services.
371 	 *
372 	 * @return The token services.
373 	 */
374 	public OAuthConsumerTokenServices getTokenServices() {
375 		return tokenServices;
376 	}
377 
378 	/**
379 	 * The token services.
380 	 *
381 	 * @param tokenServices The token services.
382 	 */
383 	public void setTokenServices(OAuthConsumerTokenServices tokenServices) {
384 		this.tokenServices = tokenServices;
385 	}
386 
387 	/**
388 	 * Set the message source.
389 	 *
390 	 * @param messageSource The message source.
391 	 */
392 	public void setMessageSource(MessageSource messageSource) {
393 		this.messages = new MessageSourceAccessor(messageSource);
394 	}
395 
396 	/**
397 	 * The OAuth consumer support.
398 	 *
399 	 * @return The OAuth consumer support.
400 	 */
401 	public OAuthConsumerSupport getConsumerSupport() {
402 		return consumerSupport;
403 	}
404 
405 	/**
406 	 * The OAuth consumer support.
407 	 *
408 	 * @param consumerSupport The OAuth consumer support.
409 	 */
410 	public void setConsumerSupport(OAuthConsumerSupport consumerSupport) {
411 		this.consumerSupport = consumerSupport;
412 	}
413 
414 	/**
415 	 * The default request attribute into which the OAuth access tokens are stored.
416 	 *
417 	 * @return The default request attribute into which the OAuth access tokens are stored.
418 	 */
419 	public String getAccessTokensRequestAttribute() {
420 		return accessTokensRequestAttribute;
421 	}
422 
423 	/**
424 	 * The default request attribute into which the OAuth access tokens are stored.
425 	 *
426 	 * @param accessTokensRequestAttribute The default request attribute into which the OAuth access tokens are stored.
427 	 */
428 	public void setAccessTokensRequestAttribute(String accessTokensRequestAttribute) {
429 		this.accessTokensRequestAttribute = accessTokensRequestAttribute;
430 	}
431 
432 	/**
433 	 * The port resolver.
434 	 *
435 	 * @return The port resolver.
436 	 */
437 	public PortResolver getPortResolver() {
438 		return portResolver;
439 	}
440 
441 	/**
442 	 * The port resolver.
443 	 *
444 	 * @param portResolver The port resolver.
445 	 */
446 	public void setPortResolver(PortResolver portResolver) {
447 		this.portResolver = portResolver;
448 	}
449 
450 	/**
451 	 * The remember-me services.
452 	 *
453 	 * @return The remember-me services.
454 	 */
455 	public OAuthRememberMeServices getRememberMeServices() {
456 		return rememberMeServices;
457 	}
458 
459 	/**
460 	 * The remember-me services.
461 	 *
462 	 * @param rememberMeServices The remember-me services.
463 	 */
464 	public void setRememberMeServices(OAuthRememberMeServices rememberMeServices) {
465 		this.rememberMeServices = rememberMeServices;
466 	}
467 
468 	/**
469 	 * The throwable analyzer.
470 	 *
471 	 * @return The throwable analyzer.
472 	 */
473 	public ThrowableAnalyzer getThrowableAnalyzer() {
474 		return throwableAnalyzer;
475 	}
476 
477 	/**
478 	 * The throwable analyzer.
479 	 *
480 	 * @param throwableAnalyzer The throwable analyzer.
481 	 */
482 	public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
483 		this.throwableAnalyzer = throwableAnalyzer;
484 	}
485 
486 	/**
487 	 * The redirect strategy.
488 	 *
489 	 * @return The redirect strategy.
490 	 */
491 	public RedirectStrategy getRedirectStrategy() {
492 		return redirectStrategy;
493 	}
494 
495 	/**
496 	 * The redirect strategy.
497 	 *
498 	 * @param redirectStrategy The redirect strategy.
499 	 */
500 	public void setRedirectStrategy(RedirectStrategy redirectStrategy) {
501 		this.redirectStrategy = redirectStrategy;
502 	}
503 
504 	/**
505 	 * Default implementation of <code>ThrowableAnalyzer</code> which is capable of also unwrapping
506 	 * <code>ServletException</code>s.
507 	 */
508 	private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer {
509 		/**
510 		 * @see org.springframework.security.web.util.ThrowableAnalyzer#initExtractorMap()
511 		 */
512 		protected void initExtractorMap() {
513 			super.initExtractorMap();
514 
515 			registerExtractor(ServletException.class, new ThrowableCauseExtractor() {
516 				public Throwable extractCause(Throwable throwable) {
517 					ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class);
518 					return ((ServletException) throwable).getRootCause();
519 				}
520 			});
521 		}
522 	}
523 }