1 package org.springframework.security.oauth2.provider.token.store;
2
3 import java.io.UnsupportedEncodingException;
4 import java.math.BigInteger;
5 import java.security.MessageDigest;
6 import java.security.NoSuchAlgorithmException;
7 import java.sql.ResultSet;
8 import java.sql.SQLException;
9 import java.sql.Types;
10 import java.util.ArrayList;
11 import java.util.Collection;
12 import java.util.List;
13
14 import javax.sql.DataSource;
15
16 import org.apache.commons.logging.Log;
17 import org.apache.commons.logging.LogFactory;
18 import org.springframework.dao.EmptyResultDataAccessException;
19 import org.springframework.jdbc.core.JdbcTemplate;
20 import org.springframework.jdbc.core.RowMapper;
21 import org.springframework.jdbc.core.support.SqlLobValue;
22 import org.springframework.security.oauth2.common.OAuth2AccessToken;
23 import org.springframework.security.oauth2.common.OAuth2RefreshToken;
24 import org.springframework.security.oauth2.common.util.SerializationUtils;
25 import org.springframework.security.oauth2.provider.OAuth2Authentication;
26 import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
27 import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
28 import org.springframework.security.oauth2.provider.token.TokenStore;
29 import org.springframework.util.Assert;
30
31
32
33
34
35
36
37
38 public class JdbcTokenStore implements TokenStore {
39
40 private static final Log LOG = LogFactory.getLog(JdbcTokenStore.class);
41
42 private static final String DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT = "insert into oauth_access_token (token_id, token, authentication_id, user_name, client_id, authentication, refresh_token) values (?, ?, ?, ?, ?, ?, ?)";
43
44 private static final String DEFAULT_ACCESS_TOKEN_SELECT_STATEMENT = "select token_id, token from oauth_access_token where token_id = ?";
45
46 private static final String DEFAULT_ACCESS_TOKEN_AUTHENTICATION_SELECT_STATEMENT = "select token_id, authentication from oauth_access_token where token_id = ?";
47
48 private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_access_token where authentication_id = ?";
49
50 private static final String DEFAULT_ACCESS_TOKENS_FROM_USERNAME_AND_CLIENT_SELECT_STATEMENT = "select token_id, token from oauth_access_token where user_name = ? and client_id = ?";
51
52 private static final String DEFAULT_ACCESS_TOKENS_FROM_USERNAME_SELECT_STATEMENT = "select token_id, token from oauth_access_token where user_name = ?";
53
54 private static final String DEFAULT_ACCESS_TOKENS_FROM_CLIENTID_SELECT_STATEMENT = "select token_id, token from oauth_access_token where client_id = ?";
55
56 private static final String DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT = "delete from oauth_access_token where token_id = ?";
57
58 private static final String DEFAULT_ACCESS_TOKEN_DELETE_FROM_REFRESH_TOKEN_STATEMENT = "delete from oauth_access_token where refresh_token = ?";
59
60 private static final String DEFAULT_REFRESH_TOKEN_INSERT_STATEMENT = "insert into oauth_refresh_token (token_id, token, authentication) values (?, ?, ?)";
61
62 private static final String DEFAULT_REFRESH_TOKEN_SELECT_STATEMENT = "select token_id, token from oauth_refresh_token where token_id = ?";
63
64 private static final String DEFAULT_REFRESH_TOKEN_AUTHENTICATION_SELECT_STATEMENT = "select token_id, authentication from oauth_refresh_token where token_id = ?";
65
66 private static final String DEFAULT_REFRESH_TOKEN_DELETE_STATEMENT = "delete from oauth_refresh_token where token_id = ?";
67
68 private String insertAccessTokenSql = DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT;
69
70 private String selectAccessTokenSql = DEFAULT_ACCESS_TOKEN_SELECT_STATEMENT;
71
72 private String selectAccessTokenAuthenticationSql = DEFAULT_ACCESS_TOKEN_AUTHENTICATION_SELECT_STATEMENT;
73
74 private String selectAccessTokenFromAuthenticationSql = DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT;
75
76 private String selectAccessTokensFromUserNameAndClientIdSql = DEFAULT_ACCESS_TOKENS_FROM_USERNAME_AND_CLIENT_SELECT_STATEMENT;
77
78 private String selectAccessTokensFromUserNameSql = DEFAULT_ACCESS_TOKENS_FROM_USERNAME_SELECT_STATEMENT;
79
80 private String selectAccessTokensFromClientIdSql = DEFAULT_ACCESS_TOKENS_FROM_CLIENTID_SELECT_STATEMENT;
81
82 private String deleteAccessTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT;
83
84 private String insertRefreshTokenSql = DEFAULT_REFRESH_TOKEN_INSERT_STATEMENT;
85
86 private String selectRefreshTokenSql = DEFAULT_REFRESH_TOKEN_SELECT_STATEMENT;
87
88 private String selectRefreshTokenAuthenticationSql = DEFAULT_REFRESH_TOKEN_AUTHENTICATION_SELECT_STATEMENT;
89
90 private String deleteRefreshTokenSql = DEFAULT_REFRESH_TOKEN_DELETE_STATEMENT;
91
92 private String deleteAccessTokenFromRefreshTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_FROM_REFRESH_TOKEN_STATEMENT;
93
94 private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
95
96 private final JdbcTemplate jdbcTemplate;
97
98 public JdbcTokenStore(DataSource dataSource) {
99 Assert.notNull(dataSource, "DataSource required");
100 this.jdbcTemplate = new JdbcTemplate(dataSource);
101 }
102
103 public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
104 this.authenticationKeyGenerator = authenticationKeyGenerator;
105 }
106
107 public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
108 OAuth2AccessToken accessToken = null;
109
110 String key = authenticationKeyGenerator.extractKey(authentication);
111 try {
112 accessToken = jdbcTemplate.queryForObject(selectAccessTokenFromAuthenticationSql,
113 new RowMapper<OAuth2AccessToken>() {
114 public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
115 return deserializeAccessToken(rs.getBytes(2));
116 }
117 }, key);
118 }
119 catch (EmptyResultDataAccessException e) {
120 if (LOG.isDebugEnabled()) {
121 LOG.debug("Failed to find access token for authentication " + authentication);
122 }
123 }
124 catch (IllegalArgumentException e) {
125 LOG.error("Could not extract access token for authentication " + authentication, e);
126 }
127
128 if (accessToken != null
129 && !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
130 removeAccessToken(accessToken.getValue());
131
132
133 storeAccessToken(accessToken, authentication);
134 }
135 return accessToken;
136 }
137
138 public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
139 String refreshToken = null;
140 if (token.getRefreshToken() != null) {
141 refreshToken = token.getRefreshToken().getValue();
142 }
143
144 if (readAccessToken(token.getValue())!=null) {
145 removeAccessToken(token.getValue());
146 }
147
148 jdbcTemplate.update(insertAccessTokenSql, new Object[] { extractTokenKey(token.getValue()),
149 new SqlLobValue(serializeAccessToken(token)), authenticationKeyGenerator.extractKey(authentication),
150 authentication.isClientOnly() ? null : authentication.getName(),
151 authentication.getOAuth2Request().getClientId(),
152 new SqlLobValue(serializeAuthentication(authentication)), extractTokenKey(refreshToken) }, new int[] {
153 Types.VARCHAR, Types.BLOB, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR, Types.BLOB, Types.VARCHAR });
154 }
155
156 public OAuth2AccessToken readAccessToken(String tokenValue) {
157 OAuth2AccessToken accessToken = null;
158
159 try {
160 accessToken = jdbcTemplate.queryForObject(selectAccessTokenSql, new RowMapper<OAuth2AccessToken>() {
161 public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
162 return deserializeAccessToken(rs.getBytes(2));
163 }
164 }, extractTokenKey(tokenValue));
165 }
166 catch (EmptyResultDataAccessException e) {
167 if (LOG.isInfoEnabled()) {
168 LOG.info("Failed to find access token");
169 }
170 }
171 catch (IllegalArgumentException e) {
172 LOG.warn("Failed to deserialize access token", e);
173 removeAccessToken(tokenValue);
174 }
175
176 return accessToken;
177 }
178
179 public void removeAccessToken(OAuth2AccessToken token) {
180 removeAccessToken(token.getValue());
181 }
182
183 public void removeAccessToken(String tokenValue) {
184 jdbcTemplate.update(deleteAccessTokenSql, extractTokenKey(tokenValue));
185 }
186
187 public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
188 return readAuthentication(token.getValue());
189 }
190
191 public OAuth2Authentication readAuthentication(String token) {
192 OAuth2Authentication authentication = null;
193
194 try {
195 authentication = jdbcTemplate.queryForObject(selectAccessTokenAuthenticationSql,
196 new RowMapper<OAuth2Authentication>() {
197 public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
198 return deserializeAuthentication(rs.getBytes(2));
199 }
200 }, extractTokenKey(token));
201 }
202 catch (EmptyResultDataAccessException e) {
203 if (LOG.isInfoEnabled()) {
204 LOG.info("Failed to find access token");
205 }
206 }
207 catch (IllegalArgumentException e) {
208 LOG.warn("Failed to deserialize authentication", e);
209 removeAccessToken(token);
210 }
211
212 return authentication;
213 }
214
215 public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
216 jdbcTemplate.update(insertRefreshTokenSql, new Object[] { extractTokenKey(refreshToken.getValue()),
217 new SqlLobValue(serializeRefreshToken(refreshToken)),
218 new SqlLobValue(serializeAuthentication(authentication)) }, new int[] { Types.VARCHAR, Types.BLOB,
219 Types.BLOB });
220 }
221
222 public OAuth2RefreshToken readRefreshToken(String token) {
223 OAuth2RefreshToken refreshToken = null;
224
225 try {
226 refreshToken = jdbcTemplate.queryForObject(selectRefreshTokenSql, new RowMapper<OAuth2RefreshToken>() {
227 public OAuth2RefreshToken mapRow(ResultSet rs, int rowNum) throws SQLException {
228 return deserializeRefreshToken(rs.getBytes(2));
229 }
230 }, extractTokenKey(token));
231 }
232 catch (EmptyResultDataAccessException e) {
233 if (LOG.isInfoEnabled()) {
234 LOG.info("Failed to find refresh token");
235 }
236 }
237 catch (IllegalArgumentException e) {
238 LOG.warn("Failed to deserialize refresh token", e);
239 removeRefreshToken(token);
240 }
241
242 return refreshToken;
243 }
244
245 public void removeRefreshToken(OAuth2RefreshToken token) {
246 removeRefreshToken(token.getValue());
247 }
248
249 public void removeRefreshToken(String token) {
250 jdbcTemplate.update(deleteRefreshTokenSql, extractTokenKey(token));
251 }
252
253 public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
254 return readAuthenticationForRefreshToken(token.getValue());
255 }
256
257 public OAuth2Authentication readAuthenticationForRefreshToken(String value) {
258 OAuth2Authentication authentication = null;
259
260 try {
261 authentication = jdbcTemplate.queryForObject(selectRefreshTokenAuthenticationSql,
262 new RowMapper<OAuth2Authentication>() {
263 public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
264 return deserializeAuthentication(rs.getBytes(2));
265 }
266 }, extractTokenKey(value));
267 }
268 catch (EmptyResultDataAccessException e) {
269 if (LOG.isInfoEnabled()) {
270 LOG.info("Failed to find access token");
271 }
272 }
273 catch (IllegalArgumentException e) {
274 LOG.warn("Failed to deserialize access token", e);
275 removeRefreshToken(value);
276 }
277
278 return authentication;
279 }
280
281 public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
282 removeAccessTokenUsingRefreshToken(refreshToken.getValue());
283 }
284
285 public void removeAccessTokenUsingRefreshToken(String refreshToken) {
286 jdbcTemplate.update(deleteAccessTokenFromRefreshTokenSql, new Object[] { extractTokenKey(refreshToken) },
287 new int[] { Types.VARCHAR });
288 }
289
290 public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
291 List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
292
293 try {
294 accessTokens = jdbcTemplate.query(selectAccessTokensFromClientIdSql, new SafeAccessTokenRowMapper(),
295 clientId);
296 }
297 catch (EmptyResultDataAccessException e) {
298 if (LOG.isInfoEnabled()) {
299 LOG.info("Failed to find access token for clientId " + clientId);
300 }
301 }
302 accessTokens = removeNulls(accessTokens);
303
304 return accessTokens;
305 }
306
307 public Collection<OAuth2AccessToken> findTokensByUserName(String userName) {
308 List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
309
310 try {
311 accessTokens = jdbcTemplate.query(selectAccessTokensFromUserNameSql, new SafeAccessTokenRowMapper(),
312 userName);
313 }
314 catch (EmptyResultDataAccessException e) {
315 if (LOG.isInfoEnabled())
316 LOG.info("Failed to find access token for userName " + userName);
317 }
318 accessTokens = removeNulls(accessTokens);
319
320 return accessTokens;
321 }
322
323 public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
324 List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
325
326 try {
327 accessTokens = jdbcTemplate.query(selectAccessTokensFromUserNameAndClientIdSql, new SafeAccessTokenRowMapper(),
328 userName, clientId);
329 }
330 catch (EmptyResultDataAccessException e) {
331 if (LOG.isInfoEnabled()) {
332 LOG.info("Failed to find access token for clientId " + clientId + " and userName " + userName);
333 }
334 }
335 accessTokens = removeNulls(accessTokens);
336
337 return accessTokens;
338 }
339
340 private List<OAuth2AccessToken> removeNulls(List<OAuth2AccessToken> accessTokens) {
341 List<OAuth2AccessToken> tokens = new ArrayList<OAuth2AccessToken>();
342 for (OAuth2AccessToken token : accessTokens) {
343 if (token != null) {
344 tokens.add(token);
345 }
346 }
347 return tokens;
348 }
349
350 protected String extractTokenKey(String value) {
351 if (value == null) {
352 return null;
353 }
354 MessageDigest digest;
355 try {
356 digest = MessageDigest.getInstance("MD5");
357 }
358 catch (NoSuchAlgorithmException e) {
359 throw new IllegalStateException("MD5 algorithm not available. Fatal (should be in the JDK).");
360 }
361
362 try {
363 byte[] bytes = digest.digest(value.getBytes("UTF-8"));
364 return String.format("%032x", new BigInteger(1, bytes));
365 }
366 catch (UnsupportedEncodingException e) {
367 throw new IllegalStateException("UTF-8 encoding not available. Fatal (should be in the JDK).");
368 }
369 }
370
371 private final class SafeAccessTokenRowMapper implements RowMapper<OAuth2AccessToken> {
372 public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
373 try {
374 return deserializeAccessToken(rs.getBytes(2));
375 }
376 catch (IllegalArgumentException e) {
377 String token = rs.getString(1);
378 jdbcTemplate.update(deleteAccessTokenSql, token);
379 return null;
380 }
381 }
382 }
383
384 protected byte[] serializeAccessToken(OAuth2AccessToken token) {
385 return SerializationUtils.serialize(token);
386 }
387
388 protected byte[] serializeRefreshToken(OAuth2RefreshToken token) {
389 return SerializationUtils.serialize(token);
390 }
391
392 protected byte[] serializeAuthentication(OAuth2Authentication authentication) {
393 return SerializationUtils.serialize(authentication);
394 }
395
396 protected OAuth2AccessToken deserializeAccessToken(byte[] token) {
397 return SerializationUtils.deserialize(token);
398 }
399
400 protected OAuth2RefreshToken deserializeRefreshToken(byte[] token) {
401 return SerializationUtils.deserialize(token);
402 }
403
404 protected OAuth2Authentication deserializeAuthentication(byte[] authentication) {
405 return SerializationUtils.deserialize(authentication);
406 }
407
408 public void setInsertAccessTokenSql(String insertAccessTokenSql) {
409 this.insertAccessTokenSql = insertAccessTokenSql;
410 }
411
412 public void setSelectAccessTokenSql(String selectAccessTokenSql) {
413 this.selectAccessTokenSql = selectAccessTokenSql;
414 }
415
416 public void setDeleteAccessTokenSql(String deleteAccessTokenSql) {
417 this.deleteAccessTokenSql = deleteAccessTokenSql;
418 }
419
420 public void setInsertRefreshTokenSql(String insertRefreshTokenSql) {
421 this.insertRefreshTokenSql = insertRefreshTokenSql;
422 }
423
424 public void setSelectRefreshTokenSql(String selectRefreshTokenSql) {
425 this.selectRefreshTokenSql = selectRefreshTokenSql;
426 }
427
428 public void setDeleteRefreshTokenSql(String deleteRefreshTokenSql) {
429 this.deleteRefreshTokenSql = deleteRefreshTokenSql;
430 }
431
432 public void setSelectAccessTokenAuthenticationSql(String selectAccessTokenAuthenticationSql) {
433 this.selectAccessTokenAuthenticationSql = selectAccessTokenAuthenticationSql;
434 }
435
436 public void setSelectRefreshTokenAuthenticationSql(String selectRefreshTokenAuthenticationSql) {
437 this.selectRefreshTokenAuthenticationSql = selectRefreshTokenAuthenticationSql;
438 }
439
440 public void setSelectAccessTokenFromAuthenticationSql(String selectAccessTokenFromAuthenticationSql) {
441 this.selectAccessTokenFromAuthenticationSql = selectAccessTokenFromAuthenticationSql;
442 }
443
444 public void setDeleteAccessTokenFromRefreshTokenSql(String deleteAccessTokenFromRefreshTokenSql) {
445 this.deleteAccessTokenFromRefreshTokenSql = deleteAccessTokenFromRefreshTokenSql;
446 }
447
448 public void setSelectAccessTokensFromUserNameSql(String selectAccessTokensFromUserNameSql) {
449 this.selectAccessTokensFromUserNameSql = selectAccessTokensFromUserNameSql;
450 }
451
452 public void setSelectAccessTokensFromUserNameAndClientIdSql(String selectAccessTokensFromUserNameAndClientIdSql) {
453 this.selectAccessTokensFromUserNameAndClientIdSql = selectAccessTokensFromUserNameAndClientIdSql;
454 }
455
456 public void setSelectAccessTokensFromClientIdSql(String selectAccessTokensFromClientIdSql) {
457 this.selectAccessTokensFromClientIdSql = selectAccessTokensFromClientIdSql;
458 }
459
460 }