1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.springframework.security.oauth2.provider.token.store.jwk;
17
18 import org.springframework.security.jwt.codec.Codecs;
19 import org.springframework.security.jwt.crypto.sign.EllipticCurveVerifier;
20 import org.springframework.security.jwt.crypto.sign.RsaVerifier;
21 import org.springframework.security.jwt.crypto.sign.SignatureVerifier;
22
23 import java.io.IOException;
24 import java.io.InputStream;
25 import java.math.BigInteger;
26 import java.net.MalformedURLException;
27 import java.net.URL;
28 import java.security.KeyFactory;
29 import java.security.interfaces.RSAPublicKey;
30 import java.security.spec.RSAPublicKeySpec;
31 import java.util.ArrayList;
32 import java.util.Arrays;
33 import java.util.LinkedHashMap;
34 import java.util.List;
35 import java.util.Map;
36 import java.util.Set;
37 import java.util.concurrent.ConcurrentHashMap;
38
39
40
41
42
43
44
45
46
47
48
49
50
51 class JwkDefinitionSource {
52 private final List<URL> jwkSetUrls;
53 private final Map<String, JwkDefinitionHolder> jwkDefinitions = new ConcurrentHashMap<String, JwkDefinitionHolder>();
54 private static final JwkSetConverter2/provider/token/store/jwk/JwkSetConverter.html#JwkSetConverter">JwkSetConverter jwkSetConverter = new JwkSetConverter();
55
56
57
58
59
60
61 JwkDefinitionSource(String jwkSetUrl) {
62 this(Arrays.asList(jwkSetUrl));
63 }
64
65
66
67
68
69
70 JwkDefinitionSource(List<String> jwkSetUrls) {
71 this.jwkSetUrls = new ArrayList<URL>();
72 for(String jwkSetUrl : jwkSetUrls) {
73 try {
74 this.jwkSetUrls.add(new URL(jwkSetUrl));
75 } catch (MalformedURLException ex) {
76 throw new IllegalArgumentException("Invalid JWK Set URL: " + ex.getMessage(), ex);
77 }
78 }
79 }
80
81
82
83
84
85
86
87
88
89 JwkDefinitionHolder getDefinitionLoadIfNecessary(String keyId) {
90 JwkDefinitionHolder result = this.getDefinition(keyId);
91 if (result != null) {
92 return result;
93 }
94 synchronized (this.jwkDefinitions) {
95 result = this.getDefinition(keyId);
96 if (result != null) {
97 return result;
98 }
99 Map<String, JwkDefinitionHolder> newJwkDefinitions = new LinkedHashMap<String, JwkDefinitionHolder>();
100 for (URL jwkSetUrl : jwkSetUrls) {
101 newJwkDefinitions.putAll(loadJwkDefinitions(jwkSetUrl));
102 }
103 this.jwkDefinitions.clear();
104 this.jwkDefinitions.putAll(newJwkDefinitions);
105 return this.getDefinition(keyId);
106 }
107 }
108
109
110
111
112
113
114
115 private JwkDefinitionHolder getDefinition(String keyId) {
116 return this.jwkDefinitions.get(keyId);
117 }
118
119
120
121
122
123
124
125
126
127
128
129
130 static Map<String, JwkDefinitionHolder> loadJwkDefinitions(URL jwkSetUrl) {
131 InputStream jwkSetSource;
132 try {
133 jwkSetSource = jwkSetUrl.openStream();
134 } catch (IOException ex) {
135 throw new JwkException("An I/O error occurred while reading from the JWK Set source: " + ex.getMessage(), ex);
136 }
137
138 Set<JwkDefinition> jwkDefinitionSet = jwkSetConverter.convert(jwkSetSource);
139
140 Map<String, JwkDefinitionHolder> jwkDefinitions = new LinkedHashMap<String, JwkDefinitionHolder>();
141
142 for (JwkDefinition jwkDefinition : jwkDefinitionSet) {
143 if (JwkDefinition.KeyType.RSA.equals(jwkDefinition.getKeyType())) {
144 jwkDefinitions.put(jwkDefinition.getKeyId(),
145 new JwkDefinitionHolder(jwkDefinition, createRsaVerifier((RsaJwkDefinition) jwkDefinition)));
146 } else if (JwkDefinition.KeyType.EC.equals(jwkDefinition.getKeyType())) {
147 jwkDefinitions.put(jwkDefinition.getKeyId(),
148 new JwkDefinitionHolder(jwkDefinition, createEcVerifier((EllipticCurveJwkDefinition) jwkDefinition)));
149 }
150 }
151
152 return jwkDefinitions;
153 }
154
155 private static RsaVerifier createRsaVerifier(RsaJwkDefinition rsaDefinition) {
156 RsaVerifier result;
157 try {
158 BigInteger modulus = new BigInteger(1, Codecs.b64UrlDecode(rsaDefinition.getModulus()));
159 BigInteger exponent = new BigInteger(1, Codecs.b64UrlDecode(rsaDefinition.getExponent()));
160
161 RSAPublicKey rsaPublicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
162 .generatePublic(new RSAPublicKeySpec(modulus, exponent));
163
164 if (rsaDefinition.getAlgorithm() != null) {
165 result = new RsaVerifier(rsaPublicKey, rsaDefinition.getAlgorithm().standardName());
166 } else {
167 result = new RsaVerifier(rsaPublicKey);
168 }
169
170 } catch (Exception ex) {
171 throw new JwkException("An error occurred while creating a RSA Public Key Verifier for " +
172 rsaDefinition.getKeyId() + " : " + ex.getMessage(), ex);
173 }
174 return result;
175 }
176
177 private static EllipticCurveVerifier createEcVerifier(EllipticCurveJwkDefinition ecDefinition) {
178 EllipticCurveVerifier result;
179 try {
180 BigInteger x = new BigInteger(1, Codecs.b64UrlDecode(ecDefinition.getX()));
181 BigInteger y = new BigInteger(1, Codecs.b64UrlDecode(ecDefinition.getY()));
182
183 String algorithm = null;
184 if (EllipticCurveJwkDefinition.NamedCurve.P256.value().equals(ecDefinition.getCurve())) {
185 algorithm = JwkDefinition.CryptoAlgorithm.ES256.standardName();
186 } else if (EllipticCurveJwkDefinition.NamedCurve.P384.value().equals(ecDefinition.getCurve())) {
187 algorithm = JwkDefinition.CryptoAlgorithm.ES384.standardName();
188 } else if (EllipticCurveJwkDefinition.NamedCurve.P521.value().equals(ecDefinition.getCurve())) {
189 algorithm = JwkDefinition.CryptoAlgorithm.ES512.standardName();
190 }
191
192 result = new EllipticCurveVerifier(x, y, ecDefinition.getCurve(), algorithm);
193
194 } catch (Exception ex) {
195 throw new JwkException("An error occurred while creating an EC Public Key Verifier for " +
196 ecDefinition.getKeyId() + " : " + ex.getMessage(), ex);
197 }
198 return result;
199 }
200
201 static class JwkDefinitionHolder {
202 private final JwkDefinition jwkDefinition;
203 private final SignatureVerifier signatureVerifier;
204
205 private JwkDefinitionHolder(JwkDefinition jwkDefinition, SignatureVerifier signatureVerifier) {
206 this.jwkDefinition = jwkDefinition;
207 this.signatureVerifier = signatureVerifier;
208 }
209
210 JwkDefinition getJwkDefinition() {
211 return jwkDefinition;
212 }
213
214 SignatureVerifier getSignatureVerifier() {
215 return signatureVerifier;
216 }
217 }
218 }