1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package org.springframework.security.oauth2.provider.error;
17
18 import java.io.IOException;
19
20 import org.springframework.http.HttpHeaders;
21 import org.springframework.http.HttpStatus;
22 import org.springframework.http.ResponseEntity;
23 import org.springframework.security.access.AccessDeniedException;
24 import org.springframework.security.core.AuthenticationException;
25 import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
26 import org.springframework.security.oauth2.common.OAuth2AccessToken;
27 import org.springframework.security.oauth2.common.exceptions.InsufficientScopeException;
28 import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
29 import org.springframework.security.web.util.ThrowableAnalyzer;
30 import org.springframework.web.HttpRequestMethodNotSupportedException;
31
32
33
34
35
36
37
38
39 public class DefaultWebResponseExceptionTranslator implements WebResponseExceptionTranslator<OAuth2Exception> {
40
41 private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
42
43 @Override
44 public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {
45
46
47 Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);
48 Exception ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);
49
50 if (ase != null) {
51 return handleOAuth2Exception((OAuth2Exception) ase);
52 }
53
54 ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class,
55 causeChain);
56 if (ase != null) {
57 return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e));
58 }
59
60 ase = (AccessDeniedException) throwableAnalyzer
61 .getFirstThrowableOfType(AccessDeniedException.class, causeChain);
62 if (ase instanceof AccessDeniedException) {
63 return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase));
64 }
65
66 ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(
67 HttpRequestMethodNotSupportedException.class, causeChain);
68 if (ase instanceof HttpRequestMethodNotSupportedException) {
69 return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase));
70 }
71
72 return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));
73
74 }
75
76 private ResponseEntity<OAuth2Exception> handleOAuth2Exception(OAuth2Exception e) throws IOException {
77
78 int status = e.getHttpErrorCode();
79 HttpHeaders headers = new HttpHeaders();
80 headers.set("Cache-Control", "no-store");
81 headers.set("Pragma", "no-cache");
82 if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
83 headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
84 }
85
86 ResponseEntity<OAuth2Exception> response = new ResponseEntity<OAuth2Exception>(e, headers,
87 HttpStatus.valueOf(status));
88
89 return response;
90
91 }
92
93 public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
94 this.throwableAnalyzer = throwableAnalyzer;
95 }
96
97 @SuppressWarnings("serial")
98 private static class ForbiddenException extends OAuth2Exception {
99
100 public ForbiddenException(String msg, Throwable t) {
101 super(msg, t);
102 }
103
104 @Override
105 public String getOAuth2ErrorCode() {
106 return "access_denied";
107 }
108
109 @Override
110 public int getHttpErrorCode() {
111 return 403;
112 }
113
114 }
115
116 @SuppressWarnings("serial")
117 private static class ServerErrorException extends OAuth2Exception {
118
119 public ServerErrorException(String msg, Throwable t) {
120 super(msg, t);
121 }
122
123 @Override
124 public String getOAuth2ErrorCode() {
125 return "server_error";
126 }
127
128 @Override
129 public int getHttpErrorCode() {
130 return 500;
131 }
132
133 }
134
135 @SuppressWarnings("serial")
136 private static class UnauthorizedException extends OAuth2Exception {
137
138 public UnauthorizedException(String msg, Throwable t) {
139 super(msg, t);
140 }
141
142 @Override
143 public String getOAuth2ErrorCode() {
144 return "unauthorized";
145 }
146
147 @Override
148 public int getHttpErrorCode() {
149 return 401;
150 }
151
152 }
153
154 @SuppressWarnings("serial")
155 private static class MethodNotAllowed extends OAuth2Exception {
156
157 public MethodNotAllowed(String msg, Throwable t) {
158 super(msg, t);
159 }
160
161 @Override
162 public String getOAuth2ErrorCode() {
163 return "method_not_allowed";
164 }
165
166 @Override
167 public int getHttpErrorCode() {
168 return 405;
169 }
170
171 }
172 }