1   /*
2    * Copyright 2005-2010 the original author or authors.
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    *      http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package org.springframework.ws.server.endpoint;
18  
19  import java.io.IOException;
20  import java.io.StringReader;
21  import java.io.StringWriter;
22  import javax.xml.transform.Result;
23  import javax.xml.transform.Source;
24  import javax.xml.transform.Transformer;
25  import javax.xml.transform.TransformerException;
26  import javax.xml.transform.TransformerFactory;
27  import javax.xml.transform.stream.StreamResult;
28  import javax.xml.transform.stream.StreamSource;
29  
30  import org.springframework.oxm.Marshaller;
31  import org.springframework.oxm.Unmarshaller;
32  import org.springframework.oxm.XmlMappingException;
33  import org.springframework.oxm.mime.MimeContainer;
34  import org.springframework.oxm.mime.MimeMarshaller;
35  import org.springframework.oxm.mime.MimeUnmarshaller;
36  import org.springframework.ws.MockWebServiceMessage;
37  import org.springframework.ws.WebServiceMessageFactory;
38  import org.springframework.ws.context.DefaultMessageContext;
39  import org.springframework.ws.context.MessageContext;
40  import org.springframework.ws.mime.MimeMessage;
41  import org.springframework.xml.transform.StringResult;
42  import org.springframework.xml.transform.StringSource;
43  
44  import org.junit.Assert;
45  import org.junit.Before;
46  import org.junit.Test;
47  
48  import static org.custommonkey.xmlunit.XMLAssert.assertXMLEqual;
49  import static org.easymock.EasyMock.*;
50  import static org.junit.Assert.fail;
51  
52  public class MarshallingPayloadEndpointTest {
53  
54      private Transformer transformer;
55  
56      private MessageContext context;
57  
58      private WebServiceMessageFactory factoryMock;
59  
60      @Before
61      public void setUp() throws Exception {
62          MockWebServiceMessage request = new MockWebServiceMessage("<request/>");
63          transformer = TransformerFactory.newInstance().newTransformer();
64          factoryMock = createMock(WebServiceMessageFactory.class);
65  
66          context = new DefaultMessageContext(request, factoryMock);
67      }
68  
69      @Test
70      public void testInvoke() throws Exception {
71          Unmarshaller unmarshaller = new SimpleMarshaller() {
72              @Override
73              public Object unmarshal(Source source) throws XmlMappingException {
74                  try {
75                      StringWriter writer = new StringWriter();
76                      transformer.transform(source, new StreamResult(writer));
77                      assertXMLEqual("Invalid source", "<request/>", writer.toString());
78                      return 42L;
79                  }
80                  catch (Exception e) {
81                      Assert.fail(e.getMessage());
82                      return null;
83                  }
84              }
85          };
86          Marshaller marshaller = new SimpleMarshaller() {
87              @Override
88              public void marshal(Object graph, Result result) throws XmlMappingException {
89                  Assert.assertEquals("Invalid graph", "result", graph);
90                  try {
91                      transformer.transform(new StreamSource(new StringReader("<result/>")), result);
92                  }
93                  catch (TransformerException e) {
94                      Assert.fail(e.getMessage());
95                  }
96              }
97          };
98          AbstractMarshallingPayloadEndpoint endpoint = new AbstractMarshallingPayloadEndpoint() {
99              @Override
100             protected Object invokeInternal(Object requestObject) throws Exception {
101                 Assert.assertEquals("Invalid request object", 42L, requestObject);
102                 return "result";
103             }
104         };
105         endpoint.setMarshaller(marshaller);
106         endpoint.setUnmarshaller(unmarshaller);
107         endpoint.afterPropertiesSet();
108 
109         expect(factoryMock.createWebServiceMessage()).andReturn(new MockWebServiceMessage());
110 
111         replay(factoryMock);
112 
113         endpoint.invoke(context);
114         MockWebServiceMessage response = (MockWebServiceMessage) context.getResponse();
115         Assert.assertNotNull("Invalid result", response);
116         assertXMLEqual("Invalid response", "<result/>", response.getPayloadAsString());
117 
118         verify(factoryMock);
119     }
120 
121     @Test
122     public void testInvokeNullResponse() throws Exception {
123         Unmarshaller unmarshaller = new SimpleMarshaller() {
124             @Override
125             public Object unmarshal(Source source) throws XmlMappingException {
126                 try {
127                     StringWriter writer = new StringWriter();
128                     transformer.transform(source, new StreamResult(writer));
129                     assertXMLEqual("Invalid source", "<request/>", writer.toString());
130                     return (long) 42;
131                 }
132                 catch (Exception e) {
133                     Assert.fail(e.getMessage());
134                     return null;
135                 }
136             }
137         };
138         Marshaller marshaller = new SimpleMarshaller() {
139             @Override
140             public void marshal(Object graph, Result result) throws XmlMappingException {
141                 Assert.fail("marshal not expected");
142             }
143         };
144         AbstractMarshallingPayloadEndpoint endpoint = new AbstractMarshallingPayloadEndpoint() {
145             @Override
146             protected Object invokeInternal(Object requestObject) throws Exception {
147                 Assert.assertEquals("Invalid request object", (long) 42, requestObject);
148                 return null;
149             }
150         };
151         endpoint.setMarshaller(marshaller);
152         endpoint.setUnmarshaller(unmarshaller);
153         endpoint.afterPropertiesSet();
154         replay(factoryMock);
155         endpoint.invoke(context);
156         Assert.assertFalse("Response created", context.hasResponse());
157         verify(factoryMock);
158     }
159 
160     @Test
161     public void testInvokeNoRequest() throws Exception {
162         MockWebServiceMessage request = new MockWebServiceMessage((StringBuilder) null);
163         context = new DefaultMessageContext(request, factoryMock);
164         AbstractMarshallingPayloadEndpoint endpoint = new AbstractMarshallingPayloadEndpoint() {
165 
166             @Override
167             protected Object invokeInternal(Object requestObject) throws Exception {
168                 Assert.assertNull("No request expected", requestObject);
169                 return null;
170             }
171         };
172         endpoint.setMarshaller(new SimpleMarshaller());
173         endpoint.setUnmarshaller(new SimpleMarshaller());
174         endpoint.afterPropertiesSet();
175         replay(factoryMock);
176         endpoint.invoke(context);
177         Assert.assertFalse("Response created", context.hasResponse());
178         verify(factoryMock);
179     }
180 
181     @Test
182     public void testInvokeMimeMarshaller() throws Exception {
183         MimeUnmarshaller unmarshaller = createMock(MimeUnmarshaller.class);
184         MimeMarshaller marshaller = createMock(MimeMarshaller.class);
185         MimeMessage request = createMock("request", MimeMessage.class);
186         MimeMessage response = createMock("response", MimeMessage.class);
187         Source requestSource = new StringSource("<request/>");
188         expect(request.getPayloadSource()).andReturn(requestSource);
189         expect(factoryMock.createWebServiceMessage()).andReturn(response);
190         expect(unmarshaller.unmarshal(eq(requestSource), isA(MimeContainer.class))).andReturn(42L);
191         Result responseResult = new StringResult();
192         expect(response.getPayloadResult()).andReturn(responseResult);
193         marshaller.marshal(eq("result"), eq(responseResult), isA(MimeContainer.class));
194 
195         replay(factoryMock, unmarshaller, marshaller, request, response);
196 
197         AbstractMarshallingPayloadEndpoint endpoint = new AbstractMarshallingPayloadEndpoint() {
198             @Override
199             protected Object invokeInternal(Object requestObject) throws Exception {
200                 Assert.assertEquals("Invalid request object", 42L, requestObject);
201                 return "result";
202             }
203         };
204         endpoint.setMarshaller(marshaller);
205         endpoint.setUnmarshaller(unmarshaller);
206         endpoint.afterPropertiesSet();
207 
208         context = new DefaultMessageContext(request, factoryMock);
209         endpoint.invoke(context);
210         Assert.assertNotNull("Invalid result", response);
211 
212         verify(factoryMock, unmarshaller, marshaller, request, response);
213     }
214 
215     private static class SimpleMarshaller implements Marshaller, Unmarshaller {
216 
217         public void marshal(Object graph, Result result) throws XmlMappingException, IOException {
218             fail("Not expected");
219         }
220 
221         public Object unmarshal(Source source) throws XmlMappingException, IOException {
222             fail("Not expected");
223             return null;
224         }
225 
226         public boolean supports(Class<?> clazz) {
227             return false;
228         }
229     }
230 
231 }