1
2
3
4
5
6
7
8
9
10
11
12
13 package org.springframework.security.oauth2.client.test;
14
15 import java.io.IOException;
16 import java.lang.reflect.Constructor;
17 import java.net.HttpURLConnection;
18 import java.net.URI;
19 import java.util.Arrays;
20 import java.util.LinkedHashMap;
21 import java.util.List;
22 import java.util.Map;
23
24 import org.apache.commons.logging.Log;
25 import org.apache.commons.logging.LogFactory;
26 import org.apache.http.client.config.CookieSpecs;
27 import org.apache.http.client.config.RequestConfig;
28 import org.apache.http.client.config.RequestConfig.Builder;
29 import org.apache.http.client.protocol.HttpClientContext;
30 import org.apache.http.protocol.HttpContext;
31 import org.hamcrest.CoreMatchers;
32 import org.junit.Assert;
33 import org.junit.internal.AssumptionViolatedException;
34 import org.junit.internal.runners.statements.RunBefores;
35 import org.junit.rules.TestWatchman;
36 import org.junit.runners.model.FrameworkMethod;
37 import org.junit.runners.model.Statement;
38 import org.junit.runners.model.TestClass;
39 import org.springframework.beans.BeanUtils;
40 import org.springframework.core.env.Environment;
41 import org.springframework.http.HttpMethod;
42 import org.springframework.http.client.ClientHttpResponse;
43 import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
44 import org.springframework.http.client.SimpleClientHttpRequestFactory;
45 import org.springframework.security.oauth2.client.DefaultOAuth2ClientContext;
46 import org.springframework.security.oauth2.client.OAuth2ClientContext;
47 import org.springframework.security.oauth2.client.OAuth2RestTemplate;
48 import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
49 import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
50 import org.springframework.security.oauth2.client.token.AccessTokenProvider;
51 import org.springframework.security.oauth2.client.token.AccessTokenRequest;
52 import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
53 import org.springframework.security.oauth2.common.OAuth2AccessToken;
54 import org.springframework.util.ClassUtils;
55 import org.springframework.web.client.DefaultResponseErrorHandler;
56 import org.springframework.web.client.RestOperations;
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 @SuppressWarnings("deprecation")
100 public class OAuth2ContextSetup extends TestWatchman {
101
102 private static Log logger = LogFactory.getLog(OAuth2ContextSetup.class);
103
104 private OAuth2ProtectedResourceDetails resource;
105
106 private OAuth2RestTemplate client;
107
108 private Map<String, String> parameters = new LinkedHashMap<String, String>();
109
110 private final RestTemplateHolder clientHolder;
111
112 private final TestAccounts testAccounts;
113
114 private OAuth2AccessToken accessToken;
115
116 private boolean initializeAccessToken = true;
117
118 private RestOperations savedClient;
119
120 private AccessTokenProvider accessTokenProvider;
121
122 private final Environment environment;
123
124
125
126
127
128
129
130
131
132
133
134 public static OAuth2ContextSetup withEnvironment(RestTemplateHolder clientHolder,
135 Environment environment) {
136 return new OAuth2ContextSetup(clientHolder, null, environment);
137 }
138
139
140
141
142
143
144
145
146
147
148
149
150 public static OAuth2ContextSetup withTestAccounts(RestTemplateHolder clientHolder,
151 TestAccounts testAccounts) {
152 return new OAuth2ContextSetup(clientHolder, testAccounts, null);
153 }
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174 public static OAuth2ContextSetup standard(RestTemplateHolder clientHolder) {
175 return new OAuth2ContextSetup(clientHolder, null, null);
176 }
177
178 private OAuth2ContextSetup(RestTemplateHolder clientHolder,
179 TestAccounts testAccounts, Environment environment) {
180 this.clientHolder = clientHolder;
181 this.testAccounts = testAccounts;
182 this.environment = environment;
183 }
184
185 @Override
186 public Statement apply(Statement base, FrameworkMethod method, Object target) {
187 initializeIfNecessary(method, target);
188 return super.apply(base, method, target);
189 }
190
191 @Override
192 public void starting(FrameworkMethod method) {
193 if (resource != null) {
194 logger.info("Starting OAuth2 context for: " + resource);
195 AccessTokenRequest request = new DefaultAccessTokenRequest();
196 request.setAll(parameters);
197 client = createRestTemplate(resource, request);
198 if (initializeAccessToken) {
199 this.accessToken = null;
200 this.accessToken = getAccessToken();
201 }
202 savedClient = clientHolder.getRestTemplate();
203 clientHolder.setRestTemplate(client);
204 }
205 }
206
207 @Override
208 public void finished(FrameworkMethod method) {
209 if (resource != null) {
210 logger.info("Ending OAuth2 context for: " + resource);
211 if (savedClient != null) {
212 clientHolder.setRestTemplate(savedClient);
213 }
214 }
215 }
216
217 public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
218 this.accessTokenProvider = accessTokenProvider;
219 }
220
221 public void setParameters(Map<String, String> parameters) {
222 this.parameters = parameters;
223 }
224
225
226
227
228
229
230
231
232 public OAuth2AccessToken getAccessToken() {
233 if (resource == null || client == null) {
234 return null;
235 }
236 if (accessToken != null) {
237 return accessToken;
238 }
239 if (accessTokenProvider != null) {
240 client.setAccessTokenProvider(accessTokenProvider);
241 }
242 try {
243 return client.getAccessToken();
244 }
245 catch (OAuth2AccessDeniedException e) {
246 Throwable cause = e.getCause();
247 if (cause instanceof RuntimeException) {
248 throw (RuntimeException) cause;
249 }
250 if (cause instanceof Error) {
251 throw (Error) cause;
252 }
253 throw e;
254 }
255 }
256
257
258
259
260 public OAuth2RestTemplate getRestTemplate() {
261 return client;
262 }
263
264
265
266
267 public OAuth2ProtectedResourceDetails getResource() {
268 return resource;
269 }
270
271
272
273
274 public AccessTokenRequest getAccessTokenRequest() {
275 return client.getOAuth2ClientContext().getAccessTokenRequest();
276 }
277
278
279
280
281 public OAuth2ClientContext getOAuth2ClientContext() {
282 return client.getOAuth2ClientContext();
283 }
284
285 private void initializeIfNecessary(FrameworkMethod method, final Object target) {
286
287 final TestClass testClass = new TestClass(target.getClass());
288 OAuth2ContextConfiguration contextConfiguration = findOAuthContextConfiguration(
289 method, testClass);
290 if (contextConfiguration == null) {
291
292 return;
293 }
294
295 this.initializeAccessToken = contextConfiguration.initialize();
296
297 this.resource = creatResource(target, contextConfiguration);
298
299 final List<FrameworkMethod> befores = testClass
300 .getAnnotatedMethods(BeforeOAuth2Context.class);
301 if (!befores.isEmpty()) {
302
303 logger.debug("Running @BeforeOAuth2Context methods");
304
305 for (FrameworkMethod before : befores) {
306
307 RestOperations savedServerClient = clientHolder.getRestTemplate();
308
309 OAuth2ContextConfiguration beforeConfiguration = findOAuthContextConfiguration(
310 before, testClass);
311 if (beforeConfiguration != null) {
312
313 OAuth2ProtectedResourceDetails resource = creatResource(target,
314 beforeConfiguration);
315 AccessTokenRequest beforeRequest = new DefaultAccessTokenRequest();
316 beforeRequest.setAll(parameters);
317 OAuth2RestTemplate client = createRestTemplate(resource,
318 beforeRequest);
319 clientHolder.setRestTemplate(client);
320
321 }
322
323 AccessTokenRequest request = new DefaultAccessTokenRequest();
324 request.setAll(parameters);
325 this.client = createRestTemplate(this.resource, request);
326
327 List<FrameworkMethod> list = Arrays.asList(before);
328 try {
329 new RunBefores(new Statement() {
330 public void evaluate() {
331 }
332 }, list, target).evaluate();
333 }
334 catch (AssumptionViolatedException e) {
335 throw e;
336 }
337 catch (RuntimeException e) {
338 throw e;
339 }
340 catch (AssertionError e) {
341 throw e;
342 }
343 catch (Throwable e) {
344 logger.debug("Exception in befores", e);
345 Assert.assertThat(e, CoreMatchers.not(CoreMatchers.anything()));
346 }
347 finally {
348 clientHolder.setRestTemplate(savedServerClient);
349 }
350
351 }
352
353 }
354
355 }
356
357 private OAuth2RestTemplate createRestTemplate(
358 OAuth2ProtectedResourceDetails resource, AccessTokenRequest request) {
359 OAuth2ClientContext context = new DefaultOAuth2ClientContext(request);
360 OAuth2RestTemplate client = new OAuth2RestTemplate(resource, context);
361 setupConnectionFactory(client);
362 client.setErrorHandler(new DefaultResponseErrorHandler() {
363
364 public boolean hasError(ClientHttpResponse response) throws IOException {
365 return false;
366 }
367 });
368 if (accessTokenProvider != null) {
369 client.setAccessTokenProvider(accessTokenProvider);
370 }
371 return client;
372 }
373
374 private void setupConnectionFactory(OAuth2RestTemplate client) {
375 if (Boolean.getBoolean("http.components.enabled")
376 && ClassUtils.isPresent("org.apache.http.client.config.RequestConfig",
377 null)) {
378 client.setRequestFactory(new HttpComponentsClientHttpRequestFactory() {
379 @Override
380 protected HttpContext createHttpContext(HttpMethod httpMethod, URI uri) {
381 HttpClientContext context = HttpClientContext.create();
382 context.setRequestConfig(getRequestConfig());
383 return context;
384 }
385
386 protected RequestConfig getRequestConfig() {
387 Builder builder = RequestConfig.custom()
388 .setCookieSpec(CookieSpecs.IGNORE_COOKIES)
389 .setAuthenticationEnabled(false).setRedirectsEnabled(false);
390 return builder.build();
391 }
392 });
393 }
394 else {
395 client.setRequestFactory(new SimpleClientHttpRequestFactory() {
396 @Override
397 protected void prepareConnection(HttpURLConnection connection,
398 String httpMethod) throws IOException {
399 super.prepareConnection(connection, httpMethod);
400 connection.setInstanceFollowRedirects(false);
401 }
402 });
403 }
404 }
405
406 private OAuth2ProtectedResourceDetails creatResource(Object target,
407 OAuth2ContextConfiguration contextLoader) {
408 Class<? extends OAuth2ProtectedResourceDetails> type = contextLoader.value();
409 if (type == OAuth2ProtectedResourceDetails.class) {
410 type = contextLoader.resource();
411 }
412 Constructor<? extends OAuth2ProtectedResourceDetails> constructor = ClassUtils
413 .getConstructorIfAvailable(type, TestAccounts.class);
414 if (constructor != null && testAccounts != null) {
415 return BeanUtils.instantiateClass(constructor, testAccounts);
416 }
417 constructor = ClassUtils.getConstructorIfAvailable(type, Environment.class);
418 if (constructor != null && environment != null) {
419 return BeanUtils.instantiateClass(constructor, environment);
420 }
421 constructor = ClassUtils.getConstructorIfAvailable(type, Object.class);
422 if (constructor != null) {
423 return BeanUtils.instantiateClass(constructor, target);
424 }
425
426 return BeanUtils.instantiate(type);
427 }
428
429 private OAuth2ContextConfiguration findOAuthContextConfiguration(
430 FrameworkMethod method, TestClass testClass) {
431 OAuth2ContextConfiguration methodConfiguration = method
432 .getAnnotation(OAuth2ContextConfiguration.class);
433 if (methodConfiguration != null) {
434 return methodConfiguration;
435 }
436 if (testClass.getJavaClass()
437 .isAnnotationPresent(OAuth2ContextConfiguration.class)) {
438 return testClass.getJavaClass().getAnnotation(
439 OAuth2ContextConfiguration.class);
440 }
441 return null;
442 }
443
444 }