View Javadoc
1   /*
2    * Copyright 2002-2011 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  package org.springframework.security.oauth2.provider.error;
17  
18  import java.io.IOException;
19  
20  import org.springframework.http.HttpHeaders;
21  import org.springframework.http.HttpStatus;
22  import org.springframework.http.ResponseEntity;
23  import org.springframework.security.access.AccessDeniedException;
24  import org.springframework.security.core.AuthenticationException;
25  import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
26  import org.springframework.security.oauth2.common.OAuth2AccessToken;
27  import org.springframework.security.oauth2.common.exceptions.InsufficientScopeException;
28  import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
29  import org.springframework.security.web.util.ThrowableAnalyzer;
30  import org.springframework.web.HttpRequestMethodNotSupportedException;
31  
32  /**
33   * Default translator that converts exceptions into {@link OAuth2Exception}s. The output matches the OAuth 2.0
34   * specification in terms of error response format and HTTP status code.
35   * 
36   * @author Dave Syer
37   * 
38   */
39  public class DefaultWebResponseExceptionTranslator implements WebResponseExceptionTranslator<OAuth2Exception> {
40  
41  	private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
42  
43  	@Override
44  	public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {
45  
46  		// Try to extract a SpringSecurityException from the stacktrace
47  		Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);
48  		Exception ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);
49  
50  		if (ase != null) {
51  			return handleOAuth2Exception((OAuth2Exception) ase);
52  		}
53  
54  		ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class,
55  				causeChain);
56  		if (ase != null) {
57  			return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e));
58  		}
59  
60  		ase = (AccessDeniedException) throwableAnalyzer
61  				.getFirstThrowableOfType(AccessDeniedException.class, causeChain);
62  		if (ase instanceof AccessDeniedException) {
63  			return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase));
64  		}
65  
66  		ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(
67  				HttpRequestMethodNotSupportedException.class, causeChain);
68  		if (ase instanceof HttpRequestMethodNotSupportedException) {
69  			return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase));
70  		}
71  
72  		return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));
73  
74  	}
75  
76  	private ResponseEntity<OAuth2Exception> handleOAuth2Exception(OAuth2Exception e) throws IOException {
77  
78  		int status = e.getHttpErrorCode();
79  		HttpHeaders headers = new HttpHeaders();
80  		headers.set("Cache-Control", "no-store");
81  		headers.set("Pragma", "no-cache");
82  		if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
83  			headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
84  		}
85  
86  		ResponseEntity<OAuth2Exception> response = new ResponseEntity<OAuth2Exception>(e, headers,
87  				HttpStatus.valueOf(status));
88  
89  		return response;
90  
91  	}
92  
93  	public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
94  		this.throwableAnalyzer = throwableAnalyzer;
95  	}
96  
97  	@SuppressWarnings("serial")
98  	private static class ForbiddenException extends OAuth2Exception {
99  
100 		public ForbiddenException(String msg, Throwable t) {
101 			super(msg, t);
102 		}
103 
104 		@Override
105 		public String getOAuth2ErrorCode() {
106 			return "access_denied";
107 		}
108 
109 		@Override
110 		public int getHttpErrorCode() {
111 			return 403;
112 		}
113 
114 	}
115 
116 	@SuppressWarnings("serial")
117 	private static class ServerErrorException extends OAuth2Exception {
118 
119 		public ServerErrorException(String msg, Throwable t) {
120 			super(msg, t);
121 		}
122 
123 		@Override
124 		public String getOAuth2ErrorCode() {
125 			return "server_error";
126 		}
127 
128 		@Override
129 		public int getHttpErrorCode() {
130 			return 500;
131 		}
132 
133 	}
134 
135 	@SuppressWarnings("serial")
136 	private static class UnauthorizedException extends OAuth2Exception {
137 
138 		public UnauthorizedException(String msg, Throwable t) {
139 			super(msg, t);
140 		}
141 
142 		@Override
143 		public String getOAuth2ErrorCode() {
144 			return "unauthorized";
145 		}
146 
147 		@Override
148 		public int getHttpErrorCode() {
149 			return 401;
150 		}
151 
152 	}
153 
154 	@SuppressWarnings("serial")
155 	private static class MethodNotAllowed extends OAuth2Exception {
156 
157 		public MethodNotAllowed(String msg, Throwable t) {
158 			super(msg, t);
159 		}
160 
161 		@Override
162 		public String getOAuth2ErrorCode() {
163 			return "method_not_allowed";
164 		}
165 
166 		@Override
167 		public int getHttpErrorCode() {
168 			return 405;
169 		}
170 
171 	}
172 }