View Javadoc
1   /*
2    * Copyright 2012-2019 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package org.springframework.security.oauth2.provider.token.store.jwk;
17  
18  import org.springframework.security.jwt.codec.Codecs;
19  import org.springframework.security.jwt.crypto.sign.EllipticCurveVerifier;
20  import org.springframework.security.jwt.crypto.sign.RsaVerifier;
21  import org.springframework.security.jwt.crypto.sign.SignatureVerifier;
22  
23  import java.io.IOException;
24  import java.io.InputStream;
25  import java.math.BigInteger;
26  import java.net.MalformedURLException;
27  import java.net.URL;
28  import java.security.KeyFactory;
29  import java.security.interfaces.RSAPublicKey;
30  import java.security.spec.RSAPublicKeySpec;
31  import java.util.ArrayList;
32  import java.util.Arrays;
33  import java.util.LinkedHashMap;
34  import java.util.List;
35  import java.util.Map;
36  import java.util.Set;
37  import java.util.concurrent.ConcurrentHashMap;
38  
39  /**
40   * A source for JSON Web Key(s) (JWK) that is solely responsible for fetching (and caching)
41   * the JWK Set (a set of JWKs) from the URL supplied to the constructor.
42   *
43   * @see JwkSetConverter
44   * @see JwkDefinition
45   * @see SignatureVerifier
46   * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7517#page-10">JWK Set Format</a>
47   *
48   * @author Joe Grandja
49   * @author Michael Duergner
50   */
51  class JwkDefinitionSource {
52  	private final List<URL> jwkSetUrls;
53  	private final Map<String, JwkDefinitionHolder> jwkDefinitions = new ConcurrentHashMap<String, JwkDefinitionHolder>();
54  	private static final JwkSetConverter2/provider/token/store/jwk/JwkSetConverter.html#JwkSetConverter">JwkSetConverter jwkSetConverter = new JwkSetConverter();
55  
56  	/**
57  	 * Creates a new instance using the provided URL as the location for the JWK Set.
58  	 *
59  	 * @param jwkSetUrl the JWK Set URL
60  	 */
61  	JwkDefinitionSource(String jwkSetUrl) {
62  		this(Arrays.asList(jwkSetUrl));
63  	}
64  
65  	/**
66  	 * Creates a new instance using the provided URLs as the location for the JWK Sets.
67  	 *
68  	 * @param jwkSetUrls the JWK Set URLs
69  	 */
70  	JwkDefinitionSource(List<String> jwkSetUrls) {
71  		this.jwkSetUrls = new ArrayList<URL>();
72  		for(String jwkSetUrl : jwkSetUrls) {
73  			try {
74  				this.jwkSetUrls.add(new URL(jwkSetUrl));
75  			} catch (MalformedURLException ex) {
76  				throw new IllegalArgumentException("Invalid JWK Set URL: " + ex.getMessage(), ex);
77  			}
78  		}
79  	}
80  
81  	/**
82  	 * Returns the JWK definition matching the provided keyId (&quot;kid&quot;).
83  	 * If the JWK definition is not available in the internal cache then {@link #loadJwkDefinitions(URL)}
84  	 * will be called (to re-load the cache) and then followed-up with a second attempt to locate the JWK definition.
85  	 *
86  	 * @param keyId the Key ID (&quot;kid&quot;)
87  	 * @return the matching {@link JwkDefinition} or null if not found
88  	 */
89  	JwkDefinitionHolder getDefinitionLoadIfNecessary(String keyId) {
90  		JwkDefinitionHolder result = this.getDefinition(keyId);
91  		if (result != null) {
92  			return result;
93  		}
94  		synchronized (this.jwkDefinitions) {
95  			result = this.getDefinition(keyId);
96  			if (result != null) {
97  				return result;
98  			}
99  			Map<String, JwkDefinitionHolder> newJwkDefinitions = new LinkedHashMap<String, JwkDefinitionHolder>();
100 			for (URL jwkSetUrl : jwkSetUrls) {
101 				newJwkDefinitions.putAll(loadJwkDefinitions(jwkSetUrl));
102 			}
103 			this.jwkDefinitions.clear();
104 			this.jwkDefinitions.putAll(newJwkDefinitions);
105 			return this.getDefinition(keyId);
106 		}
107 	}
108 
109 	/**
110 	 * Returns the JWK definition matching the provided keyId (&quot;kid&quot;).
111 	 *
112 	 * @param keyId the Key ID (&quot;kid&quot;)
113 	 * @return the matching {@link JwkDefinition} or null if not found
114 	 */
115 	private JwkDefinitionHolder getDefinition(String keyId) {
116 		return this.jwkDefinitions.get(keyId);
117 	}
118 
119 	/**
120 	 * Fetches the JWK Set from the provided <code>URL</code> and
121 	 * returns a <code>Map</code> keyed by the JWK keyId (&quot;kid&quot;)
122 	 * and mapped to an association of the {@link JwkDefinition} and {@link SignatureVerifier}.
123 	 * Uses a {@link JwkSetConverter} to convert the JWK Set URL source to a set of {@link JwkDefinition}(s)
124 	 * followed by the instantiation of a {@link SignatureVerifier} which is associated to it's {@link JwkDefinition}.
125 	 *
126 	 * @param jwkSetUrl the JWK Set URL
127 	 * @return a <code>Map</code> keyed by the JWK keyId and mapped to an association of {@link JwkDefinition} and {@link SignatureVerifier}
128 	 * @see JwkSetConverter
129 	 */
130 	static Map<String, JwkDefinitionHolder> loadJwkDefinitions(URL jwkSetUrl) {
131 		InputStream jwkSetSource;
132 		try {
133 			jwkSetSource = jwkSetUrl.openStream();
134 		} catch (IOException ex) {
135 			throw new JwkException("An I/O error occurred while reading from the JWK Set source: " + ex.getMessage(), ex);
136 		}
137 
138 		Set<JwkDefinition> jwkDefinitionSet = jwkSetConverter.convert(jwkSetSource);
139 
140 		Map<String, JwkDefinitionHolder> jwkDefinitions = new LinkedHashMap<String, JwkDefinitionHolder>();
141 
142 		for (JwkDefinition jwkDefinition : jwkDefinitionSet) {
143 			if (JwkDefinition.KeyType.RSA.equals(jwkDefinition.getKeyType())) {
144 				jwkDefinitions.put(jwkDefinition.getKeyId(),
145 						new JwkDefinitionHolder(jwkDefinition, createRsaVerifier((RsaJwkDefinition) jwkDefinition)));
146 			} else if (JwkDefinition.KeyType.EC.equals(jwkDefinition.getKeyType())) {
147 				jwkDefinitions.put(jwkDefinition.getKeyId(),
148 						new JwkDefinitionHolder(jwkDefinition, createEcVerifier((EllipticCurveJwkDefinition) jwkDefinition)));
149 			}
150 		}
151 
152 		return jwkDefinitions;
153 	}
154 
155 	private static RsaVerifier createRsaVerifier(RsaJwkDefinition rsaDefinition) {
156 		RsaVerifier result;
157 		try {
158 			BigInteger modulus = new BigInteger(1, Codecs.b64UrlDecode(rsaDefinition.getModulus()));
159 			BigInteger exponent = new BigInteger(1, Codecs.b64UrlDecode(rsaDefinition.getExponent()));
160 
161 			RSAPublicKey rsaPublicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
162 					.generatePublic(new RSAPublicKeySpec(modulus, exponent));
163 
164 			if (rsaDefinition.getAlgorithm() != null) {
165 				result = new RsaVerifier(rsaPublicKey, rsaDefinition.getAlgorithm().standardName());
166 			} else {
167 				result = new RsaVerifier(rsaPublicKey);
168 			}
169 
170 		} catch (Exception ex) {
171 			throw new JwkException("An error occurred while creating a RSA Public Key Verifier for " +
172 					rsaDefinition.getKeyId() + " : " + ex.getMessage(), ex);
173 		}
174 		return result;
175 	}
176 
177 	private static EllipticCurveVerifier createEcVerifier(EllipticCurveJwkDefinition ecDefinition) {
178 		EllipticCurveVerifier result;
179 		try {
180 			BigInteger x = new BigInteger(1, Codecs.b64UrlDecode(ecDefinition.getX()));
181 			BigInteger y = new BigInteger(1, Codecs.b64UrlDecode(ecDefinition.getY()));
182 
183 			String algorithm = null;
184 			if (EllipticCurveJwkDefinition.NamedCurve.P256.value().equals(ecDefinition.getCurve())) {
185 				algorithm = JwkDefinition.CryptoAlgorithm.ES256.standardName();
186 			} else if (EllipticCurveJwkDefinition.NamedCurve.P384.value().equals(ecDefinition.getCurve())) {
187 				algorithm = JwkDefinition.CryptoAlgorithm.ES384.standardName();
188 			} else if (EllipticCurveJwkDefinition.NamedCurve.P521.value().equals(ecDefinition.getCurve())) {
189 				algorithm = JwkDefinition.CryptoAlgorithm.ES512.standardName();
190 			}
191 
192 			result = new EllipticCurveVerifier(x, y, ecDefinition.getCurve(), algorithm);
193 
194 		} catch (Exception ex) {
195 			throw new JwkException("An error occurred while creating an EC Public Key Verifier for " +
196 					ecDefinition.getKeyId() + " : " + ex.getMessage(), ex);
197 		}
198 		return result;
199 	}
200 
201 	static class JwkDefinitionHolder {
202 		private final JwkDefinition jwkDefinition;
203 		private final SignatureVerifier signatureVerifier;
204 
205 		private JwkDefinitionHolder(JwkDefinition jwkDefinition, SignatureVerifier signatureVerifier) {
206 			this.jwkDefinition = jwkDefinition;
207 			this.signatureVerifier = signatureVerifier;
208 		}
209 
210 		JwkDefinition getJwkDefinition() {
211 			return jwkDefinition;
212 		}
213 
214 		SignatureVerifier getSignatureVerifier() {
215 			return signatureVerifier;
216 		}
217 	}
218 }