Commit eb01d39c authored by Emmanuel Bourg's avatar Emmanuel Bourg

New upstream version 4.3.16

parent e9dafb5c
version=4.3.15.RELEASE
version=4.3.16.RELEASE
......@@ -1003,4 +1003,19 @@ public class CodeFlow implements Opcodes {
void generateCode(MethodVisitor mv, CodeFlow codeflow);
}
public static String toBoxedDescriptor(String primitiveDescriptor) {
switch (primitiveDescriptor.charAt(0)) {
case 'I': return "Ljava/lang/Integer";
case 'J': return "Ljava/lang/Long";
case 'F': return "Ljava/lang/Float";
case 'D': return "Ljava/lang/Double";
case 'B': return "Ljava/lang/Byte";
case 'C': return "Ljava/lang/Character";
case 'S': return "Ljava/lang/Short";
case 'Z': return "Ljava/lang/Boolean";
default:
throw new IllegalArgumentException("Unexpected non primitive descriptor "+primitiveDescriptor);
}
}
}
......@@ -24,6 +24,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.springframework.asm.Label;
import org.springframework.asm.MethodVisitor;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.expression.AccessException;
......@@ -53,6 +54,8 @@ public class MethodReference extends SpelNodeImpl {
private final boolean nullSafe;
private String originalPrimitiveExitTypeDescriptor = null;
private volatile CachedMethodExecutor cachedExecutor;
......@@ -233,7 +236,14 @@ public class MethodReference extends SpelNodeImpl {
CachedMethodExecutor executorToCheck = this.cachedExecutor;
if (executorToCheck != null && executorToCheck.get() instanceof ReflectiveMethodExecutor) {
Method method = ((ReflectiveMethodExecutor) executorToCheck.get()).getMethod();
this.exitTypeDescriptor = CodeFlow.toDescriptor(method.getReturnType());
String descriptor = CodeFlow.toDescriptor(method.getReturnType());
if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) {
originalPrimitiveExitTypeDescriptor = descriptor;
this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor);
}
else {
this.exitTypeDescriptor = descriptor;
}
}
}
......@@ -293,17 +303,23 @@ public class MethodReference extends SpelNodeImpl {
boolean isStaticMethod = Modifier.isStatic(method.getModifiers());
String descriptor = cf.lastDescriptor();
if (descriptor == null) {
if (!isStaticMethod) {
// Nothing on the stack but something is needed
cf.loadTarget(mv);
}
Label skipIfNull = null;
if (descriptor == null && !isStaticMethod) {
// Nothing on the stack but something is needed
cf.loadTarget(mv);
}
else {
if (isStaticMethod) {
// Something on the stack when nothing is needed
mv.visitInsn(POP);
}
if ((descriptor != null || !isStaticMethod) && nullSafe) {
mv.visitInsn(DUP);
skipIfNull = new Label();
Label continueLabel = new Label();
mv.visitJumpInsn(IFNONNULL,continueLabel);
CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor);
mv.visitJumpInsn(GOTO, skipIfNull);
mv.visitLabel(continueLabel);
}
if (descriptor != null && isStaticMethod) {
// Something on the stack when nothing is needed
mv.visitInsn(POP);
}
if (CodeFlow.isPrimitive(descriptor)) {
......@@ -323,6 +339,14 @@ public class MethodReference extends SpelNodeImpl {
mv.visitMethodInsn((isStaticMethod ? INVOKESTATIC : INVOKEVIRTUAL), classDesc, method.getName(),
CodeFlow.createSignatureDescriptor(method), method.getDeclaringClass().isInterface());
cf.pushDescriptor(this.exitTypeDescriptor);
if (originalPrimitiveExitTypeDescriptor != null) {
// The output of the accessor will be a primitive but from the block above it might be null,
// so to have a 'common stack' element at skipIfNull target we need to box the primitive
CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor);
}
if (skipIfNull != null) {
mv.visitLabel(skipIfNull);
}
}
......
......@@ -21,6 +21,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.asm.Label;
import org.springframework.asm.MethodVisitor;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.expression.AccessException;
......@@ -47,6 +48,8 @@ public class PropertyOrFieldReference extends SpelNodeImpl {
private final boolean nullSafe;
private String originalPrimitiveExitTypeDescriptor = null;
private final String name;
private volatile PropertyAccessor cachedReadAccessor;
......@@ -83,7 +86,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl {
PropertyAccessor accessorToUse = this.cachedReadAccessor;
if (accessorToUse instanceof CompilablePropertyAccessor) {
CompilablePropertyAccessor accessor = (CompilablePropertyAccessor) accessorToUse;
this.exitTypeDescriptor = CodeFlow.toDescriptor(accessor.getPropertyType());
setExitTypeDescriptor(CodeFlow.toDescriptor(accessor.getPropertyType()));
}
return tv;
}
......@@ -350,8 +353,40 @@ public class PropertyOrFieldReference extends SpelNodeImpl {
if (!(accessorToUse instanceof CompilablePropertyAccessor)) {
throw new IllegalStateException("Property accessor is not compilable: " + accessorToUse);
}
Label skipIfNull = null;
if (nullSafe) {
mv.visitInsn(DUP);
skipIfNull = new Label();
Label continueLabel = new Label();
mv.visitJumpInsn(IFNONNULL,continueLabel);
CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor);
mv.visitJumpInsn(GOTO, skipIfNull);
mv.visitLabel(continueLabel);
}
((CompilablePropertyAccessor) accessorToUse).generateCode(this.name, mv, cf);
cf.pushDescriptor(this.exitTypeDescriptor);
if (originalPrimitiveExitTypeDescriptor != null) {
// The output of the accessor is a primitive but from the block above it might be null,
// so to have a common stack element type at skipIfNull target it is necessary
// to box the primitive
CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor);
}
if (skipIfNull != null) {
mv.visitLabel(skipIfNull);
}
}
void setExitTypeDescriptor(String descriptor) {
// If this property or field access would return a primitive - and yet
// it is also marked null safe - then the exit type descriptor must be
// promoted to the box type to allow a null value to be passed on
if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) {
this.originalPrimitiveExitTypeDescriptor = descriptor;
this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor);
}
else {
this.exitTypeDescriptor = descriptor;
}
}
......@@ -379,8 +414,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl {
this.ref.getValueInternal(this.contextObject, this.evalContext, this.autoGrowNullReferences);
PropertyAccessor accessorToUse = this.ref.cachedReadAccessor;
if (accessorToUse instanceof CompilablePropertyAccessor) {
this.ref.exitTypeDescriptor =
CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType());
this.ref.setExitTypeDescriptor(CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType()));
}
return value;
}
......
......@@ -701,7 +701,167 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests {
assertCanCompile(expression);
assertEquals("def", expression.getValue());
}
@Test
public void nullsafeFieldPropertyDereferencing_SPR16489() throws Exception {
FooObjectHolder foh = new FooObjectHolder();
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(foh);
// First non compiled:
SpelExpression expression = (SpelExpression) parser.parseExpression("foo?.object");
assertEquals("hello",expression.getValue(context));
foh.foo = null;
assertNull(expression.getValue(context));
// Now revert state of foh and try compiling it:
foh.foo = new FooObject();
assertEquals("hello",expression.getValue(context));
assertCanCompile(expression);
assertEquals("hello",expression.getValue(context));
foh.foo = null;
assertNull(expression.getValue(context));
// Static references
expression = (SpelExpression)parser.parseExpression("#var?.propertya");
context.setVariable("var", StaticsHelper.class);
assertEquals("sh",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", StaticsHelper.class);
assertEquals("sh",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Single size primitive (boolean)
expression = (SpelExpression)parser.parseExpression("#var?.a");
context.setVariable("var", new TestClass4());
assertFalse((Boolean)expression.getValue(context));
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", new TestClass4());
assertFalse((Boolean)expression.getValue(context));
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Double slot primitives
expression = (SpelExpression)parser.parseExpression("#var?.four");
context.setVariable("var", new Three());
assertEquals("0.04",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", new Three());
assertEquals("0.04",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
}
@Test
public void nullsafeMethodChaining_SPR16489() throws Exception {
FooObjectHolder foh = new FooObjectHolder();
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(foh);
// First non compiled:
SpelExpression expression = (SpelExpression) parser.parseExpression("getFoo()?.getObject()");
assertEquals("hello",expression.getValue(context));
foh.foo = null;
assertNull(expression.getValue(context));
assertCanCompile(expression);
foh.foo = new FooObject();
assertEquals("hello",expression.getValue(context));
foh.foo = null;
assertNull(expression.getValue(context));
// Static method references
expression = (SpelExpression)parser.parseExpression("#var?.methoda()");
context.setVariable("var", StaticsHelper.class);
assertEquals("sh",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", StaticsHelper.class);
assertEquals("sh",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.intValue()");
context.setVariable("var", 4);
assertEquals("4",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", 4);
assertEquals("4",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()");
context.setVariable("var", false);
assertEquals("false",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", false);
assertEquals("false",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()");
context.setVariable("var", true);
assertEquals("true",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", true);
assertEquals("true",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.longValue()");
context.setVariable("var", 5L);
assertEquals("5",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", 5L);
assertEquals("5",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.floatValue()");
context.setVariable("var", 3f);
assertEquals("3.0",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", 3f);
assertEquals("3.0",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
// Nullsafe guard on expression element evaluating to primitive/null
expression = (SpelExpression)parser.parseExpression("#var?.shortValue()");
context.setVariable("var", (short)8);
assertEquals("8",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
assertCanCompile(expression);
context.setVariable("var", (short)8);
assertEquals("8",expression.getValue(context).toString());
context.setVariable("var", null);
assertNull(expression.getValue(context));
}
@Test
public void elvis() throws Exception {
Expression expression = parser.parseExpression("'a'?:'b'");
......@@ -3063,19 +3223,47 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests {
assertEquals(1.0f, expression.getValue());
}
@Test
public void compilationOfBasicNullSafeMethodReference() {
SpelExpressionParser parser = new SpelExpressionParser(
new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader()));
SpelExpression expression = parser.parseRaw("#it?.equals(3)");
StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1});
context.setVariable("it", 3);
expression.setEvaluationContext(context);
assertTrue(expression.getValue(Boolean.class));
context.setVariable("it", null);
assertNull(expression.getValue(Boolean.class));
assertCanCompile(expression);
context.setVariable("it", 3);
assertTrue(expression.getValue(Boolean.class));
context.setVariable("it", null);
assertNull(expression.getValue(Boolean.class));
}
@Test
public void failsWhenSettingContextForExpression_SPR12326() {
SpelExpressionParser parser = new SpelExpressionParser(
new SpelParserConfiguration(SpelCompilerMode.IMMEDIATE, getClass().getClassLoader()));
new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader()));
Person3 person = new Person3("foo", 1);
SpelExpression expression = parser.parseRaw("#it?.age?.equals([0])");
StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1});
context.setVariable("it", person);
expression.setEvaluationContext(context);
assertTrue(expression.getValue(Boolean.class));
// This will trigger compilation (second usage)
assertTrue(expression.getValue(Boolean.class));
context.setVariable("it", null);
assertNull(expression.getValue(Boolean.class));
assertCanCompile(expression);
context.setVariable("it", person);
assertTrue(expression.getValue(Boolean.class));
context.setVariable("it", null);
assertNull(expression.getValue(Boolean.class));
}
......@@ -5078,6 +5266,14 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests {
}
}
public static class FooObjectHolder {
private FooObject foo = new FooObject();
public FooObject getFoo() {
return foo;
}
}
public static class FooObject {
......
......@@ -26,7 +26,6 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import org.springframework.expression.AccessException;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
......@@ -34,7 +33,7 @@ import org.springframework.expression.PropertyAccessor;
import org.springframework.expression.TypedValue;
import org.springframework.expression.spel.SpelEvaluationException;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.expression.spel.support.SimpleEvaluationContext;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
......@@ -64,6 +63,10 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
/** Default maximum number of entries for the destination cache: 1024 */
public static final int DEFAULT_CACHE_LIMIT = 1024;
/** Static evaluation context to reuse */
private static EvaluationContext messageEvalContext =
SimpleEvaluationContext.forPropertyAccessors(new SimpMessageHeaderPropertyAccessor()).build();
private PathMatcher pathMatcher = new AntPathMatcher();
......@@ -191,7 +194,6 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
if (!this.selectorHeaderInUse) {
return allMatches;
}
EvaluationContext context = null;
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>(allMatches.size());
for (String sessionId : allMatches.keySet()) {
for (String subId : allMatches.get(sessionId)) {
......@@ -208,12 +210,8 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
result.add(sessionId, subId);
continue;
}
if (context == null) {
context = new StandardEvaluationContext(message);
context.getPropertyAccessors().add(new SimpMessageHeaderPropertyAccessor());
}
try {
if (expression.getValue(context, boolean.class)) {
if (Boolean.TRUE.equals(expression.getValue(messageEvalContext, message, Boolean.class))) {
result.add(sessionId, subId);
}
}
......@@ -525,7 +523,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
@Override
public Class<?>[] getSpecificTargetClasses() {
return new Class<?>[] {MessageHeaders.class};
return new Class<?>[] {Message.class, MessageHeaders.class};
}
@Override
......@@ -534,19 +532,29 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
}
@Override
public TypedValue read(EvaluationContext context, Object target, String name) throws AccessException {
MessageHeaders headers = (MessageHeaders) target;
SimpMessageHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(headers, SimpMessageHeaderAccessor.class);
public TypedValue read(EvaluationContext context, Object target, String name) {
Object value;
if ("destination".equalsIgnoreCase(name)) {
value = accessor.getDestination();
if (target instanceof Message) {
value = name.equals("headers") ? ((Message) target).getHeaders() : null;
}
else {
value = accessor.getFirstNativeHeader(name);
if (value == null) {
value = headers.get(name);
else if (target instanceof MessageHeaders) {
MessageHeaders headers = (MessageHeaders) target;
SimpMessageHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(headers, SimpMessageHeaderAccessor.class);
Assert.state(accessor != null, "No SimpMessageHeaderAccessor");
if ("destination".equalsIgnoreCase(name)) {
value = accessor.getDestination();
}
else {
value = accessor.getFirstNativeHeader(name);
if (value == null) {
value = headers.get(name);
}
}
}
else {
// Should never happen...
throw new IllegalStateException("Expected Message or MessageHeaders.");
}
return new TypedValue(value);
}
......
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -25,6 +25,8 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.mock.web.MockAsyncContext;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter;
......@@ -35,6 +37,7 @@ import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.WebUtils;
/**
* A sub-class of {@code DispatcherServlet} that saves the result in an
......@@ -64,8 +67,24 @@ final class TestDispatcherServlet extends DispatcherServlet {
throws ServletException, IOException {
registerAsyncResultInterceptors(request);
super.service(request, response);
initAsyncDispatchLatch(request);
if (request.getAsyncContext() != null) {
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext());
Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?");
final CountDownLatch dispatchLatch = new CountDownLatch(1);
mockAsyncContext.addDispatchHandler(new Runnable() {
@Override
public void run() {
dispatchLatch.countDown();
}
});
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}
private void registerAsyncResultInterceptors(final HttpServletRequest request) {
......@@ -84,19 +103,6 @@ final class TestDispatcherServlet extends DispatcherServlet {
});
}
private void initAsyncDispatchLatch(HttpServletRequest request) {
if (request.getAsyncContext() != null) {
final CountDownLatch dispatchLatch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() {
@Override
public void run() {
dispatchLatch.countDown();
}
});
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
}
}
protected DefaultMvcResult getMvcResult(ServletRequest request) {
return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE);
}
......
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -18,9 +18,15 @@ package org.springframework.test.web.servlet.samples.standalone;
import java.io.IOException;
import java.security.Principal;
import java.util.concurrent.CompletableFuture;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncListener;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
......@@ -29,17 +35,23 @@ import javax.validation.Valid;
import org.junit.Test;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.Person;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.validation.Errors;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.support.RedirectAttributes;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*;
/**
......@@ -107,6 +119,22 @@ public class FilterTests {
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
}
@Test // SPR-16695
public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception {
MockMvc mockMvc = standaloneSetup(new PersonController())
.addFilters(new WrappingRequestResponseFilter())
.build();
MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
.andExpect(request().asyncStarted())
.andExpect(request().asyncResult(new Person("Lukas")))
.andReturn();
mockMvc.perform(asyncDispatch(mvcResult))
.andExpect(status().isOk())
.andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}"));
}
@Controller
private static class PersonController {
......@@ -129,6 +157,12 @@ public class FilterTests {
public String forward() {
return "forward:/persons";
}
@GetMapping("persons/{id}")
@ResponseBody