解决HttpServletRequest的inputStream只能读取一次的问题

问题

在实际web项目中,有时需要再过滤器或拦截器中对请求参数做各种处理,比如验证签名、解密、参数安全过滤(判断参数是否有sql恶意注入的关键字)和参数修改等处理。如果参数是form表单格式(即key-value)提交,则通过request.getParameterMap()方法即可获取参数名和参数值,如果是json格式提交(使用controller中@requestBody注解接收参数),则需要request.getInputStream()request.getReader()方法从流中读取请求体(body)。但是存在一个问题,httpServletRequest中的inputStream只能被读取一次,即request.getInputStream()request.getReader()方法执行了一次之后,再到controller中通过@requestBody读取参数的时候会出现如下错误:

I/O error while reading input message; nested exception is java.io.IOException: Stream closed

原因

为什么HttpServletRequest的流只能读取一次呢?
调用httpServletRequest.getInputStream()可以看到获取的流类型为ServletInputStream,继承InputStream
InputStream的read方法内部有一个postion,标志当前流读取到的位置,每读取一次,位置就会移动一次,如果读到最后,read()会返回-1,表示流已经读取完了。如果想要重新读取则需要重写reset()方法,当然能否reset是有条件的,它取决于markSupported()是否返回true。
InputStream源码中默认不实现reset(),并且markSupported()默认返回false:

1
2
3
4
5
6
7
8
public synchronized void reset() throws IOException {
// 调用重新读取则抛出异常
throw new IOException("mark/reset not supported");
}
public boolean markSupported() {
// 不支持重新读取
return false;
}

而查看ServletInputStream源码可以发现,该类没有重写mark()reset()以及markSupported(),因此Request IO流无法重复读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public abstract class ServletInputStream extends InputStream {
protected ServletInputStream() {
}

public int readLine(byte[] b, int off, int len) throws IOException {
if (len <= 0) {
return 0;
} else {
int count = 0;

int c;
while((c = this.read()) != -1) {
b[off++] = (byte)c;
++count;
if (c == 10 || count == len) {
break;
}
}

return count > 0 ? count : -1;
}
}

public abstract boolean isFinished();

public abstract boolean isReady();

public abstract void setReadListener(ReadListener var1);
}

解决方案

为了解决这个问题,Servlet中提供了一个请求的包装类:HttpServletRequestWrapper,它基于装饰者模式实现了HttpServletRequest接口,封装了HttpServletRequest的所有行为,所以我们只需要继承该类,然后使用过滤器(filter)将原始的HttpServletRequest对象替换为HttpServletRequestWrapper对象。

并且继承的时候重写getInputStreamgetReader方法,将流的读取改为从类变量body中获取,这样就可以做到不限次数地获取请求体。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;

public class BodyReadHttpServletRequestWrapper extends HttpServletRequestWrapper {

/**
* 请求体
*/
private final byte[] body;


public BodyReadHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
body = InputStreamToByte(request.getInputStream());
}


/**
* 读取输入流中的字节数组
*/
private byte[] InputStreamToByte(InputStream is) throws IOException {
try (ByteArrayOutputStream bytestream = new ByteArrayOutputStream();) {
byte[] buffer = new byte[1024];
int ch;
while ((ch = is.read(buffer)) != -1) {
bytestream.write(buffer, 0, ch);
}
return bytestream.toByteArray();
}

}


@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}


/**
* 在使用@RequestBody注解的时候,其实框架是调用了getInputStream()方法,所以我们要重写这个方法
*/
@Override
public ServletInputStream getInputStream() throws IOException {
try (final ByteArrayInputStream bais = new ByteArrayInputStream(body)) {
return new ServletInputStream() {

@Override
public boolean isFinished() {
return false;
}

@Override
public boolean isReady() {
return false;
}

@Override
public void setReadListener(ReadListener readListener) {

}

@Override
public int read() {
return bais.read();
}

};
}
}


}

这样在filter中可以去读取输入流了,也不会影响controller中参数的获取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@Component
public class MyFilter implements Filter {


@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
BodyReadHttpServletRequestWrapper requestWrapper = new BodyReadHttpServletRequestWrapper(req);
StringBuilder sb = new StringBuilder();
// 获得所有请求参数名
Map<String, String[]> parameterMap = req.getParameterMap();
for (Map.Entry<String, String[]> stringEntry : parameterMap.entrySet()) {
String[] value = stringEntry.getValue();
for (int i = 0; i < value.length; i++) {
//得到参数对应值
sb.append(value[i]).append(",");
}
}

//通过流获取请求体
try (ServletInputStream inputStream = requestWrapper.getInputStream();) {
byte[] bytes = new byte[1024];
int n;
while ((n = inputStream.read(bytes)) != -1) {
sb.append(new String(bytes, 0, n, StandardCharsets.UTF_8));
}
}

//1.做其他处理....
//2.传递给链中的下一个过滤器
chain.doFilter(requestWrapper, response);
}


@Override
public void init(FilterConfig filterConfig) {
}

@Override
public void destroy() {
}

}
------ 本文完 ------