1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31 package org.springframework.security.oauth2.client.token.grant.code;
32
33 import java.io.IOException;
34 import java.net.URI;
35 import java.util.Collections;
36 import java.util.Iterator;
37 import java.util.List;
38 import java.util.Map;
39 import java.util.TreeMap;
40
41 import org.springframework.http.HttpHeaders;
42 import org.springframework.http.HttpMethod;
43 import org.springframework.http.HttpStatus;
44 import org.springframework.http.ResponseEntity;
45 import org.springframework.http.client.ClientHttpResponse;
46 import org.springframework.security.access.AccessDeniedException;
47 import org.springframework.security.oauth2.client.filter.state.DefaultStateKeyGenerator;
48 import org.springframework.security.oauth2.client.filter.state.StateKeyGenerator;
49 import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
50 import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
51 import org.springframework.security.oauth2.client.resource.UserApprovalRequiredException;
52 import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
53 import org.springframework.security.oauth2.client.token.AccessTokenProvider;
54 import org.springframework.security.oauth2.client.token.AccessTokenRequest;
55 import org.springframework.security.oauth2.client.token.DefaultRequestEnhancer;
56 import org.springframework.security.oauth2.client.token.OAuth2AccessTokenSupport;
57 import org.springframework.security.oauth2.client.token.RequestEnhancer;
58 import org.springframework.security.oauth2.common.OAuth2AccessToken;
59 import org.springframework.security.oauth2.common.OAuth2RefreshToken;
60 import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
61 import org.springframework.security.oauth2.common.util.OAuth2Utils;
62 import org.springframework.util.LinkedMultiValueMap;
63 import org.springframework.util.MultiValueMap;
64 import org.springframework.web.client.ResponseExtractor;
65
66
67
68
69
70
71
72 public class AuthorizationCodeAccessTokenProvider extends OAuth2AccessTokenSupport implements AccessTokenProvider {
73
74 private StateKeyGenerator stateKeyGenerator = new DefaultStateKeyGenerator();
75
76 private String scopePrefix = OAuth2Utils.SCOPE_PREFIX;
77
78 private RequestEnhancer authorizationRequestEnhancer = new DefaultRequestEnhancer();
79
80 private boolean stateMandatory = true;
81
82
83
84
85
86
87 public void setStateMandatory(boolean stateMandatory) {
88 this.stateMandatory = stateMandatory;
89 }
90
91
92
93
94
95 public void setAuthorizationRequestEnhancer(RequestEnhancer authorizationRequestEnhancer) {
96 this.authorizationRequestEnhancer = authorizationRequestEnhancer;
97 }
98
99
100
101
102
103
104 public void setScopePrefix(String scopePrefix) {
105 this.scopePrefix = scopePrefix;
106 }
107
108
109
110
111 public void setStateKeyGenerator(StateKeyGenerator stateKeyGenerator) {
112 this.stateKeyGenerator = stateKeyGenerator;
113 }
114
115 public boolean supportsResource(OAuth2ProtectedResourceDetails resource) {
116 return resource instanceof AuthorizationCodeResourceDetails
117 && "authorization_code".equals(resource.getGrantType());
118 }
119
120 public boolean supportsRefresh(OAuth2ProtectedResourceDetails resource) {
121 return supportsResource(resource);
122 }
123
124 public String obtainAuthorizationCode(OAuth2ProtectedResourceDetails details, AccessTokenRequest request)
125 throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException,
126 OAuth2AccessDeniedException {
127
128 AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
129
130 HttpHeaders headers = getHeadersForAuthorizationRequest(request);
131 MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
132 if (request.containsKey(OAuth2Utils.USER_OAUTH_APPROVAL)) {
133 form.set(OAuth2Utils.USER_OAUTH_APPROVAL, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL));
134 for (String scope : details.getScope()) {
135 form.set(scopePrefix + scope, request.getFirst(OAuth2Utils.USER_OAUTH_APPROVAL));
136 }
137 }
138 else {
139 form.putAll(getParametersForAuthorizeRequest(resource, request));
140 }
141 authorizationRequestEnhancer.enhance(request, resource, form, headers);
142 final AccessTokenRequest copy = request;
143
144 final ResponseExtractor<ResponseEntity<Void>> delegate = getAuthorizationResponseExtractor();
145 ResponseExtractor<ResponseEntity<Void>> extractor = new ResponseExtractor<ResponseEntity<Void>>() {
146 @Override
147 public ResponseEntity<Void> extractData(ClientHttpResponse response) throws IOException {
148 if (response.getHeaders().containsKey("Set-Cookie")) {
149 copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
150 }
151 return delegate.extractData(response);
152 }
153 };
154
155
156 ResponseEntity<Void> response = getRestTemplate().execute(resource.getUserAuthorizationUri(), HttpMethod.POST,
157 getRequestCallback(resource, form, headers), extractor, form.toSingleValueMap());
158
159 if (response.getStatusCode() == HttpStatus.OK) {
160
161 throw getUserApprovalSignal(resource, request);
162 }
163
164 URI location = response.getHeaders().getLocation();
165 String query = location.getQuery();
166 Map<String, String> map = OAuth2Utils.extractMap(query);
167 if (map.containsKey("state")) {
168 request.setStateKey(map.get("state"));
169 if (request.getPreservedState() == null) {
170 String redirectUri = resource.getRedirectUri(request);
171 if (redirectUri != null) {
172 request.setPreservedState(redirectUri);
173 }
174 else {
175 request.setPreservedState(new Object());
176 }
177 }
178 }
179
180 String code = map.get("code");
181 if (code == null) {
182 throw new UserRedirectRequiredException(location.toString(), form.toSingleValueMap());
183 }
184 request.set("code", code);
185 return code;
186
187 }
188
189 protected ResponseExtractor<ResponseEntity<Void>> getAuthorizationResponseExtractor() {
190 return new ResponseExtractor<ResponseEntity<Void>>() {
191 public ResponseEntity<Void> extractData(ClientHttpResponse response) throws IOException {
192 return new ResponseEntity<Void>(response.getHeaders(), response.getStatusCode());
193 }
194 };
195 }
196
197 public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request)
198 throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException,
199 OAuth2AccessDeniedException {
200
201 AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
202
203 if (request.getAuthorizationCode() == null) {
204 if (request.getStateKey() == null) {
205 throw getRedirectForAuthorization(resource, request);
206 }
207 obtainAuthorizationCode(resource, request);
208 }
209 return retrieveToken(request, resource, getParametersForTokenRequest(resource, request),
210 getHeadersForTokenRequest(request));
211
212 }
213
214 public OAuth2AccessToken refreshAccessToken(OAuth2ProtectedResourceDetails resource,
215 OAuth2RefreshToken refreshToken, AccessTokenRequest request) throws UserRedirectRequiredException,
216 OAuth2AccessDeniedException {
217 MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
218 form.add("grant_type", "refresh_token");
219 form.add("refresh_token", refreshToken.getValue());
220 try {
221 return retrieveToken(request, resource, form, getHeadersForTokenRequest(request));
222 }
223 catch (OAuth2AccessDeniedException e) {
224 throw getRedirectForAuthorization((AuthorizationCodeResourceDetails) resource, request);
225 }
226 }
227
228 private HttpHeaders getHeadersForTokenRequest(AccessTokenRequest request) {
229 HttpHeaders headers = new HttpHeaders();
230
231 return headers;
232 }
233
234 private HttpHeaders getHeadersForAuthorizationRequest(AccessTokenRequest request) {
235 HttpHeaders headers = new HttpHeaders();
236 headers.putAll(request.getHeaders());
237 if (request.getCookie() != null) {
238 headers.set("Cookie", request.getCookie());
239 }
240 return headers;
241 }
242
243 private MultiValueMap<String, String> getParametersForTokenRequest(AuthorizationCodeResourceDetails resource,
244 AccessTokenRequest request) {
245
246 MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
247 form.set("grant_type", "authorization_code");
248 form.set("code", request.getAuthorizationCode());
249
250 Object preservedState = request.getPreservedState();
251 if (request.getStateKey() != null || stateMandatory) {
252
253
254 if (preservedState == null) {
255 throw new InvalidRequestException(
256 "Possible CSRF detected - state parameter was required but no state could be found");
257 }
258 }
259
260
261
262 String redirectUri = null;
263
264 if (preservedState instanceof String) {
265
266
267 redirectUri = String.valueOf(preservedState);
268 }
269 else {
270 redirectUri = resource.getRedirectUri(request);
271 }
272
273 if (redirectUri != null && !"NONE".equals(redirectUri)) {
274 form.set("redirect_uri", redirectUri);
275 }
276
277 return form;
278
279 }
280
281 private MultiValueMap<String, String> getParametersForAuthorizeRequest(AuthorizationCodeResourceDetails resource,
282 AccessTokenRequest request) {
283
284 MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
285 form.set("response_type", "code");
286 form.set("client_id", resource.getClientId());
287
288 if (request.get("scope") != null) {
289 form.set("scope", request.getFirst("scope"));
290 }
291 else {
292 form.set("scope", OAuth2Utils.formatParameterList(resource.getScope()));
293 }
294
295
296
297 String redirectUri = resource.getPreEstablishedRedirectUri();
298
299 Object preservedState = request.getPreservedState();
300 if (redirectUri == null && preservedState != null) {
301
302
303 redirectUri = String.valueOf(preservedState);
304 }
305 else {
306 redirectUri = request.getCurrentUri();
307 }
308
309 String stateKey = request.getStateKey();
310 if (stateKey != null) {
311 form.set("state", stateKey);
312 if (preservedState == null) {
313 throw new InvalidRequestException(
314 "Possible CSRF detected - state parameter was present but no state could be found");
315 }
316 }
317
318 if (redirectUri != null) {
319 form.set("redirect_uri", redirectUri);
320 }
321
322 return form;
323
324 }
325
326 private UserRedirectRequiredException getRedirectForAuthorization(AuthorizationCodeResourceDetails resource,
327 AccessTokenRequest request) {
328
329
330 TreeMap<String, String> requestParameters = new TreeMap<String, String>();
331 requestParameters.put("response_type", "code");
332 requestParameters.put("client_id", resource.getClientId());
333
334
335 String redirectUri = resource.getRedirectUri(request);
336 if (redirectUri != null) {
337 requestParameters.put("redirect_uri", redirectUri);
338 }
339
340 if (resource.isScoped()) {
341
342 StringBuilder builder = new StringBuilder();
343 List<String> scope = resource.getScope();
344
345 if (scope != null) {
346 Iterator<String> scopeIt = scope.iterator();
347 while (scopeIt.hasNext()) {
348 builder.append(scopeIt.next());
349 if (scopeIt.hasNext()) {
350 builder.append(' ');
351 }
352 }
353 }
354
355 requestParameters.put("scope", builder.toString());
356 }
357
358 UserRedirectRequiredException redirectException = new UserRedirectRequiredException(
359 resource.getUserAuthorizationUri(), requestParameters);
360
361 String stateKey = stateKeyGenerator.generateKey(resource);
362 redirectException.setStateKey(stateKey);
363 request.setStateKey(stateKey);
364 redirectException.setStateToPreserve(redirectUri);
365 request.setPreservedState(redirectUri);
366
367 return redirectException;
368
369 }
370
371 protected UserApprovalRequiredException getUserApprovalSignal(AuthorizationCodeResourceDetails resource,
372 AccessTokenRequest request) {
373 String message = String.format("Do you approve the client '%s' to access your resources with scope=%s",
374 resource.getClientId(), resource.getScope());
375 return new UserApprovalRequiredException(resource.getUserAuthorizationUri(), Collections.singletonMap(
376 OAuth2Utils.USER_OAUTH_APPROVAL, message), resource.getClientId(), resource.getScope());
377 }
378
379 }