View Javadoc
1   package org.springframework.security.oauth2.client.filter;
2   
3   import java.io.IOException;
4   import java.io.UnsupportedEncodingException;
5   import java.util.Map;
6   
7   import javax.servlet.Filter;
8   import javax.servlet.FilterChain;
9   import javax.servlet.FilterConfig;
10  import javax.servlet.ServletException;
11  import javax.servlet.ServletRequest;
12  import javax.servlet.ServletResponse;
13  import javax.servlet.http.HttpServletRequest;
14  import javax.servlet.http.HttpServletResponse;
15  
16  import org.springframework.beans.factory.InitializingBean;
17  import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
18  import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
19  import org.springframework.security.web.DefaultRedirectStrategy;
20  import org.springframework.security.web.RedirectStrategy;
21  import org.springframework.security.web.util.ThrowableAnalyzer;
22  import org.springframework.util.Assert;
23  import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
24  import org.springframework.web.util.NestedServletException;
25  import org.springframework.web.util.UriComponents;
26  import org.springframework.web.util.UriComponentsBuilder;
27  
28  /**
29   * Security filter for an OAuth2 client.
30   * 
31   * @author Ryan Heaton
32   * @author Dave Syer
33   */
34  public class OAuth2ClientContextFilter implements Filter, InitializingBean {
35  
36  	/**
37  	 * Key in request attributes for the current URI in case it is needed by
38  	 * rest client code that needs to send a redirect URI to an authorization
39  	 * server.
40  	 */
41  	public static final String CURRENT_URI = "currentUri";
42  
43  	private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
44  
45  	private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
46  
47  	public void afterPropertiesSet() throws Exception {
48  		Assert.notNull(redirectStrategy,
49  				"A redirect strategy must be supplied.");
50  	}
51  
52  	public void doFilter(ServletRequest servletRequest,
53  			ServletResponse servletResponse, FilterChain chain)
54  			throws IOException, ServletException {
55  		HttpServletRequest request = (HttpServletRequest) servletRequest;
56  		HttpServletResponse response = (HttpServletResponse) servletResponse;
57  		request.setAttribute(CURRENT_URI, calculateCurrentUri(request));
58  
59  		try {
60  			chain.doFilter(servletRequest, servletResponse);
61  		} catch (IOException ex) {
62  			throw ex;
63  		} catch (Exception ex) {
64  			// Try to extract a SpringSecurityException from the stacktrace
65  			Throwable[] causeChain = throwableAnalyzer.determineCauseChain(ex);
66  			UserRedirectRequiredException redirect = (UserRedirectRequiredException) throwableAnalyzer
67  					.getFirstThrowableOfType(
68  							UserRedirectRequiredException.class, causeChain);
69  			if (redirect != null) {
70  				redirectUser(redirect, request, response);
71  			} else {
72  				if (ex instanceof ServletException) {
73  					throw (ServletException) ex;
74  				}
75  				if (ex instanceof RuntimeException) {
76  					throw (RuntimeException) ex;
77  				}
78  				throw new NestedServletException("Unhandled exception", ex);
79  			}
80  		}
81  	}
82  
83  	/**
84  	 * Redirect the user according to the specified exception.
85  	 * 
86  	 * @param e
87  	 *            The user redirect exception.
88  	 * @param request
89  	 *            The request.
90  	 * @param response
91  	 *            The response.
92  	 */
93  	protected void redirectUser(UserRedirectRequiredException e,
94  			HttpServletRequest request, HttpServletResponse response)
95  			throws IOException {
96  
97  		String redirectUri = e.getRedirectUri();
98  		UriComponentsBuilder builder = UriComponentsBuilder
99  				.fromHttpUrl(redirectUri);
100 		Map<String, String> requestParams = e.getRequestParams();
101 		for (Map.Entry<String, String> param : requestParams.entrySet()) {
102 			builder.queryParam(param.getKey(), param.getValue());
103 		}
104 
105 		if (e.getStateKey() != null) {
106 			builder.queryParam("state", e.getStateKey());
107 		}
108 
109 		this.redirectStrategy.sendRedirect(request, response, builder.build()
110 				.encode().toUriString());
111 	}
112 
113 	/**
114 	 * Calculate the current URI given the request.
115 	 * 
116 	 * @param request
117 	 *            The request.
118 	 * @return The current uri.
119 	 */
120 	protected String calculateCurrentUri(HttpServletRequest request)
121 			throws UnsupportedEncodingException {
122 		ServletUriComponentsBuilder builder = ServletUriComponentsBuilder
123 				.fromRequest(request);
124 		// Now work around SPR-10172...
125 		String queryString = request.getQueryString();
126 		boolean legalSpaces = queryString != null && queryString.contains("+");
127 		if (legalSpaces) {
128 			builder.replaceQuery(queryString.replace("+", "%20"));
129 		}
130 		UriComponents uri = null;
131 		try {
132 			uri = builder.replaceQueryParam("code").build(true);
133 		} catch (IllegalArgumentException ex) {
134 			// ignore failures to parse the url (including query string). does't
135 			// make sense for redirection purposes anyway.
136 			return null;
137 		}
138 		String query = uri.getQuery();
139 		if (legalSpaces) {
140 			query = query.replace("%20", "+");
141 		}
142 		return ServletUriComponentsBuilder.fromUri(uri.toUri())
143 				.replaceQuery(query).build().toString();
144 	}
145 
146 	public void init(FilterConfig filterConfig) throws ServletException {
147 	}
148 
149 	public void destroy() {
150 	}
151 
152 	public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
153 		this.throwableAnalyzer = throwableAnalyzer;
154 	}
155 
156 	public void setRedirectStrategy(RedirectStrategy redirectStrategy) {
157 		this.redirectStrategy = redirectStrategy;
158 	}
159 
160 }