View Javadoc
1   package org.springframework.security.oauth2.provider.token.store;
2   
3   import java.io.UnsupportedEncodingException;
4   import java.math.BigInteger;
5   import java.security.MessageDigest;
6   import java.security.NoSuchAlgorithmException;
7   import java.sql.ResultSet;
8   import java.sql.SQLException;
9   import java.sql.Types;
10  import java.util.ArrayList;
11  import java.util.Collection;
12  import java.util.List;
13  
14  import javax.sql.DataSource;
15  
16  import org.apache.commons.logging.Log;
17  import org.apache.commons.logging.LogFactory;
18  import org.springframework.dao.EmptyResultDataAccessException;
19  import org.springframework.jdbc.core.JdbcTemplate;
20  import org.springframework.jdbc.core.RowMapper;
21  import org.springframework.jdbc.core.support.SqlLobValue;
22  import org.springframework.security.oauth2.common.OAuth2AccessToken;
23  import org.springframework.security.oauth2.common.OAuth2RefreshToken;
24  import org.springframework.security.oauth2.common.util.SerializationUtils;
25  import org.springframework.security.oauth2.provider.OAuth2Authentication;
26  import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
27  import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
28  import org.springframework.security.oauth2.provider.token.TokenStore;
29  import org.springframework.util.Assert;
30  
31  /**
32   * Implementation of token services that stores tokens in a database.
33   *
34   * @author Ken Dombeck
35   * @author Luke Taylor
36   * @author Dave Syer
37   */
38  public class JdbcTokenStore implements TokenStore {
39  
40  	private static final Log LOG = LogFactory.getLog(JdbcTokenStore.class);
41  
42  	private static final String DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT = "insert into oauth_access_token (token_id, token, authentication_id, user_name, client_id, authentication, refresh_token) values (?, ?, ?, ?, ?, ?, ?)";
43  
44  	private static final String DEFAULT_ACCESS_TOKEN_SELECT_STATEMENT = "select token_id, token from oauth_access_token where token_id = ?";
45  
46  	private static final String DEFAULT_ACCESS_TOKEN_AUTHENTICATION_SELECT_STATEMENT = "select token_id, authentication from oauth_access_token where token_id = ?";
47  
48  	private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_access_token where authentication_id = ?";
49  
50  	private static final String DEFAULT_ACCESS_TOKENS_FROM_USERNAME_AND_CLIENT_SELECT_STATEMENT = "select token_id, token from oauth_access_token where user_name = ? and client_id = ?";
51  
52  	private static final String DEFAULT_ACCESS_TOKENS_FROM_USERNAME_SELECT_STATEMENT = "select token_id, token from oauth_access_token where user_name = ?";
53  
54  	private static final String DEFAULT_ACCESS_TOKENS_FROM_CLIENTID_SELECT_STATEMENT = "select token_id, token from oauth_access_token where client_id = ?";
55  
56  	private static final String DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT = "delete from oauth_access_token where token_id = ?";
57  
58  	private static final String DEFAULT_ACCESS_TOKEN_DELETE_FROM_REFRESH_TOKEN_STATEMENT = "delete from oauth_access_token where refresh_token = ?";
59  
60  	private static final String DEFAULT_REFRESH_TOKEN_INSERT_STATEMENT = "insert into oauth_refresh_token (token_id, token, authentication) values (?, ?, ?)";
61  
62  	private static final String DEFAULT_REFRESH_TOKEN_SELECT_STATEMENT = "select token_id, token from oauth_refresh_token where token_id = ?";
63  
64  	private static final String DEFAULT_REFRESH_TOKEN_AUTHENTICATION_SELECT_STATEMENT = "select token_id, authentication from oauth_refresh_token where token_id = ?";
65  
66  	private static final String DEFAULT_REFRESH_TOKEN_DELETE_STATEMENT = "delete from oauth_refresh_token where token_id = ?";
67  
68  	private String insertAccessTokenSql = DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT;
69  
70  	private String selectAccessTokenSql = DEFAULT_ACCESS_TOKEN_SELECT_STATEMENT;
71  
72  	private String selectAccessTokenAuthenticationSql = DEFAULT_ACCESS_TOKEN_AUTHENTICATION_SELECT_STATEMENT;
73  
74  	private String selectAccessTokenFromAuthenticationSql = DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT;
75  
76  	private String selectAccessTokensFromUserNameAndClientIdSql = DEFAULT_ACCESS_TOKENS_FROM_USERNAME_AND_CLIENT_SELECT_STATEMENT;
77  
78  	private String selectAccessTokensFromUserNameSql = DEFAULT_ACCESS_TOKENS_FROM_USERNAME_SELECT_STATEMENT;
79  
80  	private String selectAccessTokensFromClientIdSql = DEFAULT_ACCESS_TOKENS_FROM_CLIENTID_SELECT_STATEMENT;
81  
82  	private String deleteAccessTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT;
83  
84  	private String insertRefreshTokenSql = DEFAULT_REFRESH_TOKEN_INSERT_STATEMENT;
85  
86  	private String selectRefreshTokenSql = DEFAULT_REFRESH_TOKEN_SELECT_STATEMENT;
87  
88  	private String selectRefreshTokenAuthenticationSql = DEFAULT_REFRESH_TOKEN_AUTHENTICATION_SELECT_STATEMENT;
89  
90  	private String deleteRefreshTokenSql = DEFAULT_REFRESH_TOKEN_DELETE_STATEMENT;
91  
92  	private String deleteAccessTokenFromRefreshTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_FROM_REFRESH_TOKEN_STATEMENT;
93  
94  	private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
95  
96  	private final JdbcTemplate jdbcTemplate;
97  
98  	public JdbcTokenStore(DataSource dataSource) {
99  		Assert.notNull(dataSource, "DataSource required");
100 		this.jdbcTemplate = new JdbcTemplate(dataSource);
101 	}
102 
103 	public void setAuthenticationKeyGenerator(AuthenticationKeyGenerator authenticationKeyGenerator) {
104 		this.authenticationKeyGenerator = authenticationKeyGenerator;
105 	}
106 
107 	public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
108 		OAuth2AccessToken accessToken = null;
109 
110 		String key = authenticationKeyGenerator.extractKey(authentication);
111 		try {
112 			accessToken = jdbcTemplate.queryForObject(selectAccessTokenFromAuthenticationSql,
113 					new RowMapper<OAuth2AccessToken>() {
114 						public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
115 							return deserializeAccessToken(rs.getBytes(2));
116 						}
117 					}, key);
118 		}
119 		catch (EmptyResultDataAccessException e) {
120 			if (LOG.isDebugEnabled()) {
121 				LOG.debug("Failed to find access token for authentication " + authentication);
122 			}
123 		}
124 		catch (IllegalArgumentException e) {
125 			LOG.error("Could not extract access token for authentication " + authentication, e);
126 		}
127 
128 		if (accessToken != null
129 				&& !key.equals(authenticationKeyGenerator.extractKey(readAuthentication(accessToken.getValue())))) {
130 			removeAccessToken(accessToken.getValue());
131 			// Keep the store consistent (maybe the same user is represented by this authentication but the details have
132 			// changed)
133 			storeAccessToken(accessToken, authentication);
134 		}
135 		return accessToken;
136 	}
137 
138 	public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
139 		String refreshToken = null;
140 		if (token.getRefreshToken() != null) {
141 			refreshToken = token.getRefreshToken().getValue();
142 		}
143 		
144 		if (readAccessToken(token.getValue())!=null) {
145 			removeAccessToken(token.getValue());
146 		}
147 
148 		jdbcTemplate.update(insertAccessTokenSql, new Object[] { extractTokenKey(token.getValue()),
149 				new SqlLobValue(serializeAccessToken(token)), authenticationKeyGenerator.extractKey(authentication),
150 				authentication.isClientOnly() ? null : authentication.getName(),
151 				authentication.getOAuth2Request().getClientId(),
152 				new SqlLobValue(serializeAuthentication(authentication)), extractTokenKey(refreshToken) }, new int[] {
153 				Types.VARCHAR, Types.BLOB, Types.VARCHAR, Types.VARCHAR, Types.VARCHAR, Types.BLOB, Types.VARCHAR });
154 	}
155 
156 	public OAuth2AccessToken readAccessToken(String tokenValue) {
157 		OAuth2AccessToken accessToken = null;
158 
159 		try {
160 			accessToken = jdbcTemplate.queryForObject(selectAccessTokenSql, new RowMapper<OAuth2AccessToken>() {
161 				public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
162 					return deserializeAccessToken(rs.getBytes(2));
163 				}
164 			}, extractTokenKey(tokenValue));
165 		}
166 		catch (EmptyResultDataAccessException e) {
167 			if (LOG.isInfoEnabled()) {
168 				LOG.info("Failed to find access token");
169 			}
170 		}
171 		catch (IllegalArgumentException e) {
172 			LOG.warn("Failed to deserialize access token", e);
173 			removeAccessToken(tokenValue);
174 		}
175 
176 		return accessToken;
177 	}
178 
179 	public void removeAccessToken(OAuth2AccessToken token) {
180 		removeAccessToken(token.getValue());
181 	}
182 
183 	public void removeAccessToken(String tokenValue) {
184 		jdbcTemplate.update(deleteAccessTokenSql, extractTokenKey(tokenValue));
185 	}
186 
187 	public OAuth2Authentication readAuthentication(OAuth2AccessToken token) {
188 		return readAuthentication(token.getValue());
189 	}
190 
191 	public OAuth2Authentication readAuthentication(String token) {
192 		OAuth2Authentication authentication = null;
193 
194 		try {
195 			authentication = jdbcTemplate.queryForObject(selectAccessTokenAuthenticationSql,
196 					new RowMapper<OAuth2Authentication>() {
197 						public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
198 							return deserializeAuthentication(rs.getBytes(2));
199 						}
200 					}, extractTokenKey(token));
201 		}
202 		catch (EmptyResultDataAccessException e) {
203 			if (LOG.isInfoEnabled()) {
204 				LOG.info("Failed to find access token");
205 			}
206 		}
207 		catch (IllegalArgumentException e) {
208 			LOG.warn("Failed to deserialize authentication", e);
209 			removeAccessToken(token);
210 		}
211 
212 		return authentication;
213 	}
214 
215 	public void storeRefreshToken(OAuth2RefreshToken refreshToken, OAuth2Authentication authentication) {
216 		jdbcTemplate.update(insertRefreshTokenSql, new Object[] { extractTokenKey(refreshToken.getValue()),
217 				new SqlLobValue(serializeRefreshToken(refreshToken)),
218 				new SqlLobValue(serializeAuthentication(authentication)) }, new int[] { Types.VARCHAR, Types.BLOB,
219 				Types.BLOB });
220 	}
221 
222 	public OAuth2RefreshToken readRefreshToken(String token) {
223 		OAuth2RefreshToken refreshToken = null;
224 
225 		try {
226 			refreshToken = jdbcTemplate.queryForObject(selectRefreshTokenSql, new RowMapper<OAuth2RefreshToken>() {
227 				public OAuth2RefreshToken mapRow(ResultSet rs, int rowNum) throws SQLException {
228 					return deserializeRefreshToken(rs.getBytes(2));
229 				}
230 			}, extractTokenKey(token));
231 		}
232 		catch (EmptyResultDataAccessException e) {
233 			if (LOG.isInfoEnabled()) {
234 				LOG.info("Failed to find refresh token");
235 			}
236 		}
237 		catch (IllegalArgumentException e) {
238 			LOG.warn("Failed to deserialize refresh token", e);
239 			removeRefreshToken(token);
240 		}
241 
242 		return refreshToken;
243 	}
244 
245 	public void removeRefreshToken(OAuth2RefreshToken token) {
246 		removeRefreshToken(token.getValue());
247 	}
248 
249 	public void removeRefreshToken(String token) {
250 		jdbcTemplate.update(deleteRefreshTokenSql, extractTokenKey(token));
251 	}
252 
253 	public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken token) {
254 		return readAuthenticationForRefreshToken(token.getValue());
255 	}
256 
257 	public OAuth2Authentication readAuthenticationForRefreshToken(String value) {
258 		OAuth2Authentication authentication = null;
259 
260 		try {
261 			authentication = jdbcTemplate.queryForObject(selectRefreshTokenAuthenticationSql,
262 					new RowMapper<OAuth2Authentication>() {
263 						public OAuth2Authentication mapRow(ResultSet rs, int rowNum) throws SQLException {
264 							return deserializeAuthentication(rs.getBytes(2));
265 						}
266 					}, extractTokenKey(value));
267 		}
268 		catch (EmptyResultDataAccessException e) {
269 			if (LOG.isInfoEnabled()) {
270 				LOG.info("Failed to find access token");
271 			}
272 		}
273 		catch (IllegalArgumentException e) {
274 			LOG.warn("Failed to deserialize access token", e);
275 			removeRefreshToken(value);
276 		}
277 
278 		return authentication;
279 	}
280 
281 	public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken refreshToken) {
282 		removeAccessTokenUsingRefreshToken(refreshToken.getValue());
283 	}
284 
285 	public void removeAccessTokenUsingRefreshToken(String refreshToken) {
286 		jdbcTemplate.update(deleteAccessTokenFromRefreshTokenSql, new Object[] { extractTokenKey(refreshToken) },
287 				new int[] { Types.VARCHAR });
288 	}
289 
290 	public Collection<OAuth2AccessToken> findTokensByClientId(String clientId) {
291 		List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
292 
293 		try {
294 			accessTokens = jdbcTemplate.query(selectAccessTokensFromClientIdSql, new SafeAccessTokenRowMapper(),
295 					clientId);
296 		}
297 		catch (EmptyResultDataAccessException e) {
298 			if (LOG.isInfoEnabled()) {
299 				LOG.info("Failed to find access token for clientId " + clientId);
300 			}
301 		}
302 		accessTokens = removeNulls(accessTokens);
303 
304 		return accessTokens;
305 	}
306 
307 	public Collection<OAuth2AccessToken> findTokensByUserName(String userName) {
308 		List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
309 
310 		try {
311 			accessTokens = jdbcTemplate.query(selectAccessTokensFromUserNameSql, new SafeAccessTokenRowMapper(),
312 					userName);
313 		}
314 		catch (EmptyResultDataAccessException e) {
315 			if (LOG.isInfoEnabled())
316 				LOG.info("Failed to find access token for userName " + userName);
317 		}
318 		accessTokens = removeNulls(accessTokens);
319 
320 		return accessTokens;
321 	}
322 
323 	public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String clientId, String userName) {
324 		List<OAuth2AccessToken> accessTokens = new ArrayList<OAuth2AccessToken>();
325 
326 		try {
327 			accessTokens = jdbcTemplate.query(selectAccessTokensFromUserNameAndClientIdSql, new SafeAccessTokenRowMapper(),
328 					userName, clientId);
329 		}
330 		catch (EmptyResultDataAccessException e) {
331 			if (LOG.isInfoEnabled()) {
332 				LOG.info("Failed to find access token for clientId " + clientId + " and userName " + userName);
333 			}
334 		}
335 		accessTokens = removeNulls(accessTokens);
336 
337 		return accessTokens;
338 	}
339 
340 	private List<OAuth2AccessToken> removeNulls(List<OAuth2AccessToken> accessTokens) {
341 		List<OAuth2AccessToken> tokens = new ArrayList<OAuth2AccessToken>();
342 		for (OAuth2AccessToken token : accessTokens) {
343 			if (token != null) {
344 				tokens.add(token);
345 			}
346 		}
347 		return tokens;
348 	}
349 
350 	protected String extractTokenKey(String value) {
351 		if (value == null) {
352 			return null;
353 		}
354 		MessageDigest digest;
355 		try {
356 			digest = MessageDigest.getInstance("MD5");
357 		}
358 		catch (NoSuchAlgorithmException e) {
359 			throw new IllegalStateException("MD5 algorithm not available.  Fatal (should be in the JDK).");
360 		}
361 
362 		try {
363 			byte[] bytes = digest.digest(value.getBytes("UTF-8"));
364 			return String.format("%032x", new BigInteger(1, bytes));
365 		}
366 		catch (UnsupportedEncodingException e) {
367 			throw new IllegalStateException("UTF-8 encoding not available.  Fatal (should be in the JDK).");
368 		}
369 	}
370 
371 	private final class SafeAccessTokenRowMapper implements RowMapper<OAuth2AccessToken> {
372 		public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
373 			try {
374 				return deserializeAccessToken(rs.getBytes(2));
375 			}
376 			catch (IllegalArgumentException e) {
377 				String token = rs.getString(1);
378 				jdbcTemplate.update(deleteAccessTokenSql, token);
379 				return null;
380 			}
381 		}
382 	}
383 
384 	protected byte[] serializeAccessToken(OAuth2AccessToken token) {
385 		return SerializationUtils.serialize(token);
386 	}
387 
388 	protected byte[] serializeRefreshToken(OAuth2RefreshToken token) {
389 		return SerializationUtils.serialize(token);
390 	}
391 
392 	protected byte[] serializeAuthentication(OAuth2Authentication authentication) {
393 		return SerializationUtils.serialize(authentication);
394 	}
395 
396 	protected OAuth2AccessToken deserializeAccessToken(byte[] token) {
397 		return SerializationUtils.deserialize(token);
398 	}
399 
400 	protected OAuth2RefreshToken deserializeRefreshToken(byte[] token) {
401 		return SerializationUtils.deserialize(token);
402 	}
403 
404 	protected OAuth2Authentication deserializeAuthentication(byte[] authentication) {
405 		return SerializationUtils.deserialize(authentication);
406 	}
407 
408 	public void setInsertAccessTokenSql(String insertAccessTokenSql) {
409 		this.insertAccessTokenSql = insertAccessTokenSql;
410 	}
411 
412 	public void setSelectAccessTokenSql(String selectAccessTokenSql) {
413 		this.selectAccessTokenSql = selectAccessTokenSql;
414 	}
415 
416 	public void setDeleteAccessTokenSql(String deleteAccessTokenSql) {
417 		this.deleteAccessTokenSql = deleteAccessTokenSql;
418 	}
419 
420 	public void setInsertRefreshTokenSql(String insertRefreshTokenSql) {
421 		this.insertRefreshTokenSql = insertRefreshTokenSql;
422 	}
423 
424 	public void setSelectRefreshTokenSql(String selectRefreshTokenSql) {
425 		this.selectRefreshTokenSql = selectRefreshTokenSql;
426 	}
427 
428 	public void setDeleteRefreshTokenSql(String deleteRefreshTokenSql) {
429 		this.deleteRefreshTokenSql = deleteRefreshTokenSql;
430 	}
431 
432 	public void setSelectAccessTokenAuthenticationSql(String selectAccessTokenAuthenticationSql) {
433 		this.selectAccessTokenAuthenticationSql = selectAccessTokenAuthenticationSql;
434 	}
435 
436 	public void setSelectRefreshTokenAuthenticationSql(String selectRefreshTokenAuthenticationSql) {
437 		this.selectRefreshTokenAuthenticationSql = selectRefreshTokenAuthenticationSql;
438 	}
439 
440 	public void setSelectAccessTokenFromAuthenticationSql(String selectAccessTokenFromAuthenticationSql) {
441 		this.selectAccessTokenFromAuthenticationSql = selectAccessTokenFromAuthenticationSql;
442 	}
443 
444 	public void setDeleteAccessTokenFromRefreshTokenSql(String deleteAccessTokenFromRefreshTokenSql) {
445 		this.deleteAccessTokenFromRefreshTokenSql = deleteAccessTokenFromRefreshTokenSql;
446 	}
447 
448 	public void setSelectAccessTokensFromUserNameSql(String selectAccessTokensFromUserNameSql) {
449 		this.selectAccessTokensFromUserNameSql = selectAccessTokensFromUserNameSql;
450 	}
451 
452 	public void setSelectAccessTokensFromUserNameAndClientIdSql(String selectAccessTokensFromUserNameAndClientIdSql) {
453 		this.selectAccessTokensFromUserNameAndClientIdSql = selectAccessTokensFromUserNameAndClientIdSql;
454 	}
455 
456 	public void setSelectAccessTokensFromClientIdSql(String selectAccessTokensFromClientIdSql) {
457 		this.selectAccessTokensFromClientIdSql = selectAccessTokensFromClientIdSql;
458 	}
459 
460 }