1 package org.springframework.security.oauth2.client.filter;
2
3 import java.io.IOException;
4 import java.io.UnsupportedEncodingException;
5 import java.util.Map;
6
7 import javax.servlet.Filter;
8 import javax.servlet.FilterChain;
9 import javax.servlet.FilterConfig;
10 import javax.servlet.ServletException;
11 import javax.servlet.ServletRequest;
12 import javax.servlet.ServletResponse;
13 import javax.servlet.http.HttpServletRequest;
14 import javax.servlet.http.HttpServletResponse;
15
16 import org.springframework.beans.factory.InitializingBean;
17 import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
18 import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
19 import org.springframework.security.web.DefaultRedirectStrategy;
20 import org.springframework.security.web.RedirectStrategy;
21 import org.springframework.security.web.util.ThrowableAnalyzer;
22 import org.springframework.util.Assert;
23 import org.springframework.web.servlet.support.ServletUriComponentsBuilder;
24 import org.springframework.web.util.NestedServletException;
25 import org.springframework.web.util.UriComponents;
26 import org.springframework.web.util.UriComponentsBuilder;
27
28
29
30
31
32
33
34 public class OAuth2ClientContextFilter implements Filter, InitializingBean {
35
36
37
38
39
40
41 public static final String CURRENT_URI = "currentUri";
42
43 private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
44
45 private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
46
47 public void afterPropertiesSet() throws Exception {
48 Assert.notNull(redirectStrategy,
49 "A redirect strategy must be supplied.");
50 }
51
52 public void doFilter(ServletRequest servletRequest,
53 ServletResponse servletResponse, FilterChain chain)
54 throws IOException, ServletException {
55 HttpServletRequest request = (HttpServletRequest) servletRequest;
56 HttpServletResponse response = (HttpServletResponse) servletResponse;
57 request.setAttribute(CURRENT_URI, calculateCurrentUri(request));
58
59 try {
60 chain.doFilter(servletRequest, servletResponse);
61 } catch (IOException ex) {
62 throw ex;
63 } catch (Exception ex) {
64
65 Throwable[] causeChain = throwableAnalyzer.determineCauseChain(ex);
66 UserRedirectRequiredException redirect = (UserRedirectRequiredException) throwableAnalyzer
67 .getFirstThrowableOfType(
68 UserRedirectRequiredException.class, causeChain);
69 if (redirect != null) {
70 redirectUser(redirect, request, response);
71 } else {
72 if (ex instanceof ServletException) {
73 throw (ServletException) ex;
74 }
75 if (ex instanceof RuntimeException) {
76 throw (RuntimeException) ex;
77 }
78 throw new NestedServletException("Unhandled exception", ex);
79 }
80 }
81 }
82
83
84
85
86
87
88
89
90
91
92
93 protected void redirectUser(UserRedirectRequiredException e,
94 HttpServletRequest request, HttpServletResponse response)
95 throws IOException {
96
97 String redirectUri = e.getRedirectUri();
98 UriComponentsBuilder builder = UriComponentsBuilder
99 .fromHttpUrl(redirectUri);
100 Map<String, String> requestParams = e.getRequestParams();
101 for (Map.Entry<String, String> param : requestParams.entrySet()) {
102 builder.queryParam(param.getKey(), param.getValue());
103 }
104
105 if (e.getStateKey() != null) {
106 builder.queryParam("state", e.getStateKey());
107 }
108
109 this.redirectStrategy.sendRedirect(request, response, builder.build()
110 .encode().toUriString());
111 }
112
113
114
115
116
117
118
119
120 protected String calculateCurrentUri(HttpServletRequest request)
121 throws UnsupportedEncodingException {
122 ServletUriComponentsBuilder builder = ServletUriComponentsBuilder
123 .fromRequest(request);
124
125 String queryString = request.getQueryString();
126 boolean legalSpaces = queryString != null && queryString.contains("+");
127 if (legalSpaces) {
128 builder.replaceQuery(queryString.replace("+", "%20"));
129 }
130 UriComponents uri = null;
131 try {
132 uri = builder.replaceQueryParam("code").build(true);
133 } catch (IllegalArgumentException ex) {
134
135
136 return null;
137 }
138 String query = uri.getQuery();
139 if (legalSpaces) {
140 query = query.replace("%20", "+");
141 }
142 return ServletUriComponentsBuilder.fromUri(uri.toUri())
143 .replaceQuery(query).build().toString();
144 }
145
146 public void init(FilterConfig filterConfig) throws ServletException {
147 }
148
149 public void destroy() {
150 }
151
152 public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
153 this.throwableAnalyzer = throwableAnalyzer;
154 }
155
156 public void setRedirectStrategy(RedirectStrategy redirectStrategy) {
157 this.redirectStrategy = redirectStrategy;
158 }
159
160 }