View Javadoc
1   package org.springframework.security.oauth2.provider.token.store;
2   
3   import java.util.Collection;
4   import java.util.Collections;
5   import java.util.Date;
6   import java.util.HashSet;
7   import java.util.concurrent.ConcurrentHashMap;
8   import java.util.concurrent.DelayQueue;
9   import java.util.concurrent.Delayed;
10  import java.util.concurrent.TimeUnit;
11  import java.util.concurrent.atomic.AtomicInteger;
12  
13  import org.springframework.security.oauth2.common.OAuth2AccessToken;
14  import org.springframework.security.oauth2.common.OAuth2RefreshToken;
15  import org.springframework.security.oauth2.provider.OAuth2Authentication;
16  import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
17  import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
18  import org.springframework.security.oauth2.provider.token.TokenStore;
19  import org.springframework.util.Assert;
20  
21  /**
22   * Implementation of token services that stores tokens in memory.
23   * 
24   * @author Ryan Heaton
25   * @author Luke Taylor
26   * @author Dave Syer
27   */
28  public class InMemoryTokenStore implements TokenStore {
29  
30  	private static final int DEFAULT_FLUSH_INTERVAL = 1000;
31  
32  	private final ConcurrentHashMap<String, OAuth2AccessToken> accessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
33  
34  	private final ConcurrentHashMap<String, OAuth2AccessToken> authenticationToAccessTokenStore = new ConcurrentHashMap<String, OAuth2AccessToken>();
35  
36  	private final ConcurrentHashMap<String, Collection<OAuth2AccessToken>> userNameToAccessTokenStore = new ConcurrentHashMap<String, Collection<OAuth2AccessToken>>();
37  
38  	private final ConcurrentHashMap<String, Collection<OAuth2AccessToken>> clientIdToAccessTokenStore = new ConcurrentHashMap<String, Collection<OAuth2AccessToken>>();
39  
40  	private final ConcurrentHashMap<String, OAuth2RefreshToken> refreshTokenStore = new ConcurrentHashMap<String, OAuth2RefreshToken>();
41  
42  	private final ConcurrentHashMap<String, String> accessTokenToRefreshTokenStore = new ConcurrentHashMap<String, String>();
43  
44  	private final ConcurrentHashMap<String, OAuth2Authentication> authenticationStore = new ConcurrentHashMap<String, OAuth2Authentication>();
45  
46  	private final ConcurrentHashMap<String, OAuth2Authentication> refreshTokenAuthenticationStore = new ConcurrentHashMap<String, OAuth2Authentication>();
47  
48  	private final ConcurrentHashMap<String, String> refreshTokenToAccessTokenStore = new ConcurrentHashMap<String, String>();
49  
50  	private final DelayQueue<TokenExpiry> expiryQueue = new DelayQueue<TokenExpiry>();
51  
52  	private final ConcurrentHashMap<String, TokenExpiry> expiryMap = new ConcurrentHashMap<String, TokenExpiry>();
53  
54  	private int flushInterval = DEFAULT_FLUSH_INTERVAL;
55  
56  	private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
57  
58  	private AtomicInteger flushCounter = new AtomicInteger(0);
59  
60  	/**
61  	 * The number of tokens to store before flushing expired tokens. Defaults to 1000.
62  	 * 
63  	 * @param flushInterval the interval to set
64  	 */
65  	public void setFlushInterval(int flushInterval) {
66  		this.flushInterval = flushInterval;
67  	}
68  
69  	/**
70  	 * The interval (count of token inserts) between flushing expired tokens.
71  	 * 
72  	 * @return the flushInterval the flush interval
73  	 */
74  	public int getFlushInterval() {
75  		return flushInterval;
76  	}
77  
78  	/**
79  	 * Convenience method for super admin users to remove all tokens (useful for testing, not really in production)
80  	 */
81  	public void clear() {
82  		accessTokenStore.clear();
83  		authenticationToAccessTokenStore.clear();
84  		clientIdToAccessTokenStore.clear();
85  		refreshTokenStore.clear();
86  		accessTokenToRefreshTokenStore.clear();
87  		authenticationStore.clear();
88  		refreshTokenAuthenticationStore.clear();
89  		refreshTokenToAccessTokenStore.clear();
90  		expiryQueue.clear();
91  	}
92  
93  	public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
94  		this.authenticationKeyGenerator = authenticationKeyGenerator;
95  	}
96  
97  	public int getAccessTokenCount() {
98  		Assert.state(accessTokenStore.isEmpty() || accessTokenStore.size() >= accessTokenToRefreshTokenStore.size(),
99  				"Too many refresh tokens");
100 		Assert.state(accessTokenStore.size() == authenticationToAccessTokenStore.size(),
101 				"Inconsistent token store state");
102 		Assert.state(accessTokenStore.size() <= authenticationStore.size(), "Inconsistent authentication store state");
103 		return accessTokenStore.size();
104 	}
105 
106 	public int getRefreshTokenCount() {
107 		Assert.state(refreshTokenStore.size() == refreshTokenToAccessTokenStore.size(),
108 				"Inconsistent refresh token store state");
109 		return accessTokenStore.size();
110 	}
111 
112 	public int getExpiryTokenCount() {
113 		return expiryQueue.size();
114 	}
115 
116 	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
117 		String key = authenticationKeyGenerator.extractKey(authentication);
118 		OAuth2AccessToken accessToken = authenticationToAccessTokenStore.get(key);
119 		if (accessToken != null
120 				&& !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
121 			// Keep the stores consistent (maybe the same user is represented by this authentication but the details
122 			// have changed)
123 			storeAccessToken(accessToken, authentication);
124 		}
125 		return accessToken;
126 	}
127 
128 	public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
129 		return readAuthentication(token.getValue());
130 	}
131 
132 	public OAuth2Authentication readAuthentication(String token) {
133 		return this.authenticationStore.get(token);
134 	}
135 
136 	public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
137 		return readAuthenticationForRefreshToken(token.getValue());
138 	}
139 
140 	public OAuth2Authentication readAuthenticationForRefreshToken(String token) {
141 		return this.refreshTokenAuthenticationStore.get(token);
142 	}
143 
144 	public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
145 		if (this.flushCounter.incrementAndGet() >= this.flushInterval) {
146 			flush();
147 			this.flushCounter.set(0);
148 		}
149 		this.accessTokenStore.put(token.getValue(), token);
150 		this.authenticationStore.put(token.getValue(), authentication);
151 		this.authenticationToAccessTokenStore.put(authenticationKeyGenerator.extractKey(authentication), token);
152 		if (!authentication.isClientOnly()) {
153 			addToCollection(this.userNameToAccessTokenStore, getApprovalKey(authentication), token);
154 		}
155 		addToCollection(this.clientIdToAccessTokenStore, authentication.getOAuth2Request().getClientId(), token);
156 		if (token.getExpiration() != null) {
157 			TokenExpiry expiry = new TokenExpiry(token.getValue(), token.getExpiration());
158 			// Remove existing expiry for this token if present
159 			expiryQueue.remove(expiryMap.put(token.getValue(), expiry));
160 			this.expiryQueue.put(expiry);
161 		}
162 		if (token.getRefreshToken() != null && token.getRefreshToken().getValue() != null) {
163 			this.refreshTokenToAccessTokenStore.put(token.getRefreshToken().getValue(), token.getValue());
164 			this.accessTokenToRefreshTokenStore.put(token.getValue(), token.getRefreshToken().getValue());
165 		}
166 	}
167 
168 	private String getApprovalKey(OAuth2Authentication authentication) {
169 		String userName = authentication.getUserAuthentication() == null ? "" : authentication.getUserAuthentication()
170 				.getName();
171 		return getApprovalKey(authentication.getOAuth2Request().getClientId(), userName);
172 	}
173 
174 	private String getApprovalKey(String clientId, String userName) {
175 		return clientId + (userName==null ? "" : ":" + userName);
176 	}
177 
178 	private void addToCollection(ConcurrentHashMap<String, Collection<OAuth2AccessToken>> store, String key,
179 			OAuth2AccessToken token) {
180 		if (!store.containsKey(key)) {
181 			synchronized (store) {
182 				if (!store.containsKey(key)) {
183 					store.put(key, new HashSet<OAuth2AccessToken>());
184 				}
185 			}
186 		}
187 		store.get(key).add(token);
188 	}
189 
190 	public void removeAccessToken(OAuth2AccessToken accessToken) {
191 		removeAccessToken(accessToken.getValue());
192 	}
193 
194 	public OAuth2AccessToken readAccessToken(String tokenValue) {
195 		return this.accessTokenStore.get(tokenValue);
196 	}
197 
198 	public void removeAccessToken(String tokenValue) {
199 		OAuth2AccessToken removed = this.accessTokenStore.remove(tokenValue);
200 		this.accessTokenToRefreshTokenStore.remove(tokenValue);
201 		// Don't remove the refresh token - it's up to the caller to do that
202 		OAuth2Authentication authentication = this.authenticationStore.remove(tokenValue);
203 		if (authentication != null) {
204 			this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
205 			Collection<OAuth2AccessToken> tokens;
206 			String clientId = authentication.getOAuth2Request().getClientId();
207 			tokens = this.userNameToAccessTokenStore.get(getApprovalKey(clientId, authentication.getName()));
208 			if (tokens != null) {
209 				tokens.remove(removed);
210 			}
211 			tokens = this.clientIdToAccessTokenStore.get(clientId);
212 			if (tokens != null) {
213 				tokens.remove(removed);
214 			}
215 			this.authenticationToAccessTokenStore.remove(authenticationKeyGenerator.extractKey(authentication));
216 		}
217 	}
218 
219 	public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
220 		this.refreshTokenStore.put(refreshToken.getValue(), refreshToken);
221 		this.refreshTokenAuthenticationStore.put(refreshToken.getValue(), authentication);
222 	}
223 
224 	public OAuth2RefreshToken readRefreshToken(String tokenValue) {
225 		return this.refreshTokenStore.get(tokenValue);
226 	}
227 
228 	public void removeRefreshToken(OAuth2RefreshToken refreshToken) {
229 		removeRefreshToken(refreshToken.getValue());
230 	}
231 
232 	public void removeRefreshToken(String tokenValue) {
233 		this.refreshTokenStore.remove(tokenValue);
234 		this.refreshTokenAuthenticationStore.remove(tokenValue);
235 		this.refreshTokenToAccessTokenStore.remove(tokenValue);
236 	}
237 
238 	public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
239 		removeAccessTokenUsingRefreshToken(refreshToken.getValue());
240 	}
241 
242 	private void removeAccessTokenUsingRefreshToken(String refreshToken) {
243 		String accessToken = this.refreshTokenToAccessTokenStore.remove(refreshToken);
244 		if (accessToken != null) {
245 			removeAccessToken(accessToken);
246 		}
247 	}
248 
249 	public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
250 		Collection<OAuth2AccessToken> result = userNameToAccessTokenStore.get(getApprovalKey(clientId, userName));
251 		return result != null ? Collections.<OAuth2AccessToken> unmodifiableCollection(result) : Collections
252 				.<OAuth2AccessToken> emptySet();
253 	}
254 
255 	public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
256 		Collection<OAuth2AccessToken> result = clientIdToAccessTokenStore.get(clientId);
257 		return result != null ? Collections.<OAuth2AccessToken> unmodifiableCollection(result) : Collections
258 				.<OAuth2AccessToken> emptySet();
259 	}
260 
261 	private void flush() {
262 		TokenExpiry expiry = expiryQueue.poll();
263 		while (expiry != null) {
264 			removeAccessToken(expiry.getValue());
265 			expiry = expiryQueue.poll();
266 		}
267 	}
268 
269 	private static class TokenExpiry implements Delayed {
270 
271 		private final long expiry;
272 
273 		private final String value;
274 
275 		public TokenExpiry(String value, Date date) {
276 			this.value = value;
277 			this.expiry = date.getTime();
278 		}
279 
280 		public int compareTo(Delayed other) {
281 			if (this == other) {
282 				return 0;
283 			}
284 			long diff = getDelay(TimeUnit.MILLISECONDS) - other.getDelay(TimeUnit.MILLISECONDS);
285 			return (diff == 0 ? 0 : ((diff < 0) ? -1 : 1));
286 		}
287 
288 		public long getDelay(TimeUnit unit) {
289 			return expiry - System.currentTimeMillis();
290 		}
291 
292 		public String getValue() {
293 			return value;
294 		}
295 
296 	}
297 
298 }