1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.security.oauth.consumer.filter;
18
19 import java.io.IOException;
20 import java.io.UnsupportedEncodingException;
21 import java.net.URLEncoder;
22 import java.util.ArrayList;
23 import java.util.HashMap;
24 import java.util.Map;
25 import java.util.TreeMap;
26
27 import javax.servlet.Filter;
28 import javax.servlet.FilterChain;
29 import javax.servlet.FilterConfig;
30 import javax.servlet.ServletException;
31 import javax.servlet.ServletRequest;
32 import javax.servlet.ServletResponse;
33 import javax.servlet.http.HttpServletRequest;
34 import javax.servlet.http.HttpServletResponse;
35
36 import org.apache.commons.logging.Log;
37 import org.apache.commons.logging.LogFactory;
38 import org.springframework.beans.factory.InitializingBean;
39 import org.springframework.context.MessageSource;
40 import org.springframework.context.MessageSourceAware;
41 import org.springframework.context.support.MessageSourceAccessor;
42 import org.springframework.security.core.SpringSecurityMessageSource;
43 import org.springframework.security.oauth.common.OAuthProviderParameter;
44 import org.springframework.security.oauth.consumer.AccessTokenRequiredException;
45 import org.springframework.security.oauth.consumer.OAuthConsumerSupport;
46 import org.springframework.security.oauth.consumer.OAuthConsumerToken;
47 import org.springframework.security.oauth.consumer.OAuthRequestFailedException;
48 import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder;
49 import org.springframework.security.oauth.consumer.OAuthSecurityContextImpl;
50 import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
51 import org.springframework.security.oauth.consumer.rememberme.HttpSessionOAuthRememberMeServices;
52 import org.springframework.security.oauth.consumer.rememberme.OAuthRememberMeServices;
53 import org.springframework.security.oauth.consumer.token.HttpSessionBasedTokenServices;
54 import org.springframework.security.oauth.consumer.token.OAuthConsumerTokenServices;
55 import org.springframework.security.web.DefaultRedirectStrategy;
56 import org.springframework.security.web.PortResolver;
57 import org.springframework.security.web.PortResolverImpl;
58 import org.springframework.security.web.RedirectStrategy;
59 import org.springframework.security.web.access.AccessDeniedHandler;
60 import org.springframework.security.web.savedrequest.DefaultSavedRequest;
61 import org.springframework.security.web.util.ThrowableAnalyzer;
62 import org.springframework.security.web.util.ThrowableCauseExtractor;
63 import org.springframework.util.Assert;
64
65
66
67
68
69
70 public class OAuthConsumerContextFilter implements Filter, InitializingBean, MessageSourceAware {
71
72 public static final String ACCESS_TOKENS_DEFAULT_ATTRIBUTE = "OAUTH_ACCESS_TOKENS";
73 public static final String OAUTH_FAILURE_KEY = "OAUTH_FAILURE_KEY";
74 private static final Log LOG = LogFactory.getLog(OAuthConsumerContextFilter.class);
75
76 private AccessDeniedHandler OAuthFailureHandler;
77 protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
78 private OAuthRememberMeServices rememberMeServices = new HttpSessionOAuthRememberMeServices();
79 private OAuthConsumerSupport consumerSupport;
80 private String accessTokensRequestAttribute = ACCESS_TOKENS_DEFAULT_ATTRIBUTE;
81 private PortResolver portResolver = new PortResolverImpl();
82 private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
83 private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
84
85 private OAuthConsumerTokenServices tokenServices = new HttpSessionBasedTokenServices();
86
87 public void afterPropertiesSet() throws Exception {
88 Assert.notNull(rememberMeServices, "Remember-me services must be provided.");
89 Assert.notNull(consumerSupport, "Consumer support must be provided.");
90 Assert.notNull(tokenServices, "OAuth token services are required.");
91 Assert.notNull(redirectStrategy, "A redirect strategy must be supplied.");
92 }
93
94 public void init(FilterConfig ignored) throws ServletException {
95 }
96
97 public void destroy() {
98 }
99
100 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
101 HttpServletRequest request = (HttpServletRequest) servletRequest;
102 HttpServletResponse response = (HttpServletResponse) servletResponse;
103 OAuthSecurityContextImpl context = new OAuthSecurityContextImpl();
104 context.setDetails(request);
105
106 Map<String, OAuthConsumerToken> rememberedTokens = getRememberMeServices().loadRememberedTokens(request, response);
107 Map<String, OAuthConsumerToken> accessTokens = new TreeMap<String, OAuthConsumerToken>();
108 Map<String, OAuthConsumerToken> requestTokens = new TreeMap<String, OAuthConsumerToken>();
109 if (rememberedTokens != null) {
110 for (Map.Entry<String, OAuthConsumerToken> tokenEntry : rememberedTokens.entrySet()) {
111 OAuthConsumerToken token = tokenEntry.getValue();
112 if (token != null) {
113 if (token.isAccessToken()) {
114 accessTokens.put(tokenEntry.getKey(), token);
115 }
116 else {
117 requestTokens.put(tokenEntry.getKey(), token);
118 }
119 }
120 }
121 }
122
123 context.setAccessTokens(accessTokens);
124 OAuthSecurityContextHolder.setContext(context);
125 if (LOG.isDebugEnabled()) {
126 LOG.debug("Storing access tokens in request attribute '" + getAccessTokensRequestAttribute() + "'.");
127 }
128
129 try {
130 try {
131 request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
132 chain.doFilter(request, response);
133 }
134 catch (Exception e) {
135 try {
136 ProtectedResourceDetails resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e);
137 String neededResourceId = resourceThatNeedsAuthorization.getId();
138 while (!accessTokens.containsKey(neededResourceId)) {
139 OAuthConsumerToken token = requestTokens.remove(neededResourceId);
140 if (token == null) {
141 token = getTokenServices().getToken(neededResourceId);
142 }
143
144 String verifier = request.getParameter(OAuthProviderParameter.oauth_verifier.toString());
145
146
147 if (token == null || (!token.isAccessToken() && (!resourceThatNeedsAuthorization.isUse10a() || verifier == null))) {
148
149
150 if (LOG.isDebugEnabled()) {
151 LOG.debug("Obtaining request token for resource: " + neededResourceId);
152 }
153
154
155 String callbackURL = response.encodeRedirectURL(getCallbackURL(request));
156 token = getConsumerSupport().getUnauthorizedRequestToken(neededResourceId, callbackURL);
157 if (LOG.isDebugEnabled()) {
158 LOG.debug("Request token obtained for resource " + neededResourceId + ": " + token);
159 }
160
161
162 requestTokens.put(neededResourceId, token);
163 getTokenServices().storeToken(neededResourceId, token);
164 String redirect = getUserAuthorizationRedirectURL(resourceThatNeedsAuthorization, token, callbackURL);
165
166 if (LOG.isDebugEnabled()) {
167 LOG.debug("Redirecting request to " + redirect + " for user authorization of the request token for resource " + neededResourceId + ".");
168 }
169
170 request.setAttribute("org.springframework.security.oauth.consumer.AccessTokenRequiredException", e);
171 this.redirectStrategy.sendRedirect(request, response, redirect);
172 return;
173 }
174 else if (!token.isAccessToken()) {
175
176 if (LOG.isDebugEnabled()) {
177 LOG.debug("Obtaining access token for resource: " + neededResourceId);
178 }
179
180
181 try {
182 token = getConsumerSupport().getAccessToken(token, verifier);
183 }
184 finally {
185 getTokenServices().removeToken(neededResourceId);
186 }
187
188 if (LOG.isDebugEnabled()) {
189 LOG.debug("Access token " + token + " obtained for resource " + neededResourceId + ". Now storing and using.");
190 }
191
192 getTokenServices().storeToken(neededResourceId, token);
193 }
194
195 accessTokens.put(neededResourceId, token);
196
197 try {
198
199 if (!response.isCommitted()) {
200 request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
201 chain.doFilter(request, response);
202 }
203 else {
204
205 throw new IllegalStateException("Unable to reprocess filter chain with needed OAuth2 resources because the response is already committed.");
206 }
207 }
208 catch (Exception e1) {
209 resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e1);
210 neededResourceId = resourceThatNeedsAuthorization.getId();
211 }
212 }
213 }
214 catch (OAuthRequestFailedException eo) {
215 fail(request, response, eo);
216 }
217 catch (Exception ex) {
218 Throwable[] causeChain = getThrowableAnalyzer().determineCauseChain(ex);
219 OAuthRequestFailedException rfe = (OAuthRequestFailedException) getThrowableAnalyzer().getFirstThrowableOfType(OAuthRequestFailedException.class, causeChain);
220 if (rfe != null) {
221 fail(request, response, rfe);
222 }
223 else {
224
225 if (ex instanceof ServletException) {
226 throw (ServletException) ex;
227 }
228 else if (ex instanceof RuntimeException) {
229 throw (RuntimeException) ex;
230 }
231
232
233 throw new RuntimeException(ex);
234 }
235 }
236 }
237 }
238 finally {
239 OAuthSecurityContextHolder.setContext(null);
240 HashMap<String, OAuthConsumerToken> tokensToRemember = new HashMap<String, OAuthConsumerToken>();
241 tokensToRemember.putAll(requestTokens);
242 tokensToRemember.putAll(accessTokens);
243 getRememberMeServices().rememberTokens(tokensToRemember, request, response);
244 }
245 }
246
247
248
249
250
251
252
253
254
255
256 protected ProtectedResourceDetails checkForResourceThatNeedsAuthorization(Exception ex) throws ServletException, IOException {
257 Throwable[] causeChain = getThrowableAnalyzer().determineCauseChain(ex);
258 AccessTokenRequiredException ase = (AccessTokenRequiredException) getThrowableAnalyzer().getFirstThrowableOfType(AccessTokenRequiredException.class, causeChain);
259 ProtectedResourceDetails resourceThatNeedsAuthorization;
260 if (ase != null) {
261 resourceThatNeedsAuthorization = ase.getResource();
262 if (resourceThatNeedsAuthorization == null) {
263 throw new OAuthRequestFailedException(ase.getMessage());
264 }
265 }
266 else {
267
268 if (ex instanceof ServletException) {
269 throw (ServletException) ex;
270 }
271 if (ex instanceof IOException) {
272 throw (IOException) ex;
273 }
274 else if (ex instanceof RuntimeException) {
275 throw (RuntimeException) ex;
276 }
277
278
279 throw new RuntimeException(ex);
280 }
281 return resourceThatNeedsAuthorization;
282 }
283
284
285
286
287
288
289
290 protected String getCallbackURL(HttpServletRequest request) {
291 return new DefaultSavedRequest(request, getPortResolver()).getRedirectUrl();
292 }
293
294
295
296
297
298
299
300
301
302 protected String getUserAuthorizationRedirectURL(ProtectedResourceDetails details, OAuthConsumerToken requestToken, String callbackURL) {
303 try {
304 String baseURL = details.getUserAuthorizationURL();
305 StringBuilder builder = new StringBuilder(baseURL);
306 char appendChar = baseURL.indexOf('?') < 0 ? '?' : '&';
307 builder.append(appendChar).append("oauth_token=");
308 builder.append(URLEncoder.encode(requestToken.getValue(), "UTF-8"));
309 if (!details.isUse10a()) {
310 builder.append('&').append("oauth_callback=");
311 builder.append(URLEncoder.encode(callbackURL, "UTF-8"));
312 }
313 return builder.toString();
314 }
315 catch (UnsupportedEncodingException e) {
316 throw new IllegalStateException(e);
317 }
318 }
319
320
321
322
323
324
325
326
327
328
329
330 protected void fail(HttpServletRequest request, HttpServletResponse response, OAuthRequestFailedException failure) throws IOException, ServletException {
331 try {
332
333 request.getSession().setAttribute(OAUTH_FAILURE_KEY, failure);
334 }
335 catch (Exception e) {
336
337 }
338
339 if (LOG.isDebugEnabled()) {
340 LOG.debug(failure);
341 }
342
343 if (getOAuthFailureHandler() != null) {
344 getOAuthFailureHandler().handle(request, response, failure);
345 }
346 else {
347 throw failure;
348 }
349 }
350
351
352
353
354
355
356 public AccessDeniedHandler getOAuthFailureHandler() {
357 return OAuthFailureHandler;
358 }
359
360
361
362
363
364
365 public void setOAuthFailureHandler(AccessDeniedHandler OAuthFailureHandler) {
366 this.OAuthFailureHandler = OAuthFailureHandler;
367 }
368
369
370
371
372
373
374 public OAuthConsumerTokenServices getTokenServices() {
375 return tokenServices;
376 }
377
378
379
380
381
382
383 public void setTokenServices(OAuthConsumerTokenServices tokenServices) {
384 this.tokenServices = tokenServices;
385 }
386
387
388
389
390
391
392 public void setMessageSource(MessageSource messageSource) {
393 this.messages = new MessageSourceAccessor(messageSource);
394 }
395
396
397
398
399
400
401 public OAuthConsumerSupport getConsumerSupport() {
402 return consumerSupport;
403 }
404
405
406
407
408
409
410 public void setConsumerSupport(OAuthConsumerSupport consumerSupport) {
411 this.consumerSupport = consumerSupport;
412 }
413
414
415
416
417
418
419 public String getAccessTokensRequestAttribute() {
420 return accessTokensRequestAttribute;
421 }
422
423
424
425
426
427
428 public void setAccessTokensRequestAttribute(String accessTokensRequestAttribute) {
429 this.accessTokensRequestAttribute = accessTokensRequestAttribute;
430 }
431
432
433
434
435
436
437 public PortResolver getPortResolver() {
438 return portResolver;
439 }
440
441
442
443
444
445
446 public void setPortResolver(PortResolver portResolver) {
447 this.portResolver = portResolver;
448 }
449
450
451
452
453
454
455 public OAuthRememberMeServices getRememberMeServices() {
456 return rememberMeServices;
457 }
458
459
460
461
462
463
464 public void setRememberMeServices(OAuthRememberMeServices rememberMeServices) {
465 this.rememberMeServices = rememberMeServices;
466 }
467
468
469
470
471
472
473 public ThrowableAnalyzer getThrowableAnalyzer() {
474 return throwableAnalyzer;
475 }
476
477
478
479
480
481
482 public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
483 this.throwableAnalyzer = throwableAnalyzer;
484 }
485
486
487
488
489
490
491 public RedirectStrategy getRedirectStrategy() {
492 return redirectStrategy;
493 }
494
495
496
497
498
499
500 public void setRedirectStrategy(RedirectStrategy redirectStrategy) {
501 this.redirectStrategy = redirectStrategy;
502 }
503
504
505
506
507
508 private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer {
509
510
511
512 protected void initExtractorMap() {
513 super.initExtractorMap();
514
515 registerExtractor(ServletException.class, new ThrowableCauseExtractor() {
516 public Throwable extractCause(Throwable throwable) {
517 ThrowableAnalyzer.verifyThrowableHierarchy(throwable, ServletException.class);
518 return ((ServletException) throwable).getRootCause();
519 }
520 });
521 }
522 }
523 }