1 package org.springframework.security.oauth2.client.token;
2
3 import org.apache.commons.logging.Log;
4 import org.apache.commons.logging.LogFactory;
5 import org.springframework.http.HttpHeaders;
6 import org.springframework.http.HttpMethod;
7 import org.springframework.http.MediaType;
8 import org.springframework.http.client.ClientHttpRequest;
9 import org.springframework.http.client.ClientHttpRequestFactory;
10 import org.springframework.http.client.ClientHttpRequestInterceptor;
11 import org.springframework.http.client.ClientHttpResponse;
12 import org.springframework.http.client.SimpleClientHttpRequestFactory;
13 import org.springframework.http.converter.FormHttpMessageConverter;
14 import org.springframework.http.converter.HttpMessageConverter;
15 import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
16 import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
17 import org.springframework.security.oauth2.client.token.auth.ClientAuthenticationHandler;
18 import org.springframework.security.oauth2.client.token.auth.DefaultClientAuthenticationHandler;
19 import org.springframework.security.oauth2.common.OAuth2AccessToken;
20 import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
21 import org.springframework.security.oauth2.http.converter.FormOAuth2AccessTokenMessageConverter;
22 import org.springframework.security.oauth2.http.converter.FormOAuth2ExceptionHttpMessageConverter;
23 import org.springframework.util.Assert;
24 import org.springframework.util.MultiValueMap;
25 import org.springframework.web.client.DefaultResponseErrorHandler;
26 import org.springframework.web.client.HttpMessageConverterExtractor;
27 import org.springframework.web.client.RequestCallback;
28 import org.springframework.web.client.ResponseErrorHandler;
29 import org.springframework.web.client.ResponseExtractor;
30 import org.springframework.web.client.RestClientException;
31 import org.springframework.web.client.RestOperations;
32 import org.springframework.web.client.RestTemplate;
33
34 import java.io.IOException;
35 import java.net.HttpURLConnection;
36 import java.util.ArrayList;
37 import java.util.Arrays;
38 import java.util.List;
39
40
41
42
43
44
45
46 public abstract class OAuth2AccessTokenSupport {
47
48 protected final Log logger = LogFactory.getLog(getClass());
49
50 private static final FormHttpMessageConverter FORM_MESSAGE_CONVERTER = new FormHttpMessageConverter();
51
52 private RestOperations restTemplate;
53
54 private List<HttpMessageConverter<?>> messageConverters;
55
56 private ClientAuthenticationHandler authenticationHandler = new DefaultClientAuthenticationHandler();
57
58 private ResponseErrorHandler responseErrorHandler = new AccessTokenErrorHandler();
59
60 private List<ClientHttpRequestInterceptor> interceptors = new ArrayList<ClientHttpRequestInterceptor>();
61
62 private RequestEnhancer tokenRequestEnhancer = new DefaultRequestEnhancer();
63
64
65
66
67 public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
68 this.interceptors = interceptors;
69 }
70
71
72
73
74
75 public void setTokenRequestEnhancer(RequestEnhancer tokenRequestEnhancer) {
76 this.tokenRequestEnhancer = tokenRequestEnhancer;
77 }
78
79 private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory() {
80 @Override
81 protected void prepareConnection(HttpURLConnection connection, String httpMethod)
82 throws IOException {
83 super.prepareConnection(connection, httpMethod);
84 connection.setInstanceFollowRedirects(false);
85 connection.setUseCaches(false);
86 }
87 };
88
89 protected RestOperations getRestTemplate() {
90 if (restTemplate == null) {
91 synchronized (this) {
92 if (restTemplate == null) {
93 RestTemplate restTemplate = new RestTemplate();
94 restTemplate.setErrorHandler(getResponseErrorHandler());
95 restTemplate.setRequestFactory(requestFactory);
96 restTemplate.setInterceptors(interceptors);
97 this.restTemplate = restTemplate;
98 }
99 }
100 }
101 if (messageConverters == null) {
102 setMessageConverters(new RestTemplate().getMessageConverters());
103 }
104 return restTemplate;
105 }
106
107 public void setAuthenticationHandler(ClientAuthenticationHandler authenticationHandler) {
108 this.authenticationHandler = authenticationHandler;
109 }
110
111 public void setMessageConverters(List<HttpMessageConverter<?>> messageConverters) {
112 this.messageConverters = new ArrayList<HttpMessageConverter<?>>(messageConverters);
113 this.messageConverters.add(new FormOAuth2AccessTokenMessageConverter());
114 this.messageConverters.add(new FormOAuth2ExceptionHttpMessageConverter());
115 }
116
117 protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
118 MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {
119
120 try {
121
122 authenticationHandler.authenticateTokenRequest(resource, form, headers);
123
124 tokenRequestEnhancer.enhance(request, resource, form, headers);
125 final AccessTokenRequest copy = request;
126
127 final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
128 ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
129 @Override
130 public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
131 if (response.getHeaders().containsKey("Set-Cookie")) {
132 copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
133 }
134 return delegate.extractData(response);
135 }
136 };
137 return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
138 getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());
139
140 }
141 catch (OAuth2Exception oe) {
142 throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
143 }
144 catch (RestClientException rce) {
145 throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
146 }
147
148 }
149
150 protected HttpMethod getHttpMethod() {
151 return HttpMethod.POST;
152 }
153
154 protected String getAccessTokenUri(OAuth2ProtectedResourceDetails resource, MultiValueMap<String, String> form) {
155
156 String accessTokenUri = resource.getAccessTokenUri();
157
158 if (logger.isDebugEnabled()) {
159 logger.debug("Retrieving token from " + accessTokenUri);
160 }
161
162 StringBuilder builder = new StringBuilder(accessTokenUri);
163
164 if (getHttpMethod() == HttpMethod.GET) {
165 String separator = "?";
166 if (accessTokenUri.contains("?")) {
167 separator = "&";
168 }
169
170 for (String key : form.keySet()) {
171 builder.append(separator);
172 builder.append(key + "={" + key + "}");
173 separator = "&";
174 }
175 }
176
177 return builder.toString();
178
179 }
180
181 protected ResponseErrorHandler getResponseErrorHandler() {
182 return responseErrorHandler;
183 }
184
185
186
187
188 public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
189 Assert.notNull(requestFactory, "'requestFactory' must not be null");
190 this.requestFactory = requestFactory;
191 }
192
193 protected ResponseExtractor<OAuth2AccessToken> getResponseExtractor() {
194 getRestTemplate();
195 return new HttpMessageConverterExtractor<OAuth2AccessToken>(OAuth2AccessToken.class, this.messageConverters);
196 }
197
198 protected RequestCallback getRequestCallback(OAuth2ProtectedResourceDetails resource,
199 MultiValueMap<String, String> form, HttpHeaders headers) {
200 return new OAuth2AuthTokenCallback(form, headers);
201 }
202
203
204
205
206 private class OAuth2AuthTokenCallback implements RequestCallback {
207
208 private final MultiValueMap<String, String> form;
209
210 private final HttpHeaders headers;
211
212 private OAuth2AuthTokenCallback(MultiValueMap<String, String> form, HttpHeaders headers) {
213 this.form = form;
214 this.headers = headers;
215 }
216
217 public void doWithRequest(ClientHttpRequest request) throws IOException {
218 request.getHeaders().putAll(this.headers);
219 request.getHeaders().setAccept(
220 Arrays.asList(MediaType.APPLICATION_JSON, MediaType.APPLICATION_FORM_URLENCODED));
221 if (logger.isDebugEnabled()) {
222 logger.debug("Encoding and sending form: " + form);
223 }
224 FORM_MESSAGE_CONVERTER.write(this.form, MediaType.APPLICATION_FORM_URLENCODED, request);
225 }
226 }
227
228 private class AccessTokenErrorHandler extends DefaultResponseErrorHandler {
229
230 @SuppressWarnings("unchecked")
231 @Override
232 public void handleError(ClientHttpResponse response) throws IOException {
233 for (HttpMessageConverter<?> converter : messageConverters) {
234 if (converter.canRead(OAuth2Exception.class, response.getHeaders().getContentType())) {
235 OAuth2Exception ex;
236 try {
237 ex = ((HttpMessageConverter<OAuth2Exception>) converter).read(OAuth2Exception.class, response);
238 }
239 catch (Exception e) {
240
241 continue;
242 }
243 throw ex;
244 }
245 }
246 super.handleError(response);
247 }
248
249 }
250
251 }