1
2
3
4
5
6
7
8
9
10
11
12
13 package org.springframework.security.oauth2.provider.error;
14
15 import java.io.IOException;
16 import java.util.ArrayList;
17 import java.util.Collections;
18 import java.util.List;
19
20 import javax.servlet.http.HttpServletRequest;
21 import javax.servlet.http.HttpServletResponse;
22
23 import org.apache.commons.logging.Log;
24 import org.apache.commons.logging.LogFactory;
25 import org.springframework.http.HttpEntity;
26 import org.springframework.http.HttpHeaders;
27 import org.springframework.http.HttpInputMessage;
28 import org.springframework.http.HttpOutputMessage;
29 import org.springframework.http.MediaType;
30 import org.springframework.http.ResponseEntity;
31 import org.springframework.http.converter.HttpMessageConverter;
32 import org.springframework.http.server.ServerHttpResponse;
33 import org.springframework.http.server.ServletServerHttpRequest;
34 import org.springframework.http.server.ServletServerHttpResponse;
35 import org.springframework.security.oauth2.http.converter.jaxb.JaxbOAuth2ExceptionMessageConverter;
36 import org.springframework.web.HttpMediaTypeNotAcceptableException;
37 import org.springframework.web.client.RestTemplate;
38 import org.springframework.web.context.request.NativeWebRequest;
39 import org.springframework.web.context.request.ServletWebRequest;
40
41
42
43
44
45
46
47
48
49 public class DefaultOAuth2ExceptionRenderer implements OAuth2ExceptionRenderer {
50
51 private final Log logger = LogFactory.getLog(DefaultOAuth2ExceptionRenderer.class);
52
53 private List<HttpMessageConverter<?>> messageConverters = geDefaultMessageConverters();
54
55 public void setMessageConverters(List<HttpMessageConverter<?>> messageConverters) {
56 this.messageConverters = messageConverters;
57 }
58
59 public void handleHttpEntityResponse(HttpEntity<?> responseEntity, ServletWebRequest webRequest) throws Exception {
60 if (responseEntity == null) {
61 return;
62 }
63 HttpInputMessage inputMessage = createHttpInputMessage(webRequest);
64 HttpOutputMessage outputMessage = createHttpOutputMessage(webRequest);
65 if (responseEntity instanceof ResponseEntity && outputMessage instanceof ServerHttpResponse) {
66 ((ServerHttpResponse) outputMessage).setStatusCode(((ResponseEntity<?>) responseEntity).getStatusCode());
67 }
68 HttpHeaders entityHeaders = responseEntity.getHeaders();
69 if (!entityHeaders.isEmpty()) {
70 outputMessage.getHeaders().putAll(entityHeaders);
71 }
72 Object body = responseEntity.getBody();
73 if (body != null) {
74 writeWithMessageConverters(body, inputMessage, outputMessage);
75 }
76 else {
77
78 outputMessage.getBody();
79 }
80 }
81
82 @SuppressWarnings({ "unchecked", "rawtypes" })
83 private void writeWithMessageConverters(Object returnValue, HttpInputMessage inputMessage,
84 HttpOutputMessage outputMessage) throws IOException, HttpMediaTypeNotAcceptableException {
85 List<MediaType> acceptedMediaTypes = inputMessage.getHeaders().getAccept();
86 if (acceptedMediaTypes.isEmpty()) {
87 acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
88 }
89 MediaType.sortByQualityValue(acceptedMediaTypes);
90 Class<?> returnValueType = returnValue.getClass();
91 List<MediaType> allSupportedMediaTypes = new ArrayList<MediaType>();
92 for (MediaType acceptedMediaType : acceptedMediaTypes) {
93 for (HttpMessageConverter messageConverter : messageConverters) {
94 if (messageConverter.canWrite(returnValueType, acceptedMediaType)) {
95 messageConverter.write(returnValue, acceptedMediaType, outputMessage);
96 if (logger.isDebugEnabled()) {
97 MediaType contentType = outputMessage.getHeaders().getContentType();
98 if (contentType == null) {
99 contentType = acceptedMediaType;
100 }
101 logger.debug("Written [" + returnValue + "] as \"" + contentType + "\" using ["
102 + messageConverter + "]");
103 }
104 return;
105 }
106 }
107 }
108 for (HttpMessageConverter messageConverter : messageConverters) {
109 allSupportedMediaTypes.addAll(messageConverter.getSupportedMediaTypes());
110 }
111 throw new HttpMediaTypeNotAcceptableException(allSupportedMediaTypes);
112 }
113
114 private List<HttpMessageConverter<?>> geDefaultMessageConverters() {
115 List<HttpMessageConverter<?>> result = new ArrayList<HttpMessageConverter<?>>();
116 result.addAll(new RestTemplate().getMessageConverters());
117 result.add(new JaxbOAuth2ExceptionMessageConverter());
118 return result;
119 }
120
121 private HttpInputMessage createHttpInputMessage(NativeWebRequest webRequest) throws Exception {
122 HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
123 return new ServletServerHttpRequest(servletRequest);
124 }
125
126 private HttpOutputMessage createHttpOutputMessage(NativeWebRequest webRequest) throws Exception {
127 HttpServletResponse servletResponse = (HttpServletResponse) webRequest.getNativeResponse();
128 return new ServletServerHttpResponse(servletResponse);
129 }
130
131 }