View Javadoc
1   package org.springframework.security.oauth2.client.token;
2   
3   import org.apache.commons.logging.Log;
4   import org.apache.commons.logging.LogFactory;
5   import org.springframework.http.HttpHeaders;
6   import org.springframework.http.HttpMethod;
7   import org.springframework.http.MediaType;
8   import org.springframework.http.client.ClientHttpRequest;
9   import org.springframework.http.client.ClientHttpRequestFactory;
10  import org.springframework.http.client.ClientHttpRequestInterceptor;
11  import org.springframework.http.client.ClientHttpResponse;
12  import org.springframework.http.client.SimpleClientHttpRequestFactory;
13  import org.springframework.http.converter.FormHttpMessageConverter;
14  import org.springframework.http.converter.HttpMessageConverter;
15  import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
16  import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
17  import org.springframework.security.oauth2.client.token.auth.ClientAuthenticationHandler;
18  import org.springframework.security.oauth2.client.token.auth.DefaultClientAuthenticationHandler;
19  import org.springframework.security.oauth2.common.OAuth2AccessToken;
20  import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
21  import org.springframework.security.oauth2.http.converter.FormOAuth2AccessTokenMessageConverter;
22  import org.springframework.security.oauth2.http.converter.FormOAuth2ExceptionHttpMessageConverter;
23  import org.springframework.util.Assert;
24  import org.springframework.util.MultiValueMap;
25  import org.springframework.web.client.DefaultResponseErrorHandler;
26  import org.springframework.web.client.HttpMessageConverterExtractor;
27  import org.springframework.web.client.RequestCallback;
28  import org.springframework.web.client.ResponseErrorHandler;
29  import org.springframework.web.client.ResponseExtractor;
30  import org.springframework.web.client.RestClientException;
31  import org.springframework.web.client.RestOperations;
32  import org.springframework.web.client.RestTemplate;
33  
34  import java.io.IOException;
35  import java.net.HttpURLConnection;
36  import java.util.ArrayList;
37  import java.util.Arrays;
38  import java.util.List;
39  
40  /**
41   * Base support logic for obtaining access tokens.
42   * 
43   * @author Ryan Heaton
44   * @author Dave Syer
45   */
46  public abstract class OAuth2AccessTokenSupport {
47  
48  	protected final Log logger = LogFactory.getLog(getClass());
49  
50  	private static final FormHttpMessageConverter FORM_MESSAGE_CONVERTER = new FormHttpMessageConverter();
51  
52  	private RestOperations restTemplate;
53  
54  	private List<HttpMessageConverter<?>> messageConverters;
55  
56  	private ClientAuthenticationHandler authenticationHandler = new DefaultClientAuthenticationHandler();
57  
58  	private ResponseErrorHandler responseErrorHandler = new AccessTokenErrorHandler();
59  
60  	private List<ClientHttpRequestInterceptor> interceptors = new ArrayList<ClientHttpRequestInterceptor>();
61  	
62  	private RequestEnhancer tokenRequestEnhancer = new DefaultRequestEnhancer();
63  	
64  	/**
65  	 * Sets the request interceptors that this accessor should use.
66  	 */
67  	public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
68  		this.interceptors = interceptors;
69  	}
70  	
71  	/**
72  	 * A custom enhancer for the access token request
73  	 * @param tokenRequestEnhancer
74  	 */
75  	public void setTokenRequestEnhancer(RequestEnhancer tokenRequestEnhancer) {
76  		this.tokenRequestEnhancer = tokenRequestEnhancer;
77  	}
78  
79  	private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory() {
80  		@Override
81  		protected void prepareConnection(HttpURLConnection connection, String httpMethod)
82  				throws IOException {
83  			super.prepareConnection(connection, httpMethod);
84  			connection.setInstanceFollowRedirects(false);
85  			connection.setUseCaches(false);
86  		}
87  	};
88  
89  	protected RestOperations getRestTemplate() {
90  		if (restTemplate == null) {
91  			synchronized (this) {
92  				if (restTemplate == null) {
93  					RestTemplate restTemplate = new RestTemplate();
94  					restTemplate.setErrorHandler(getResponseErrorHandler());
95  					restTemplate.setRequestFactory(requestFactory);
96  					restTemplate.setInterceptors(interceptors);
97  					this.restTemplate = restTemplate;
98  				}
99  			}
100 		}
101 		if (messageConverters == null) {
102 			setMessageConverters(new RestTemplate().getMessageConverters());
103 		}
104 		return restTemplate;
105 	}
106 
107 	public void setAuthenticationHandler(ClientAuthenticationHandler authenticationHandler) {
108 		this.authenticationHandler = authenticationHandler;
109 	}
110 
111 	public void setMessageConverters(List<HttpMessageConverter<?>> messageConverters) {
112 		this.messageConverters = new ArrayList<HttpMessageConverter<?>>(messageConverters);
113 		this.messageConverters.add(new FormOAuth2AccessTokenMessageConverter());
114 		this.messageConverters.add(new FormOAuth2ExceptionHttpMessageConverter());
115 	}
116 
117 	protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
118 			MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {
119 
120 		try {
121 			// Prepare headers and form before going into rest template call in case the URI is affected by the result
122 			authenticationHandler.authenticateTokenRequest(resource, form, headers);
123 			// Opportunity to customize form and headers
124 			tokenRequestEnhancer.enhance(request, resource, form, headers);
125 			final AccessTokenRequest copy = request;
126 
127 			final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
128 			ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
129 				@Override
130 				public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
131 					if (response.getHeaders().containsKey("Set-Cookie")) {
132 						copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
133 					}
134 					return delegate.extractData(response);
135 				}
136 			};
137 			return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
138 					getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());
139 
140 		}
141 		catch (OAuth2Exception oe) {
142 			throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
143 		}
144 		catch (RestClientException rce) {
145 			throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
146 		}
147 
148 	}
149 
150 	protected HttpMethod getHttpMethod() {
151 		return HttpMethod.POST;
152 	}
153 
154 	protected String getAccessTokenUri(OAuth2ProtectedResourceDetails resource, MultiValueMap<String, String> form) {
155 
156 		String accessTokenUri = resource.getAccessTokenUri();
157 
158 		if (logger.isDebugEnabled()) {
159 			logger.debug("Retrieving token from " + accessTokenUri);
160 		}
161 
162 		StringBuilder builder = new StringBuilder(accessTokenUri);
163 
164 		if (getHttpMethod() == HttpMethod.GET) {
165 			String separator = "?";
166 			if (accessTokenUri.contains("?")) {
167 				separator = "&";
168 			}
169 
170 			for (String key : form.keySet()) {
171 				builder.append(separator);
172 				builder.append(key + "={" + key + "}");
173 				separator = "&";
174 			}
175 		}
176 
177 		return builder.toString();
178 
179 	}
180 
181 	protected ResponseErrorHandler getResponseErrorHandler() {
182 		return responseErrorHandler;
183 	}
184 
185 	/**
186 	 * Set the request factory that this template uses for obtaining {@link ClientHttpRequest HttpRequests}.
187 	 */
188 	public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
189 		Assert.notNull(requestFactory, "'requestFactory' must not be null");
190 		this.requestFactory = requestFactory;
191 	}
192 
193 	protected ResponseExtractor<OAuth2AccessToken> getResponseExtractor() {
194 		getRestTemplate(); // force initialization
195 		return new HttpMessageConverterExtractor<OAuth2AccessToken>(OAuth2AccessToken.class, this.messageConverters);
196 	}
197 
198 	protected RequestCallback getRequestCallback(OAuth2ProtectedResourceDetails resource,
199 			MultiValueMap<String, String> form, HttpHeaders headers) {
200 		return new OAuth2AuthTokenCallback(form, headers);
201 	}
202 
203 	/**
204 	 * Request callback implementation that writes the given object to the request stream.
205 	 */
206 	private class OAuth2AuthTokenCallback implements RequestCallback {
207 
208 		private final MultiValueMap<String, String> form;
209 
210 		private final HttpHeaders headers;
211 
212 		private OAuth2AuthTokenCallback(MultiValueMap<String, String> form, HttpHeaders headers) {
213 			this.form = form;
214 			this.headers = headers;
215 		}
216 
217 		public void doWithRequest(ClientHttpRequest request) throws IOException {
218 			request.getHeaders().putAll(this.headers);
219 			request.getHeaders().setAccept(
220 					Arrays.asList(MediaType.APPLICATION_JSON, MediaType.APPLICATION_FORM_URLENCODED));
221 			if (logger.isDebugEnabled()) {
222 				logger.debug("Encoding and sending form: " + form);
223 			}
224 			FORM_MESSAGE_CONVERTER.write(this.form, MediaType.APPLICATION_FORM_URLENCODED, request);
225 		}
226 	}
227 
228 	private class AccessTokenErrorHandler extends DefaultResponseErrorHandler {
229 
230 		@SuppressWarnings("unchecked")
231 		@Override
232 		public void handleError(ClientHttpResponse response) throws IOException {
233 			for (HttpMessageConverter<?> converter : messageConverters) {
234 				if (converter.canRead(OAuth2Exception.class, response.getHeaders().getContentType())) {
235 					OAuth2Exception ex;
236 					try {
237 						ex = ((HttpMessageConverter<OAuth2Exception>) converter).read(OAuth2Exception.class, response);
238 					}
239 					catch (Exception e) {
240 						// ignore
241 						continue;
242 					}
243 					throw ex;
244 				}
245 			}
246 			super.handleError(response);
247 		}
248 
249 	}
250 
251 }