1 package org.springframework.security.oauth2.client.token;
2
3 import java.sql.ResultSet;
4 import java.sql.SQLException;
5 import java.sql.Types;
6
7 import javax.sql.DataSource;
8
9 import org.apache.commons.logging.Log;
10 import org.apache.commons.logging.LogFactory;
11 import org.springframework.dao.EmptyResultDataAccessException;
12 import org.springframework.jdbc.core.JdbcTemplate;
13 import org.springframework.jdbc.core.RowMapper;
14 import org.springframework.jdbc.core.support.SqlLobValue;
15 import org.springframework.security.core.Authentication;
16 import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
17 import org.springframework.security.oauth2.common.OAuth2AccessToken;
18 import org.springframework.security.oauth2.common.util.SerializationUtils;
19 import org.springframework.util.Assert;
20
21
22
23
24
25
26 public class JdbcClientTokenServices implements ClientTokenServices {
27
28 private static final Log LOG = LogFactory.getLog(JdbcClientTokenServices.class);
29
30 private static final String DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT = "insert into oauth_client_token (token_id, token, authentication_id, user_name, client_id) values (?, ?, ?, ?, ?)";
31
32 private static final String DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT = "select token_id, token from oauth_client_token where authentication_id = ?";
33
34 private static final String DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT = "delete from oauth_client_token where authentication_id = ?";
35
36 private String insertAccessTokenSql = DEFAULT_ACCESS_TOKEN_INSERT_STATEMENT;
37
38 private String selectAccessTokenSql = DEFAULT_ACCESS_TOKEN_FROM_AUTHENTICATION_SELECT_STATEMENT;
39
40 private String deleteAccessTokenSql = DEFAULT_ACCESS_TOKEN_DELETE_STATEMENT;
41
42 private ClientKeyGenerator keyGenerator = new DefaultClientKeyGenerator();
43
44 private final JdbcTemplate jdbcTemplate;
45
46 public JdbcClientTokenServices(DataSource dataSource) {
47 Assert.notNull(dataSource, "DataSource required");
48 this.jdbcTemplate = new JdbcTemplate(dataSource);
49 }
50
51 public void setClientKeyGenerator(ClientKeyGenerator keyGenerator) {
52 this.keyGenerator = keyGenerator;
53 }
54
55 public OAuth2AccessToken getAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
56
57 OAuth2AccessToken accessToken = null;
58
59 try {
60 accessToken = jdbcTemplate.queryForObject(selectAccessTokenSql, new RowMapper<OAuth2AccessToken>() {
61 public OAuth2AccessToken mapRow(ResultSet rs, int rowNum) throws SQLException {
62 return SerializationUtils.deserialize(rs.getBytes(2));
63 }
64 }, keyGenerator.extractKey(resource, authentication));
65 }
66 catch (EmptyResultDataAccessException e) {
67 if (LOG.isInfoEnabled()) {
68 LOG.debug("Failed to find access token for authentication " + authentication);
69 }
70 }
71
72 return accessToken;
73 }
74
75 public void saveAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication,
76 OAuth2AccessToken accessToken) {
77 removeAccessToken(resource, authentication);
78 String name = authentication==null ? null : authentication.getName();
79 jdbcTemplate.update(
80 insertAccessTokenSql,
81 new Object[] { accessToken.getValue(), new SqlLobValue(SerializationUtils.serialize(accessToken)),
82 keyGenerator.extractKey(resource, authentication), name,
83 resource.getClientId() }, new int[] { Types.VARCHAR, Types.BLOB, Types.VARCHAR, Types.VARCHAR,
84 Types.VARCHAR });
85 }
86
87 public void removeAccessToken(OAuth2ProtectedResourceDetails resource, Authentication authentication) {
88 jdbcTemplate.update(deleteAccessTokenSql, keyGenerator.extractKey(resource, authentication));
89 }
90
91 public void setInsertAccessTokenSql(String insertAccessTokenSql) {
92 this.insertAccessTokenSql = insertAccessTokenSql;
93 }
94
95 public void setSelectAccessTokenSql(String selectAccessTokenSql) {
96 this.selectAccessTokenSql = selectAccessTokenSql;
97 }
98
99 public void setDeleteAccessTokenSql(String deleteAccessTokenSql) {
100 this.deleteAccessTokenSql = deleteAccessTokenSql;
101 }
102
103 }