1 package org.springframework.security.oauth.consumer.client;
2
3 import org.springframework.http.HttpMethod;
4 import org.springframework.http.client.ClientHttpRequest;
5 import org.springframework.http.client.ClientHttpRequestFactory;
6 import org.springframework.security.oauth.consumer.OAuthConsumerSupport;
7 import org.springframework.security.oauth.consumer.OAuthConsumerToken;
8 import org.springframework.security.oauth.consumer.OAuthSecurityContext;
9 import org.springframework.security.oauth.consumer.OAuthSecurityContextHolder;
10 import org.springframework.security.oauth.consumer.OAuthSecurityContextImpl;
11 import org.springframework.security.oauth.consumer.ProtectedResourceDetails;
12 import org.springframework.util.CollectionUtils;
13
14 import java.io.IOException;
15 import java.net.URI;
16 import java.util.Collections;
17 import java.util.HashMap;
18 import java.util.Map;
19
20
21
22
23
24
25 public class OAuthClientHttpRequestFactory implements ClientHttpRequestFactory {
26
27 private final ClientHttpRequestFactory delegate;
28 private final ProtectedResourceDetails resource;
29 private final OAuthConsumerSupport support;
30 private Map<String, String> additionalOAuthParameters;
31
32 public OAuthClientHttpRequestFactory(ClientHttpRequestFactory delegate, ProtectedResourceDetails resource, OAuthConsumerSupport support) {
33 this.delegate = delegate;
34 this.resource = resource;
35 this.support = support;
36
37 if (delegate == null) {
38 throw new IllegalArgumentException("A delegate must be supplied for an OAuth2ClientHttpRequestFactory.");
39 }
40 if (resource == null) {
41 throw new IllegalArgumentException("A resource must be supplied for an OAuth2ClientHttpRequestFactory.");
42 }
43 this.additionalOAuthParameters = !CollectionUtils.isEmpty(resource.getAdditionalParameters()) ?
44 new HashMap<String, String>(resource.getAdditionalParameters()) :
45 Collections.<String, String>emptyMap();
46 }
47
48 public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
49 OAuthSecurityContext context = OAuthSecurityContextHolder.getContext();
50 if (context == null) {
51 context = new OAuthSecurityContextImpl();
52 }
53
54 Map<String, OAuthConsumerToken> accessTokens = context.getAccessTokens();
55 OAuthConsumerToken accessToken = accessTokens == null ? null : accessTokens.get(this.resource.getId());
56
57 boolean useAuthHeader = this.resource.isAcceptsAuthorizationHeader();
58 if (!useAuthHeader) {
59 String queryString = this.support.getOAuthQueryString(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
60 String uriValue = String.valueOf(uri);
61 uri = URI.create((uriValue.contains("?") ? uriValue.substring(0, uriValue.indexOf('?')) : uriValue) + "?" + queryString);
62 }
63
64 ClientHttpRequest req = delegate.createRequest(uri, httpMethod);
65 if (useAuthHeader) {
66 String authHeader = this.support.getAuthorizationHeader(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
67 req.getHeaders().add("Authorization", authHeader);
68 }
69
70 Map<String, String> additionalHeaders = this.resource.getAdditionalRequestHeaders();
71 if (additionalHeaders != null) {
72 for (Map.Entry<String, String> header : additionalHeaders.entrySet()) {
73 req.getHeaders().add(header.getKey(), header.getValue());
74 }
75 }
76 return req;
77 }
78
79
80
81
82
83
84 public Map<String, String> getAdditionalOAuthParameters() {
85 return additionalOAuthParameters;
86 }
87
88
89
90
91
92
93 public void setAdditionalOAuthParameters(Map<String, String> additionalOAuthParameters) {
94 this.additionalOAuthParameters = additionalOAuthParameters;
95 }
96 }