View Javadoc
1   /*
2    * Copyright 2006-2010 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5    * the License. You may obtain a copy of the License at
6    *
7    * https://www.apache.org/licenses/LICENSE-2.0
8    *
9    * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10   * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11   * specific language governing permissions and limitations under the License.
12   */
13  package org.springframework.security.oauth2.client.test;
14  
15  import java.io.IOException;
16  import java.lang.reflect.Constructor;
17  import java.net.HttpURLConnection;
18  import java.net.URI;
19  import java.util.Arrays;
20  import java.util.LinkedHashMap;
21  import java.util.List;
22  import java.util.Map;
23  
24  import org.apache.commons.logging.Log;
25  import org.apache.commons.logging.LogFactory;
26  import org.apache.http.client.config.CookieSpecs;
27  import org.apache.http.client.config.RequestConfig;
28  import org.apache.http.client.config.RequestConfig.Builder;
29  import org.apache.http.client.protocol.HttpClientContext;
30  import org.apache.http.protocol.HttpContext;
31  import org.hamcrest.CoreMatchers;
32  import org.junit.Assert;
33  import org.junit.internal.AssumptionViolatedException;
34  import org.junit.internal.runners.statements.RunBefores;
35  import org.junit.rules.TestWatchman;
36  import org.junit.runners.model.FrameworkMethod;
37  import org.junit.runners.model.Statement;
38  import org.junit.runners.model.TestClass;
39  import org.springframework.beans.BeanUtils;
40  import org.springframework.core.env.Environment;
41  import org.springframework.http.HttpMethod;
42  import org.springframework.http.client.ClientHttpResponse;
43  import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
44  import org.springframework.http.client.SimpleClientHttpRequestFactory;
45  import org.springframework.security.oauth2.client.DefaultOAuth2ClientContext;
46  import org.springframework.security.oauth2.client.OAuth2ClientContext;
47  import org.springframework.security.oauth2.client.OAuth2RestTemplate;
48  import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
49  import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
50  import org.springframework.security.oauth2.client.token.AccessTokenProvider;
51  import org.springframework.security.oauth2.client.token.AccessTokenRequest;
52  import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
53  import org.springframework.security.oauth2.common.OAuth2AccessToken;
54  import org.springframework.util.ClassUtils;
55  import org.springframework.web.client.DefaultResponseErrorHandler;
56  import org.springframework.web.client.RestOperations;
57  
58  /**
59   * <p>
60   * A rule that sets up an OAuth2 context for tests and makes the access token available
61   * inside a test method. In combination with the {@link OAuth2ContextConfiguration}
62   * annotation provides a number of different strategies for configuring an
63   * {@link OAuth2ProtectedResourceDetails} instance that will be used to create the OAuth2
64   * context for tests. Example:
65   * </p>
66   * 
67   * <pre>
68   * &#064;OAuth2ContextConfiguration(ResourceOwnerPasswordProtectedResourceDetails.class)
69   * public class MyIntegrationTests implements RestTemplateHolder {
70   * 
71   * 	&#064;Rule
72   * 	public OAuth2ContextSetup context = OAuth2ContextSetup.withEnvironment(this,
73   * 			TestEnvironment.instance());
74   * 
75   * 	&#064;Test
76   * 	public void testSomethingWithClientCredentials() {
77   * 		// This call will be authenticated with the client credentials in
78   * 		// MyClientDetailsResource
79   * 		getRestTemplate().getForObject(&quot;https://myserver/resource&quot;, String.class);
80   * 	}
81   * 
82   * 	// This class is used to initialize the OAuth2 context for the test methods.
83   * 	static class MyClientDetailsResource extends
84   * 			ResourceOwnerPasswordProtectedResourceDetails {
85   * 		public MyClientDetailsResource(Environment environment) {
86   *             ... do stuff with environment to initialize the password credentials
87   *         }
88   * 	}
89   * 
90   * }
91   * </pre>
92   * 
93   * @see OAuth2ContextConfiguration
94   * @see BeforeOAuth2Context
95   * 
96   * @author Dave Syer
97   * 
98   */
99  @SuppressWarnings("deprecation")
100 public class OAuth2ContextSetup extends TestWatchman {
101 
102 	private static Log logger = LogFactory.getLog(OAuth2ContextSetup.class);
103 
104 	private OAuth2ProtectedResourceDetails resource;
105 
106 	private OAuth2RestTemplate client;
107 
108 	private Map<String, String> parameters = new LinkedHashMap<String, String>();
109 
110 	private final RestTemplateHolder clientHolder;
111 
112 	private final TestAccounts testAccounts;
113 
114 	private OAuth2AccessToken accessToken;
115 
116 	private boolean initializeAccessToken = true;
117 
118 	private RestOperations savedClient;
119 
120 	private AccessTokenProvider accessTokenProvider;
121 
122 	private final Environment environment;
123 
124 	/**
125 	 * Create a new client that can inject an Environment into its protected resource
126 	 * details.
127 	 * 
128 	 * @param clientHolder receives an OAuth2RestTemplate with the authenticated client
129 	 * for the duration of a test
130 	 * @param environment a Spring Environment that can be used to initialize the client
131 	 * 
132 	 * @return a rule that wraps test methods in an OAuth2 context
133 	 */
134 	public static OAuth2ContextSetup withEnvironment(RestTemplateHolder clientHolder,
135 			Environment environment) {
136 		return new OAuth2ContextSetup(clientHolder, null, environment);
137 	}
138 
139 	/**
140 	 * Create a new client that can inject a {@link TestAccounts} instance into its
141 	 * protected resource details.
142 	 * 
143 	 * @param clientHolder receives an OAuth2RestTemplate with the authenticated client
144 	 * for the duration of a test
145 	 * @param testAccounts a test account generator that can be used to initialize the
146 	 * client
147 	 * 
148 	 * @return a rule that wraps test methods in an OAuth2 context
149 	 */
150 	public static OAuth2ContextSetup withTestAccounts(RestTemplateHolder clientHolder,
151 			TestAccounts testAccounts) {
152 		return new OAuth2ContextSetup(clientHolder, testAccounts, null);
153 	}
154 
155 	/**
156 	 * Create a new client that knows how to create its protected resource with no
157 	 * externalization help. Typically it will use resource details which accept an
158 	 * instance of the current test case (downcasting it from Object). For example
159 	 * 
160 	 * <pre>
161 	 * static class MyClientDetailsResource extends ClientCredentialsProtectedResourceDetails {
162 	 * 	public MyClientDetailsResource(Object target) {
163 	 *             MyIntegrationTests test = (MyIntegrationTests) target;
164 	 *             ... do stuff with test instance to initialize the client credentials
165 	 *         }
166 	 * }
167 	 * </pre>
168 	 * 
169 	 * @param clientHolder receives an OAuth2RestTemplate with the authenticated client
170 	 * for the duration of a test
171 	 * 
172 	 * @return a rule that wraps test methods in an OAuth2 context
173 	 */
174 	public static OAuth2ContextSetup standard(RestTemplateHolder clientHolder) {
175 		return new OAuth2ContextSetup(clientHolder, null, null);
176 	}
177 
178 	private OAuth2ContextSetup(RestTemplateHolder clientHolder,
179 			TestAccounts testAccounts, Environment environment) {
180 		this.clientHolder = clientHolder;
181 		this.testAccounts = testAccounts;
182 		this.environment = environment;
183 	}
184 
185 	@Override
186 	public Statement apply(Statement base, FrameworkMethod method, Object target) {
187 		initializeIfNecessary(method, target);
188 		return super.apply(base, method, target);
189 	}
190 
191 	@Override
192 	public void starting(FrameworkMethod method) {
193 		if (resource != null) {
194 			logger.info("Starting OAuth2 context for: " + resource);
195 			AccessTokenRequest request = new DefaultAccessTokenRequest();
196 			request.setAll(parameters);
197 			client = createRestTemplate(resource, request);
198 			if (initializeAccessToken) {
199 				this.accessToken = null;
200 				this.accessToken = getAccessToken();
201 			}
202 			savedClient = clientHolder.getRestTemplate();
203 			clientHolder.setRestTemplate(client);
204 		}
205 	}
206 
207 	@Override
208 	public void finished(FrameworkMethod method) {
209 		if (resource != null) {
210 			logger.info("Ending OAuth2 context for: " + resource);
211 			if (savedClient != null) {
212 				clientHolder.setRestTemplate(savedClient);
213 			}
214 		}
215 	}
216 
217 	public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
218 		this.accessTokenProvider = accessTokenProvider;
219 	}
220 
221 	public void setParameters(Map<String, String> parameters) {
222 		this.parameters = parameters;
223 	}
224 
225 	/**
226 	 * Get the current access token. Should be available inside a test method as long as a
227 	 * resource has been setup with {@link OAuth2ContextConfiguration
228 	 * &#64;OAuth2ContextConfiguration}.
229 	 * 
230 	 * @return the current access token initializing it if necessary
231 	 */
232 	public OAuth2AccessToken getAccessToken() {
233 		if (resource == null || client == null) {
234 			return null;
235 		}
236 		if (accessToken != null) {
237 			return accessToken;
238 		}
239 		if (accessTokenProvider != null) {
240 			client.setAccessTokenProvider(accessTokenProvider);
241 		}
242 		try {
243 			return client.getAccessToken();
244 		}
245 		catch (OAuth2AccessDeniedException e) {
246 			Throwable cause = e.getCause();
247 			if (cause instanceof RuntimeException) {
248 				throw (RuntimeException) cause;
249 			}
250 			if (cause instanceof Error) {
251 				throw (Error) cause;
252 			}
253 			throw e;
254 		}
255 	}
256 
257 	/**
258 	 * @return the client template
259 	 */
260 	public OAuth2RestTemplate getRestTemplate() {
261 		return client;
262 	}
263 
264 	/**
265 	 * @return the current client resource details
266 	 */
267 	public OAuth2ProtectedResourceDetails getResource() {
268 		return resource;
269 	}
270 
271 	/**
272 	 * @return the current access token request
273 	 */
274 	public AccessTokenRequest getAccessTokenRequest() {
275 		return client.getOAuth2ClientContext().getAccessTokenRequest();
276 	}
277 
278 	/**
279 	 * @return the current OAuth2 context
280 	 */
281 	public OAuth2ClientContext getOAuth2ClientContext() {
282 		return client.getOAuth2ClientContext();
283 	}
284 
285 	private void initializeIfNecessary(FrameworkMethod method, final Object target) {
286 
287 		final TestClass testClass = new TestClass(target.getClass());
288 		OAuth2ContextConfiguration contextConfiguration = findOAuthContextConfiguration(
289 				method, testClass);
290 		if (contextConfiguration == null) {
291 			// Nothing to do
292 			return;
293 		}
294 
295 		this.initializeAccessToken = contextConfiguration.initialize();
296 
297 		this.resource = creatResource(target, contextConfiguration);
298 
299 		final List<FrameworkMethod> befores = testClass
300 				.getAnnotatedMethods(BeforeOAuth2Context.class);
301 		if (!befores.isEmpty()) {
302 
303 			logger.debug("Running @BeforeOAuth2Context methods");
304 
305 			for (FrameworkMethod before : befores) {
306 
307 				RestOperations savedServerClient = clientHolder.getRestTemplate();
308 
309 				OAuth2ContextConfiguration beforeConfiguration = findOAuthContextConfiguration(
310 						before, testClass);
311 				if (beforeConfiguration != null) {
312 
313 					OAuth2ProtectedResourceDetails resource = creatResource(target,
314 							beforeConfiguration);
315 					AccessTokenRequest beforeRequest = new DefaultAccessTokenRequest();
316 					beforeRequest.setAll(parameters);
317 					OAuth2RestTemplate client = createRestTemplate(resource,
318 							beforeRequest);
319 					clientHolder.setRestTemplate(client);
320 
321 				}
322 
323 				AccessTokenRequest request = new DefaultAccessTokenRequest();
324 				request.setAll(parameters);
325 				this.client = createRestTemplate(this.resource, request);
326 
327 				List<FrameworkMethod> list = Arrays.asList(before);
328 				try {
329 					new RunBefores(new Statement() {
330 						public void evaluate() {
331 						}
332 					}, list, target).evaluate();
333 				}
334 				catch (AssumptionViolatedException e) {
335 					throw e;
336 				}
337 				catch (RuntimeException e) {
338 					throw e;
339 				}
340 				catch (AssertionError e) {
341 					throw e;
342 				}
343 				catch (Throwable e) {
344 					logger.debug("Exception in befores", e);
345 					Assert.assertThat(e, CoreMatchers.not(CoreMatchers.anything()));
346 				}
347 				finally {
348 					clientHolder.setRestTemplate(savedServerClient);
349 				}
350 
351 			}
352 
353 		}
354 
355 	}
356 
357 	private OAuth2RestTemplate createRestTemplate(
358 			OAuth2ProtectedResourceDetails resource, AccessTokenRequest request) {
359 		OAuth2ClientContext context = new DefaultOAuth2ClientContext(request);
360 		OAuth2RestTemplate client = new OAuth2RestTemplate(resource, context);
361 		setupConnectionFactory(client);
362 		client.setErrorHandler(new DefaultResponseErrorHandler() {
363 			// Pass errors through in response entity for status code analysis
364 			public boolean hasError(ClientHttpResponse response) throws IOException {
365 				return false;
366 			}
367 		});
368 		if (accessTokenProvider != null) {
369 			client.setAccessTokenProvider(accessTokenProvider);
370 		}
371 		return client;
372 	}
373 
374 	private void setupConnectionFactory(OAuth2RestTemplate client) {
375 		if (Boolean.getBoolean("http.components.enabled")
376 				&& ClassUtils.isPresent("org.apache.http.client.config.RequestConfig",
377 						null)) {
378 			client.setRequestFactory(new HttpComponentsClientHttpRequestFactory() {
379 				@Override
380 				protected HttpContext createHttpContext(HttpMethod httpMethod, URI uri) {
381 					HttpClientContext context = HttpClientContext.create();
382 					context.setRequestConfig(getRequestConfig());
383 					return context;
384 				}
385 
386 				protected RequestConfig getRequestConfig() {
387 					Builder builder = RequestConfig.custom()
388 							.setCookieSpec(CookieSpecs.IGNORE_COOKIES)
389 							.setAuthenticationEnabled(false).setRedirectsEnabled(false);
390 					return builder.build();
391 				}
392 			});
393 		}
394 		else {
395 			client.setRequestFactory(new SimpleClientHttpRequestFactory() {
396 				@Override
397 				protected void prepareConnection(HttpURLConnection connection,
398 						String httpMethod) throws IOException {
399 					super.prepareConnection(connection, httpMethod);
400 					connection.setInstanceFollowRedirects(false);
401 				}
402 			});
403 		}
404 	}
405 
406 	private OAuth2ProtectedResourceDetails creatResource(Object target,
407 			OAuth2ContextConfiguration contextLoader) {
408 		Class<? extends OAuth2ProtectedResourceDetails> type = contextLoader.value();
409 		if (type == OAuth2ProtectedResourceDetails.class) {
410 			type = contextLoader.resource();
411 		}
412 		Constructor<? extends OAuth2ProtectedResourceDetails> constructor = ClassUtils
413 				.getConstructorIfAvailable(type, TestAccounts.class);
414 		if (constructor != null && testAccounts != null) {
415 			return BeanUtils.instantiateClass(constructor, testAccounts);
416 		}
417 		constructor = ClassUtils.getConstructorIfAvailable(type, Environment.class);
418 		if (constructor != null && environment != null) {
419 			return BeanUtils.instantiateClass(constructor, environment);
420 		}
421 		constructor = ClassUtils.getConstructorIfAvailable(type, Object.class);
422 		if (constructor != null) {
423 			return BeanUtils.instantiateClass(constructor, target);
424 		}
425 		// Fallback to default constructor
426 		return BeanUtils.instantiate(type);
427 	}
428 
429 	private OAuth2ContextConfiguration findOAuthContextConfiguration(
430 			FrameworkMethod method, TestClass testClass) {
431 		OAuth2ContextConfiguration methodConfiguration = method
432 				.getAnnotation(OAuth2ContextConfiguration.class);
433 		if (methodConfiguration != null) {
434 			return methodConfiguration;
435 		}
436 		if (testClass.getJavaClass()
437 				.isAnnotationPresent(OAuth2ContextConfiguration.class)) {
438 			return testClass.getJavaClass().getAnnotation(
439 					OAuth2ContextConfiguration.class);
440 		}
441 		return null;
442 	}
443 
444 }