View Javadoc
1   package org.springframework.security.oauth2.client;
2   
3   import java.io.IOException;
4   import java.io.UnsupportedEncodingException;
5   import java.net.URI;
6   import java.net.URISyntaxException;
7   import java.net.URLEncoder;
8   import java.util.Arrays;
9   
10  import org.springframework.http.HttpMethod;
11  import org.springframework.http.client.ClientHttpRequest;
12  import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
13  import org.springframework.security.oauth2.client.http.OAuth2ErrorHandler;
14  import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
15  import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
16  import org.springframework.security.oauth2.client.resource.UserRedirectRequiredException;
17  import org.springframework.security.oauth2.client.token.AccessTokenProvider;
18  import org.springframework.security.oauth2.client.token.AccessTokenProviderChain;
19  import org.springframework.security.oauth2.client.token.AccessTokenRequest;
20  import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider;
21  import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
22  import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitAccessTokenProvider;
23  import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordAccessTokenProvider;
24  import org.springframework.security.oauth2.common.AuthenticationScheme;
25  import org.springframework.security.oauth2.common.OAuth2AccessToken;
26  import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
27  import org.springframework.web.client.RequestCallback;
28  import org.springframework.web.client.ResponseErrorHandler;
29  import org.springframework.web.client.ResponseExtractor;
30  import org.springframework.web.client.RestClientException;
31  import org.springframework.web.client.RestTemplate;
32  
33  /**
34   * Rest template that is able to make OAuth2-authenticated REST requests with the credentials of the provided resource.
35   * 
36   * @author Ryan Heaton
37   * @author Dave Syer
38   */
39  public class OAuth2RestTemplate extends RestTemplate implements OAuth2RestOperations {
40  
41  	private final OAuth2ProtectedResourceDetails resource;
42  
43  	private AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.<AccessTokenProvider> asList(
44  			new AuthorizationCodeAccessTokenProvider(), new ImplicitAccessTokenProvider(),
45  			new ResourceOwnerPasswordAccessTokenProvider(), new ClientCredentialsAccessTokenProvider()));
46  
47  	private OAuth2ClientContext context;
48  
49  	private boolean retryBadAccessTokens = true;
50  
51  	private OAuth2RequestAuthenticator authenticator = new DefaultOAuth2RequestAuthenticator();
52  
53  	public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource) {
54  		this(resource, new DefaultOAuth2ClientContext());
55  	}
56  
57  	public OAuth2RestTemplate(OAuth2ProtectedResourceDetails resource, OAuth2ClientContext context) {
58  		super();
59  		if (resource == null) {
60  			throw new IllegalArgumentException("An OAuth2 resource must be supplied.");
61  		}
62  
63  		this.resource = resource;
64  		this.context = context;
65  		setErrorHandler(new OAuth2ErrorHandler(resource));
66  	}
67  
68  	/**
69  	 * Strategy for extracting an Authorization header from an access token and the request details. Defaults to the
70  	 * simple form "TOKEN_TYPE TOKEN_VALUE".
71  	 * 
72  	 * @param authenticator the authenticator to use
73  	 */
74  	public void setAuthenticator(OAuth2RequestAuthenticator authenticator) {
75  		this.authenticator = authenticator;
76  	}
77  
78  	/**
79  	 * Flag to determine whether a request that has an existing access token, and which then leads to an
80  	 * AccessTokenRequiredException should be retried (immediately, once). Useful if the remote server doesn't recognize
81  	 * an old token which is stored in the client, but is happy to re-grant it.
82  	 * 
83  	 * @param retryBadAccessTokens the flag to set (default true)
84  	 */
85  	public void setRetryBadAccessTokens(boolean retryBadAccessTokens) {
86  		this.retryBadAccessTokens = retryBadAccessTokens;
87  	}
88  
89  	@Override
90  	public void setErrorHandler(ResponseErrorHandler errorHandler) {
91  		if (!(errorHandler instanceof OAuth2ErrorHandler)) {
92  			errorHandler = new OAuth2ErrorHandler(errorHandler, resource);
93  		}
94  		super.setErrorHandler(errorHandler);
95  	}
96  	
97  	@Override
98  	public OAuth2ProtectedResourceDetails getResource() {
99  		return resource;
100 	}
101 
102 	@Override
103 	protected ClientHttpRequest createRequest(URI uri, HttpMethod method) throws IOException {
104 
105 		OAuth2AccessToken accessToken = getAccessToken();
106 
107 		AuthenticationScheme authenticationScheme = resource.getAuthenticationScheme();
108 		if (AuthenticationScheme.query.equals(authenticationScheme)
109 				|| AuthenticationScheme.form.equals(authenticationScheme)) {
110 			uri = appendQueryParameter(uri, accessToken);
111 		}
112 
113 		ClientHttpRequest req = super.createRequest(uri, method);
114 
115 		if (AuthenticationScheme.header.equals(authenticationScheme)) {
116 			authenticator.authenticate(resource, getOAuth2ClientContext(), req);
117 		}
118 		return req;
119 
120 	}
121 
122 	@Override
123 	protected <T> T doExecute(URI url, HttpMethod method, RequestCallback requestCallback,
124 			ResponseExtractor<T> responseExtractor) throws RestClientException {
125 		OAuth2AccessToken accessToken = context.getAccessToken();
126 		RuntimeException rethrow = null;
127 		try {
128 			return super.doExecute(url, method, requestCallback, responseExtractor);
129 		}
130 		catch (AccessTokenRequiredException e) {
131 			rethrow = e;
132 		}
133 		catch (OAuth2AccessDeniedException e) {
134 			rethrow = e;
135 		}
136 		catch (InvalidTokenException e) {
137 			// Don't reveal the token value in case it is logged
138 			rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
139 		}
140 		if (accessToken != null && retryBadAccessTokens) {
141 			context.setAccessToken(null);
142 			try {
143 				return super.doExecute(url, method, requestCallback, responseExtractor);
144 			}
145 			catch (InvalidTokenException e) {
146 				// Don't reveal the token value in case it is logged
147 				rethrow = new OAuth2AccessDeniedException("Invalid token for client=" + getClientId());
148 			}
149 		}
150 		throw rethrow;
151 	}
152 
153 	/**
154 	 * @return the client id for this resource.
155 	 */
156 	private String getClientId() {
157 		return resource.getClientId();
158 	}
159 
160 	/**
161 	 * Acquire or renew an access token for the current context if necessary. This method will be called automatically
162 	 * when a request is executed (and the result is cached), but can also be called as a standalone method to
163 	 * pre-populate the token.
164 	 * 
165 	 * @return an access token
166 	 */
167 	public OAuth2AccessToken getAccessToken() throws UserRedirectRequiredException {
168 
169 		OAuth2AccessToken accessToken = context.getAccessToken();
170 
171 		if (accessToken == null || accessToken.isExpired()) {
172 			try {
173 				accessToken = acquireAccessToken(context);
174 			}
175 			catch (UserRedirectRequiredException e) {
176 				context.setAccessToken(null); // No point hanging onto it now
177 				accessToken = null;
178 				String stateKey = e.getStateKey();
179 				if (stateKey != null) {
180 					Object stateToPreserve = e.getStateToPreserve();
181 					if (stateToPreserve == null) {
182 						stateToPreserve = "NONE";
183 					}
184 					context.setPreservedState(stateKey, stateToPreserve);
185 				}
186 				throw e;
187 			}
188 		}
189 		return accessToken;
190 	}
191 
192 	/**
193 	 * @return the context for this template
194 	 */
195 	public OAuth2ClientContext getOAuth2ClientContext() {
196 		return context;
197 	}
198 
199 	protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context)
200 			throws UserRedirectRequiredException {
201 
202 		AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest();
203 		if (accessTokenRequest == null) {
204 			throw new AccessTokenRequiredException(
205 					"No OAuth 2 security context has been established. Unable to access resource '"
206 							+ this.resource.getId() + "'.", resource);
207 		}
208 
209 		// Transfer the preserved state from the (longer lived) context to the current request.
210 		String stateKey = accessTokenRequest.getStateKey();
211 		if (stateKey != null) {
212 			accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey));
213 		}
214 
215 		OAuth2AccessToken existingToken = oauth2Context.getAccessToken();
216 		if (existingToken != null) {
217 			accessTokenRequest.setExistingToken(existingToken);
218 		}
219 
220 		OAuth2AccessToken accessToken = null;
221 		accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest);
222 		if (accessToken == null || accessToken.getValue() == null) {
223 			throw new IllegalStateException(
224 					"Access token provider returned a null access token, which is illegal according to the contract.");
225 		}
226 		oauth2Context.setAccessToken(accessToken);
227 		return accessToken;
228 	}
229 
230 	protected URI appendQueryParameter(URI uri, OAuth2AccessToken accessToken) {
231 
232 		try {
233 
234 			// TODO: there is some duplication with UriUtils here. Probably unavoidable as long as this
235 			// method signature uses URI not String.
236 			String query = uri.getRawQuery(); // Don't decode anything here
237 			String queryFragment = resource.getTokenName() + "=" + URLEncoder.encode(accessToken.getValue(), "UTF-8");
238 			if (query == null) {
239 				query = queryFragment;
240 			}
241 			else {
242 				query = query + "&" + queryFragment;
243 			}
244 
245 			// first form the URI without query and fragment parts, so that it doesn't re-encode some query string chars
246 			// (SECOAUTH-90)
247 			URI update = new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), uri.getPath(), null,
248 					null);
249 			// now add the encoded query string and the then fragment
250 			StringBuffer sb = new StringBuffer(update.toString());
251 			sb.append("?");
252 			sb.append(query);
253 			if (uri.getFragment() != null) {
254 				sb.append("#");
255 				sb.append(uri.getFragment());
256 			}
257 
258 			return new URI(sb.toString());
259 
260 		}
261 		catch (URISyntaxException e) {
262 			throw new IllegalArgumentException("Could not parse URI", e);
263 		}
264 		catch (UnsupportedEncodingException e) {
265 			throw new IllegalArgumentException("Could not encode URI", e);
266 		}
267 
268 	}
269 
270 	public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
271 		this.accessTokenProvider = accessTokenProvider;
272 	}
273 
274 }