1 package org.springframework.security.oauth2.client;
2
3 import java.io.IOException;
4 import java.io.UnsupportedEncodingException;
5 import java.net.URI;
6 import java.net.URISyntaxException;
7 import java.net.URLEncoder;
8 import java.util.Arrays;
9
10 import org.springframework.http.HttpMethod;
11 import org.springframework.http.client.ClientHttpRequest;
12 import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
13 import org.springframework.security.oauth2.client.http.OAuth2ErrorHandler;
14 import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
15 import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
16 import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
17 import org.springframework.security.oauth2.client.token.AccessTokenProvider;
18 import org.springframework.security.oauth2.client.token.AccessTokenProviderChain;
19 import org.springframework.security.oauth2.client.token.AccessTokenRequest;
20 import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider;
21 import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
22 import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider;
23 import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider;
24 import org.springframework.security.oauth2.common.AuthenticationScheme;
25 import org.springframework.security.oauth2.common.OAuth2AccessToken;
26 import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
27 import org.springframework.web.client.RequestCallback;
28 import org.springframework.web.client.ResponseErrorHandler;
29 import org.springframework.web.client.ResponseExtractor;
30 import org.springframework.web.client.RestClientException;
31 import org.springframework.web.client.RestTemplate;
32
33
34
35
36
37
38
39 public class OAuth2RestTemplate extends RestTemplate implements OAuth2RestOperations {
40
41 private final OAuth2ProtectedResourceDetails resource;
42
43 private AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(
44 new AuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(),
45 new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider()));
46
47 private OAuth2ClientContext context;
48
49 private boolean retryBadAccessTokens = true;
50
51 private OAuth2RequestAuthenticator authenticator = new DefaultOAuth2RequestAuthenticator();
52
53 public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource) {
54 this(resource, new DefaultOAuth2ClientContext());
55 }
56
57 public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource, OAuth2ClientContext context) {
58 super();
59 if (resource == null) {
60 throw new IllegalArgumentException("An OAuth2 resource must be supplied.");
61 }
62
63 this.resource = resource;
64 this.context = context;
65 setErrorHandler(new OAuth2ErrorHandler(resource));
66 }
67
68
69
70
71
72
73
74 public void setAuthenticator(OAuth2RequestAuthenticator authenticator) {
75 this.authenticator = authenticator;
76 }
77
78
79
80
81
82
83
84
85 public void setRetryBadAccessTokens(boolean retryBadAccessTokens) {
86 this.retryBadAccessTokens = retryBadAccessTokens;
87 }
88
89 @Override
90 public void setErrorHandler(ResponseErrorHandler errorHandler) {
91 if (!(errorHandler instanceof OAuth2ErrorHandler)) {
92 errorHandler = new OAuth2ErrorHandler(errorHandler, resource);
93 }
94 super.setErrorHandler(errorHandler);
95 }
96
97 @Override
98 public OAuth2ProtectedResourceDetails getResource() {
99 return resource;
100 }
101
102 @Override
103 protected ClientHttpRequest createRequest(URI uri, HttpMethod method) throws IOException {
104
105 OAuth2AccessToken accessToken = getAccessToken();
106
107 AuthenticationScheme authenticationScheme = resource.getAuthenticationScheme();
108 if (AuthenticationScheme.query.equals(authenticationScheme)
109 || AuthenticationScheme.form.equals(authenticationScheme)) {
110 uri = appendQueryParameter(uri, accessToken);
111 }
112
113 ClientHttpRequest req = super.createRequest(uri, method);
114
115 if (AuthenticationScheme.header.equals(authenticationScheme)) {
116 authenticator.authenticate(resource, getOAuth2ClientContext(), req);
117 }
118 return req;
119
120 }
121
122 @Override
123 protected <T> T doExecute(URI url, HttpMethod method, RequestCallback requestCallback,
124 ResponseExtractor<T> responseExtractor) throws RestClientException {
125 OAuth2AccessToken accessToken = context.getAccessToken();
126 RuntimeException rethrow = null;
127 try {
128 return super.doExecute(url, method, requestCallback, responseExtractor);
129 }
130 catch (AccessTokenRequiredException e) {
131 rethrow = e;
132 }
133 catch (OAuth2AccessDeniedException e) {
134 rethrow = e;
135 }
136 catch (InvalidTokenException e) {
137
138 rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
139 }
140 if (accessToken != null && retryBadAccessTokens) {
141 context.setAccessToken(null);
142 try {
143 return super.doExecute(url, method, requestCallback, responseExtractor);
144 }
145 catch (InvalidTokenException e) {
146
147 rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
148 }
149 }
150 throw rethrow;
151 }
152
153
154
155
156 private String getClientId() {
157 return resource.getClientId();
158 }
159
160
161
162
163
164
165
166
167 public OAuth2AccessToken getAccessToken() throws UserRedirectRequiredException {
168
169 OAuth2AccessToken accessToken = context.getAccessToken();
170
171 if (accessToken == null || accessToken.isExpired()) {
172 try {
173 accessToken = acquireAccessToken(context);
174 }
175 catch (UserRedirectRequiredException e) {
176 context.setAccessToken(null);
177 accessToken = null;
178 String stateKey = e.getStateKey();
179 if (stateKey != null) {
180 Object stateToPreserve = e.getStateToPreserve();
181 if (stateToPreserve == null) {
182 stateToPreserve = "NONE";
183 }
184 context.setPreservedState(stateKey, stateToPreserve);
185 }
186 throw e;
187 }
188 }
189 return accessToken;
190 }
191
192
193
194
195 public OAuth2ClientContext getOAuth2ClientContext() {
196 return context;
197 }
198
199 protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context)
200 throws UserRedirectRequiredException {
201
202 AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest();
203 if (accessTokenRequest == null) {
204 throw new AccessTokenRequiredException(
205 "No OAuth 2 security context has been established. Unable to access resource '"
206 + this.resource.getId() + "'.", resource);
207 }
208
209
210 String stateKey = accessTokenRequest.getStateKey();
211 if (stateKey != null) {
212 accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey));
213 }
214
215 OAuth2AccessToken existingToken = oauth2Context.getAccessToken();
216 if (existingToken != null) {
217 accessTokenRequest.setExistingToken(existingToken);
218 }
219
220 OAuth2AccessToken accessToken = null;
221 accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest);
222 if (accessToken == null || accessToken.getValue() == null) {
223 throw new IllegalStateException(
224 "Access token provider returned a null access token, which is illegal according to the contract.");
225 }
226 oauth2Context.setAccessToken(accessToken);
227 return accessToken;
228 }
229
230 protected URI appendQueryParameter(URI uri, OAuth2AccessToken accessToken) {
231
232 try {
233
234
235
236 String query = uri.getRawQuery();
237 String queryFragment = resource.getTokenName() + "=" + URLEncoder.encode(accessToken.getValue(), "UTF-8");
238 if (query == null) {
239 query = queryFragment;
240 }
241 else {
242 query = query + "&" + queryFragment;
243 }
244
245
246
247 URI update = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(), null,
248 null);
249
250 StringBuffer sb = new StringBuffer(update.toString());
251 sb.append("?");
252 sb.append(query);
253 if (uri.getFragment() != null) {
254 sb.append("#");
255 sb.append(uri.getFragment());
256 }
257
258 return new URI(sb.toString());
259
260 }
261 catch (URISyntaxException e) {
262 throw new IllegalArgumentException("Could not parse URI", e);
263 }
264 catch (UnsupportedEncodingException e) {
265 throw new IllegalArgumentException("Could not encode URI", e);
266 }
267
268 }
269
270 public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
271 this.accessTokenProvider = accessTokenProvider;
272 }
273
274 }