1 package org.springframework.security.oauth2.provider.token.store;
2
3 import java.util.Collection;
4 import java.util.Collections;
5 import java.util.Date;
6 import java.util.HashSet;
7 import java.util.concurrent.ConcurrentHashMap;
8 import java.util.concurrent.DelayQueue;
9 import java.util.concurrent.Delayed;
10 import java.util.concurrent.TimeUnit;
11 import java.util.concurrent.atomic.AtomicInteger;
12
13 import org.springframework.security.oauth2.common.OAuth2AccessToken;
14 import org.springframework.security.oauth2.common.OAuth2RefreshToken;
15 import org.springframework.security.oauth2.provider.OAuth2Authentication;
16 import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
17 import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
18 import org.springframework.security.oauth2.provider.token.TokenStore;
19 import org.springframework.util.Assert;
20
21
22
23
24
25
26
27
28 public class InMemoryTokenStore implements TokenStore {
29
30 private static final int DEFAULT_FLUSH_INTERVAL = 1000;
31
32 private final ConcurrentHashMap<String, OAuth2AccessToken> accessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
33
34 private final ConcurrentHashMap<String, OAuth2AccessToken> authenticationToAccessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
35
36 private final ConcurrentHashMap<String, Collection<OAuth2AccessToken>> userNameToAccessTokenStore = new ConcurrentHashMap<String, Collection<OAuth2AccessToken>>();
37
38 private final ConcurrentHashMap<String, Collection<OAuth2AccessToken>> clientIdToAccessTokenStore = new ConcurrentHashMap<String, Collection<OAuth2AccessToken>>();
39
40 private final ConcurrentHashMap<String, OAuth2RefreshToken> refreshTokenStore = new ConcurrentHashMap<String, OAuth2RefreshToken>();
41
42 private final ConcurrentHashMap<String, String> accessTokenToRefreshTokenStore = new ConcurrentHashMap<String, String>();
43
44 private final ConcurrentHashMap<String, OAuth2Authentication> authenticationStore = new ConcurrentHashMap<String, OAuth2Authentication>();
45
46 private final ConcurrentHashMap<String, OAuth2Authentication> refreshTokenAuthenticationStore = new ConcurrentHashMap<String, OAuth2Authentication>();
47
48 private final ConcurrentHashMap<String, String> refreshTokenToAccessTokenStore = new ConcurrentHashMap<String, String>();
49
50 private final DelayQueue<TokenExpiry> expiryQueue = new DelayQueue<TokenExpiry>();
51
52 private final ConcurrentHashMap<String, TokenExpiry> expiryMap = new ConcurrentHashMap<String, TokenExpiry>();
53
54 private int flushInterval = DEFAULT_FLUSH_INTERVAL;
55
56 private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
57
58 private AtomicInteger flushCounter = new AtomicInteger(0);
59
60
61
62
63
64
65 public void setFlushInterval(int flushInterval) {
66 this.flushInterval = flushInterval;
67 }
68
69
70
71
72
73
74 public int getFlushInterval() {
75 return flushInterval;
76 }
77
78
79
80
81 public void clear() {
82 accessTokenStore.clear();
83 authenticationToAccessTokenStore.clear();
84 clientIdToAccessTokenStore.clear();
85 refreshTokenStore.clear();
86 accessTokenToRefreshTokenStore.clear();
87 authenticationStore.clear();
88 refreshTokenAuthenticationStore.clear();
89 refreshTokenToAccessTokenStore.clear();
90 expiryQueue.clear();
91 }
92
93 public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
94 this.authenticationKeyGenerator = authenticationKeyGenerator;
95 }
96
97 public int getAccessTokenCount() {
98 Assert.state(accessTokenStore.isEmpty() || accessTokenStore.size() >= accessTokenToRefreshTokenStore.size(),
99 "Too many refresh tokens");
100 Assert.state(accessTokenStore.size() == authenticationToAccessTokenStore.size(),
101 "Inconsistent token store state");
102 Assert.state(accessTokenStore.size() <= authenticationStore.size(), "Inconsistent authentication store state");
103 return accessTokenStore.size();
104 }
105
106 public int getRefreshTokenCount() {
107 Assert.state(refreshTokenStore.size() == refreshTokenToAccessTokenStore.size(),
108 "Inconsistent refresh token store state");
109 return accessTokenStore.size();
110 }
111
112 public int getExpiryTokenCount() {
113 return expiryQueue.size();
114 }
115
116 public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
117 String key = authenticationKeyGenerator.extractKey(authentication);
118 OAuth2AccessToken accessToken = authenticationToAccessTokenStore.get(key);
119 if (accessToken != null
120 && !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
121
122
123 storeAccessToken(accessToken, authentication);
124 }
125 return accessToken;
126 }
127
128 public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
129 return readAuthentication(token.getValue());
130 }
131
132 public OAuth2Authentication readAuthentication(String token) {
133 return this.authenticationStore.get(token);
134 }
135
136 public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
137 return readAuthenticationForRefreshToken(token.getValue());
138 }
139
140 public OAuth2Authentication readAuthenticationForRefreshToken(String token) {
141 return this.refreshTokenAuthenticationStore.get(token);
142 }
143
144 public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
145 if (this.flushCounter.incrementAndGet() >= this.flushInterval) {
146 flush();
147 this.flushCounter.set(0);
148 }
149 this.accessTokenStore.put(token.getValue(), token);
150 this.authenticationStore.put(token.getValue(), authentication);
151 this.authenticationToAccessTokenStore.put(authenticationKeyGenerator.extractKey(authentication), token);
152 if (!authentication.isClientOnly()) {
153 addToCollection(this.userNameToAccessTokenStore, getApprovalKey(authentication), token);
154 }
155 addToCollection(this.clientIdToAccessTokenStore, authentication.getOAuth2Request().getClientId(), token);
156 if (token.getExpiration() != null) {
157 TokenExpiry expiry = new TokenExpiry(token.getValue(), token.getExpiration());
158
159 expiryQueue.remove(expiryMap.put(token.getValue(), expiry));
160 this.expiryQueue.put(expiry);
161 }
162 if (token.getRefreshToken() != null && token.getRefreshToken().getValue() != null) {
163 this.refreshTokenToAccessTokenStore.put(token.getRefreshToken().getValue(), token.getValue());
164 this.accessTokenToRefreshTokenStore.put(token.getValue(), token.getRefreshToken().getValue());
165 }
166 }
167
168 private String getApprovalKey(OAuth2Authentication authentication) {
169 String userName = authentication.getUserAuthentication() == null ? "" : authentication.getUserAuthentication()
170 .getName();
171 return getApprovalKey(authentication.getOAuth2Request().getClientId(), userName);
172 }
173
174 private String getApprovalKey(String clientId, String userName) {
175 return clientId + (userName==null ? "" : ":" + userName);
176 }
177
178 private void addToCollection(ConcurrentHashMap<String, Collection<OAuth2AccessToken>> store, String key,
179 OAuth2AccessToken token) {
180 if (!store.containsKey(key)) {
181 synchronized (store) {
182 if (!store.containsKey(key)) {
183 store.put(key, new HashSet<OAuth2AccessToken>());
184 }
185 }
186 }
187 store.get(key).add(token);
188 }
189
190 public void removeAccessToken(OAuth2AccessToken accessToken) {
191 removeAccessToken(accessToken.getValue());
192 }
193
194 public OAuth2AccessToken readAccessToken(String tokenValue) {
195 return this.accessTokenStore.get(tokenValue);
196 }
197
198 public void removeAccessToken(String tokenValue) {
199 OAuth2AccessToken removed = this.accessTokenStore.remove(tokenValue);
200 this.accessTokenToRefreshTokenStore.remove(tokenValue);
201
202 OAuth2Authentication authentication = this.authenticationStore.remove(tokenValue);
203 if (authentication != null) {
204 this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
205 Collection<OAuth2AccessToken> tokens;
206 String clientId = authentication.getOAuth2Request().getClientId();
207 tokens = this.userNameToAccessTokenStore.get(getApprovalKey(clientId, authentication.getName()));
208 if (tokens != null) {
209 tokens.remove(removed);
210 }
211 tokens = this.clientIdToAccessTokenStore.get(clientId);
212 if (tokens != null) {
213 tokens.remove(removed);
214 }
215 this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
216 }
217 }
218
219 public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
220 this.refreshTokenStore.put(refreshToken.getValue(), refreshToken);
221 this.refreshTokenAuthenticationStore.put(refreshToken.getValue(), authentication);
222 }
223
224 public OAuth2RefreshToken readRefreshToken(String tokenValue) {
225 return this.refreshTokenStore.get(tokenValue);
226 }
227
228 public void removeRefreshToken(OAuth2RefreshToken refreshToken) {
229 removeRefreshToken(refreshToken.getValue());
230 }
231
232 public void removeRefreshToken(String tokenValue) {
233 this.refreshTokenStore.remove(tokenValue);
234 this.refreshTokenAuthenticationStore.remove(tokenValue);
235 this.refreshTokenToAccessTokenStore.remove(tokenValue);
236 }
237
238 public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
239 removeAccessTokenUsingRefreshToken(refreshToken.getValue());
240 }
241
242 private void removeAccessTokenUsingRefreshToken(String refreshToken) {
243 String accessToken = this.refreshTokenToAccessTokenStore.remove(refreshToken);
244 if (accessToken != null) {
245 removeAccessToken(accessToken);
246 }
247 }
248
249 public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
250 Collection<OAuth2AccessToken> result = userNameToAccessTokenStore.get(getApprovalKey(clientId, userName));
251 return result != null ? Collections.<OAuth2AccessToken> unmodifiableCollection(result) : Collections
252 .<OAuth2AccessToken> emptySet();
253 }
254
255 public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
256 Collection<OAuth2AccessToken> result = clientIdToAccessTokenStore.get(clientId);
257 return result != null ? Collections.<OAuth2AccessToken> unmodifiableCollection(result) : Collections
258 .<OAuth2AccessToken> emptySet();
259 }
260
261 private void flush() {
262 TokenExpiry expiry = expiryQueue.poll();
263 while (expiry != null) {
264 removeAccessToken(expiry.getValue());
265 expiry = expiryQueue.poll();
266 }
267 }
268
269 private static class TokenExpiry implements Delayed {
270
271 private final long expiry;
272
273 private final String value;
274
275 public TokenExpiry(String value, Date date) {
276 this.value = value;
277 this.expiry = date.getTime();
278 }
279
280 public int compareTo(Delayed other) {
281 if (this == other) {
282 return 0;
283 }
284 long diff = getDelay(TimeUnit.MILLISECONDS) - other.getDelay(TimeUnit.MILLISECONDS);
285 return (diff == 0 ? 0 : ((diff < 0) ? -1 : 1));
286 }
287
288 public long getDelay(TimeUnit unit) {
289 return expiry - System.currentTimeMillis();
290 }
291
292 public String getValue() {
293 return value;
294 }
295
296 }
297
298 }