1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.springframework.ws.server.endpoint.adapter.method.jaxb;
18
19 import java.io.IOException;
20 import java.io.InputStream;
21 import java.io.OutputStream;
22 import java.io.Reader;
23 import java.io.Writer;
24 import java.net.URL;
25 import java.util.concurrent.ConcurrentHashMap;
26 import java.util.concurrent.ConcurrentMap;
27 import javax.xml.bind.JAXBContext;
28 import javax.xml.bind.JAXBElement;
29 import javax.xml.bind.JAXBException;
30 import javax.xml.bind.JAXBIntrospector;
31 import javax.xml.bind.Marshaller;
32 import javax.xml.bind.Unmarshaller;
33 import javax.xml.namespace.QName;
34 import javax.xml.stream.XMLEventReader;
35 import javax.xml.stream.XMLEventWriter;
36 import javax.xml.stream.XMLStreamException;
37 import javax.xml.stream.XMLStreamReader;
38 import javax.xml.stream.XMLStreamWriter;
39 import javax.xml.transform.Result;
40 import javax.xml.transform.Source;
41 import javax.xml.transform.sax.SAXSource;
42 import javax.xml.transform.stream.StreamResult;
43 import javax.xml.transform.stream.StreamSource;
44
45 import org.springframework.util.Assert;
46 import org.springframework.ws.WebServiceMessage;
47 import org.springframework.ws.context.MessageContext;
48 import org.springframework.ws.server.endpoint.adapter.method.AbstractPayloadMethodProcessor;
49 import org.springframework.ws.stream.StreamingPayload;
50 import org.springframework.ws.stream.StreamingWebServiceMessage;
51 import org.springframework.xml.transform.TraxUtils;
52
53 import org.w3c.dom.Node;
54 import org.xml.sax.ContentHandler;
55 import org.xml.sax.InputSource;
56 import org.xml.sax.XMLReader;
57 import org.xml.sax.ext.LexicalHandler;
58
59
60
61
62
63
64
65
66
67
68
69
70 public abstract class AbstractJaxb2PayloadMethodProcessor extends AbstractPayloadMethodProcessor {
71
72 private final ConcurrentMap<Class, JAXBContext> jaxbContexts = new ConcurrentHashMap<Class, JAXBContext>();
73
74
75
76
77
78
79
80
81
82 protected final void marshalToResponsePayload(MessageContext messageContext, Class<?> clazz, Object jaxbElement)
83 throws JAXBException {
84 Assert.notNull(messageContext, "'messageContext' must not be null");
85 Assert.notNull(clazz, "'clazz' must not be null");
86 Assert.notNull(jaxbElement, "'jaxbElement' must not be null");
87 if (logger.isDebugEnabled()) {
88 logger.debug("Marshalling [" + jaxbElement + "] to response payload");
89 }
90 WebServiceMessage response = messageContext.getResponse();
91 if (response instanceof StreamingWebServiceMessage) {
92 StreamingWebServiceMessage streamingResponse = (StreamingWebServiceMessage) response;
93
94 StreamingPayload payload = new JaxbStreamingPayload(clazz, jaxbElement);
95 streamingResponse.setStreamingPayload(payload);
96 }
97 else {
98 Result responsePayload = response.getPayloadResult();
99 try {
100 Jaxb2ResultCallback callback = new Jaxb2ResultCallback(clazz, jaxbElement);
101 TraxUtils.doWithResult(responsePayload, callback);
102 }
103 catch (Exception ex) {
104 throw convertToJaxbException(ex);
105 }
106 }
107 }
108
109
110
111
112
113
114
115
116
117 protected final Object unmarshalFromRequestPayload(MessageContext messageContext, Class<?> clazz)
118 throws JAXBException {
119 Source requestPayload = getRequestPayload(messageContext);
120 if (requestPayload == null) {
121 return null;
122 }
123 try {
124 Jaxb2SourceCallback callback = new Jaxb2SourceCallback(clazz);
125 TraxUtils.doWithSource(requestPayload, callback);
126 if (logger.isDebugEnabled()) {
127 logger.debug("Unmarshalled payload request to [" + callback.result + "]");
128 }
129 return callback.result;
130 }
131 catch (Exception ex) {
132 throw convertToJaxbException(ex);
133 }
134 }
135
136
137
138
139
140
141
142
143
144 protected final <T> JAXBElement<T> unmarshalElementFromRequestPayload(MessageContext messageContext, Class<T> clazz)
145 throws JAXBException {
146 Source requestPayload = getRequestPayload(messageContext);
147 if (requestPayload == null) {
148 return null;
149 }
150 try {
151 JaxbElementSourceCallback<T> callback = new JaxbElementSourceCallback<T>(clazz);
152 TraxUtils.doWithSource(requestPayload, callback);
153 if (logger.isDebugEnabled()) {
154 logger.debug("Unmarshalled payload request to [" + callback.result + "]");
155 }
156 return callback.result;
157 }
158 catch (Exception ex) {
159 throw convertToJaxbException(ex);
160 }
161 }
162
163 private Source getRequestPayload(MessageContext messageContext) {
164 WebServiceMessage request = messageContext.getRequest();
165 return request != null ? request.getPayloadSource() : null;
166 }
167
168 private JAXBException convertToJaxbException(Exception ex) {
169 if (ex instanceof JAXBException) {
170 return (JAXBException) ex;
171 }
172 else {
173 return new JAXBException(ex);
174 }
175 }
176
177
178
179
180
181
182
183
184
185
186 protected Marshaller createMarshaller(JAXBContext jaxbContext) throws JAXBException {
187 return jaxbContext.createMarshaller();
188 }
189
190 private Marshaller createMarshaller(Class<?> clazz) throws JAXBException {
191 return createMarshaller(getJaxbContext(clazz));
192 }
193
194
195
196
197
198
199
200
201
202
203 protected Unmarshaller createUnmarshaller(JAXBContext jaxbContext) throws JAXBException {
204 return jaxbContext.createUnmarshaller();
205 }
206
207 private Unmarshaller createUnmarshaller(Class<?> clazz) throws JAXBException {
208 return createUnmarshaller(getJaxbContext(clazz));
209 }
210
211
212 private JAXBContext getJaxbContext(Class<?> clazz) throws JAXBException {
213 Assert.notNull(clazz, "'clazz' must not be null");
214 JAXBContext jaxbContext = jaxbContexts.get(clazz);
215 if (jaxbContext == null) {
216 jaxbContext = JAXBContext.newInstance(clazz);
217 jaxbContexts.putIfAbsent(clazz, jaxbContext);
218 }
219 return jaxbContext;
220 }
221
222
223
224 private class Jaxb2SourceCallback implements TraxUtils.SourceCallback {
225
226 private final Unmarshaller unmarshaller;
227
228 private Object result;
229
230 public Jaxb2SourceCallback(Class<?> clazz) throws JAXBException {
231 this.unmarshaller = createUnmarshaller(clazz);
232 }
233
234 public void domSource(Node node) throws JAXBException {
235 result = unmarshaller.unmarshal(node);
236 }
237
238 public void saxSource(XMLReader reader, InputSource inputSource) throws JAXBException {
239 result = unmarshaller.unmarshal(inputSource);
240 }
241
242 public void staxSource(XMLEventReader eventReader) throws JAXBException {
243 result = unmarshaller.unmarshal(eventReader);
244 }
245
246 public void staxSource(XMLStreamReader streamReader) throws JAXBException {
247 result = unmarshaller.unmarshal(streamReader);
248 }
249
250 public void streamSource(InputStream inputStream) throws IOException, JAXBException {
251 result = unmarshaller.unmarshal(inputStream);
252 }
253
254 public void streamSource(Reader reader) throws IOException, JAXBException {
255 result = unmarshaller.unmarshal(reader);
256 }
257
258 public void source(String systemId) throws Exception {
259 result = unmarshaller.unmarshal(new URL(systemId));
260 }
261 }
262
263 private class JaxbElementSourceCallback<T> implements TraxUtils.SourceCallback {
264
265 private final Unmarshaller unmarshaller;
266
267 private final Class<T> declaredType;
268
269 private JAXBElement<T> result;
270
271 public JaxbElementSourceCallback(Class<T> declaredType) throws JAXBException {
272 this.unmarshaller = createUnmarshaller(declaredType);
273 this.declaredType = declaredType;
274 }
275
276 public void domSource(Node node) throws JAXBException {
277 result = unmarshaller.unmarshal(node, declaredType);
278 }
279
280 public void saxSource(XMLReader reader, InputSource inputSource) throws JAXBException {
281 result = unmarshaller.unmarshal(new SAXSource(reader, inputSource), declaredType);
282 }
283
284 public void staxSource(XMLEventReader eventReader) throws JAXBException {
285 result = unmarshaller.unmarshal(eventReader, declaredType);
286 }
287
288 public void staxSource(XMLStreamReader streamReader) throws JAXBException {
289 result = unmarshaller.unmarshal(streamReader, declaredType);
290 }
291
292 public void streamSource(InputStream inputStream) throws IOException, JAXBException {
293 result = unmarshaller.unmarshal(new StreamSource(inputStream), declaredType);
294 }
295
296 public void streamSource(Reader reader) throws IOException, JAXBException {
297 result = unmarshaller.unmarshal(new StreamSource(reader), declaredType);
298 }
299
300 public void source(String systemId) throws Exception {
301 result = unmarshaller.unmarshal(new StreamSource(systemId), declaredType);
302 }
303 }
304
305 private class Jaxb2ResultCallback implements TraxUtils.ResultCallback {
306
307 private final Marshaller marshaller;
308
309 private final Object jaxbElement;
310
311 private Jaxb2ResultCallback(Class<?> clazz, Object jaxbElement) throws JAXBException {
312 this.marshaller = createMarshaller(clazz);
313 this.jaxbElement = jaxbElement;
314 }
315
316 public void domResult(Node node) throws JAXBException {
317 marshaller.marshal(jaxbElement, node);
318 }
319
320 public void saxResult(ContentHandler contentHandler, LexicalHandler lexicalHandler) throws JAXBException {
321 marshaller.marshal(jaxbElement, contentHandler);
322 }
323
324 public void staxResult(XMLEventWriter eventWriter) throws JAXBException {
325 marshaller.marshal(jaxbElement, eventWriter);
326 }
327
328 public void staxResult(XMLStreamWriter streamWriter) throws JAXBException {
329 marshaller.marshal(jaxbElement, streamWriter);
330 }
331
332 public void streamResult(OutputStream outputStream) throws JAXBException {
333 marshaller.marshal(jaxbElement, outputStream);
334 }
335
336 public void streamResult(Writer writer) throws JAXBException {
337 marshaller.marshal(jaxbElement, writer);
338 }
339
340 public void result(String systemId) throws Exception {
341 marshaller.marshal(jaxbElement, new StreamResult(systemId));
342 }
343 }
344
345 private class JaxbStreamingPayload implements StreamingPayload {
346
347 private final Object jaxbElement;
348
349 private final Marshaller marshaller;
350
351 private final QName name;
352
353 private JaxbStreamingPayload(Class<?> clazz, Object jaxbElement) throws JAXBException {
354 JAXBContext jaxbContext = getJaxbContext(clazz);
355 this.marshaller = jaxbContext.createMarshaller();
356 this.marshaller.setProperty(Marshaller.JAXB_FRAGMENT, Boolean.TRUE);
357 this.jaxbElement = jaxbElement;
358 JAXBIntrospector introspector = jaxbContext.createJAXBIntrospector();
359 this.name = introspector.getElementName(jaxbElement);
360 }
361
362 public QName getName() {
363 return name;
364 }
365
366 public void writeTo(XMLStreamWriter streamWriter) throws XMLStreamException {
367 try {
368 marshaller.marshal(jaxbElement, streamWriter);
369 }
370 catch (JAXBException ex) {
371 throw new XMLStreamException("Could not marshal [" + jaxbElement + "]: " + ex.getMessage(), ex);
372 }
373 }
374 }
375
376
377 }