通过 Tomcat servlet 代理常规 HTTP 和 WebSocket

问题描述 投票:0回答:2

我正在实现一个 Web 应用程序,除其他外,它必须显示代理到后端服务的网页并与之交互。为此,我使用的是 HTTP-Proxy-Servlet,它在大多数情况下都运行良好。

但是,某些后端服务的网页使用了websockets,而上面的代理servlet不支持websockets.

我尝试通过重构对后端的 websocket 调用然后在流之间复制来实现它,但这不起作用。浏览器报“Invalid frame header”,Tomcat 失败并显示

Error parsing HTTP request header
Invalid character found in method name. HTTP method names must be tokens
at org.apache.coyote.http11.Http11InputBuffer.parseRequestLine(Http11InputBuffer.java:414)

我的代码:

import java.io.IOException;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.*;

import javax.servlet.ServletException;
import javax.servlet.http.*;

import org.apache.http.HttpRequest;
import org.mitre.dsmiley.httpproxy.ProxyServlet;

public class ProxyWithWebSocket extends ProxyServlet {

    private static final long serialVersionUID = -2566573965489129976L;

    protected ExecutorService exec;
    
    @Override
    public void init() throws ServletException {
        super.init();
        exec = Executors.newCachedThreadPool();
    }
    
    @Override
    public void destroy() {
        super.destroy();
        exec.shutdown();
    }

    @Override
    protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
            throws ServletException, IOException {
        var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
        if (wsKey != null) {
            //initialize request attributes from caches if unset by a subclass by this point
            if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
              servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
            }
            if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
              servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
            }
            String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
            URL u = new URL(proxyRequestUri);

            var servletIn = servletRequest.getInputStream();
            var servletOut = servletResponse.getOutputStream();

            try (Socket sock = new Socket(u.getHost(), u.getPort())) {
                var sockIn = sock.getInputStream();
                var sockOut = sock.getOutputStream();
                
                StringBuilder req = new StringBuilder(512);
                req.append("GET " + u.getFile()).append(" HTTP/1.1");
                System.out.println("  > WS|" + req);
                req.append("\r\n");
                var en = servletRequest.getHeaderNames();
                while (en.hasMoreElements()) {
                    var n = en.nextElement();
                    String header = servletRequest.getHeader(n);
                    System.out.println("  > WS| " + n + ": " + header);
                    req.append(n + ": " + header + "\r\n");
                }
                req.append("\r\n");
                
                sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
                sockOut.flush();
    
                StringBuilder responseBytes = new StringBuilder(512);
                int b = 0;
                while (b != -1) {
                    b = sockIn.read();
                    if (b != -1) {
                        responseBytes.append((char)b);
                        var len = responseBytes.length();
                        if (len >= 4
                                && responseBytes.charAt(len - 4) == '\r'
                                && responseBytes.charAt(len - 3) == '\n'
                                && responseBytes.charAt(len - 2) == '\r'
                                && responseBytes.charAt(len - 1) == '\n'
                        ) {
                            break;
                        }
                    }
                }
                
                String[] rows = responseBytes.toString().split("\r\n"); 
                
                String response = rows[0];
                System.out.println("  < WS|" + response);
                
                int idx1 = response.indexOf(' ');
                int idx2 = response.indexOf(' ', idx1 + 1);
                
                for (int i = 1; i < rows.length; i++) {
                    String line = rows[i];
                    int idx3 = line.indexOf(":");
                    var k = line.substring(0, idx3);
                    var headerField = line.substring(idx3 + 2);
                    System.out.println("  < WS| " + k + ": " + headerField);
                    servletResponse.setHeader(k, headerField);
                }
                
                servletResponse.setStatus(Integer.parseInt(response.substring(idx1 + 1, idx2)));
                servletResponse.flushBuffer();
                
                System.out.println("  < WS| Flush");
    
                var f1 = exec.submit(() -> {
                    var c = 0;
                    
                    var bs = 0;
                    while ((bs = servletIn.read()) != -1) {
                        sockOut.write(bs);
                        c++;
                    }
                    System.out.println("  > WS| Done: " + c);
                    return null;
                });
                var f2 = exec.submit(() -> {
                    var c = 0;
                    
                    var bs = 0;
                    while ((bs = sockIn.read()) != -1) {
                        servletOut.write(bs);
                        servletOut.flush();
                        c++;
                    }
                    System.out.println("  < WS| Done: " + c);
                    return null;
                });
    
                try {
                    f1.get();
                } catch (Exception ex) {
                    f2.cancel(true);
                    return;
                }
                try {
                    f2.get();
                } catch (Exception ex) {
                    
                }
            }
        } else {
            super.service(servletRequest, servletResponse);
        }
    }
}

典型的交换看起来像这样(通过那些 println):

  > WS|GET /cellhub?id=NhWO8SnGyDb_Vrk23rmhVQ HTTP/1.1
  > WS| host: localhost:8080
  > WS| connection: Upgrade
  > WS| pragma: no-cache
  > WS| cache-control: no-cache
  > WS| user-agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.71 Safari/537.36
  > WS| upgrade: websocket
  > WS| origin: http://localhost:8080
  > WS| sec-websocket-version: 13
  > WS| accept-encoding: gzip, deflate, br
  > WS| accept-language: hu,hu-HU;q=0.9,en-US;q=0.8,en;q=0.7
  > WS| cookie: JSESSIONID=57E4B30452BC3EB2657139DAF70E65AD; JSESSIONID=AD5E7BB5FE17B4072F3ABEE32B9479AC
  > WS| sec-websocket-key: nrZWEb6Co4DKggUNwPeV8g==
  > WS| sec-websocket-extensions: permessage-deflate; client_max_window_bits
  < WS|HTTP/1.1 101 Switching Protocols
  < WS| Connection:  Upgrade
  < WS| Date:  Thu, 07 Oct 2021 13:18:41 GMT
  < WS| Server:  Kestrel
  < WS| Upgrade:  websocket
  < WS| Sec-WebSocket-Accept:  /9uN8ZF67WepGJQ3+DPBLMCBotc=
  < WS| Flush
  > WS| Done: 0
  < WS| Done: 42

我怎样才能完成这项工作?

编辑

我找到了

HttpServletRequest.upgrade
方法,它似乎是用于更改协议的。我在标题复制后更新了部分:

                int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
                if (respCode != 101) {
                    servletResponse.setStatus(respCode);
                    servletResponse.flushBuffer();
                    System.out.println("  < WS| Flush");
                    closeSocket = true;
                } else {
                    var uh = servletRequest.upgrade(WsUpgradeHandler.class);
                    uh.preInit(exec, sockIn, sockOut, sock);
                }

WsUpgradeHandler
在哪里

    public static class WsUpgradeHandler implements HttpUpgradeHandler {

        ExecutorService exec;
        InputStream sockIn;
        OutputStream sockOut;
        Socket sock;
        Future<?> f1;
        Future<?> f2;
        
        public WsUpgradeHandler() { }
        
        public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
            this.exec = exec;
            this.sockIn = sockIn;
            this.sockOut = sockOut;
            this.sock = sock;
        }
        
        @Override
        public void init(WebConnection wc) {
            System.out.println("  * WS| Upgrade begin");
            try {
                var servletIn = wc.getInputStream();
                var servletOut = wc.getOutputStream();
                f1 = exec.submit(() -> {
                    System.out.println("  > WS| Client -> Backend");
                    var c = 0;
                    
                    var bs = 0;
                    try {
                        while ((bs = servletIn.read()) != -1) {
                            sockOut.write(bs);
                            c++;
                        }
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        sockOut.close();
                    }
                    System.out.println("  > WS| Done: " + c);
                    return null;
                });
                f2 = exec.submit(() -> {
                    System.out.println("  > WS| Backend -> Client");
                    var c = 0;
                    
                    try {
                        var bs = 0;
                        while ((bs = sockIn.read()) != -1) {
                            servletOut.write(bs);
                            servletOut.flush();
                            c++;
                        }
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        servletOut.close();
                    }
                    System.out.println("  < WS| Done: " + c);
                    return null;
                });

            } catch (IOException ex) {
                ex.printStackTrace();
            }
        }

        @Override
        public void destroy() {
            System.out.println("  * WS| Upgrade closing");
            f1.cancel(true);
            f2.cancel(true);
            try {
                sock.close();
            } catch (IOException ex) {
                
            }
            System.out.println("  * WS| Upgrade close");
        }
        
    }

这确实适用于传递消息,但如果来自浏览器的 websocket 连接结束,此时 Tomcat 的 CPU 利用率会非常高(不应发生其他活动)。看来 Tomcat 的部分或全部 NIO theads 正在旋转,我正在使用的线程池不再有线程了。

tomcat servlets websocket proxy
2个回答
1
投票

我想我设法解决了这个问题。

上面的代码几乎是正确的,只有一个例外:显然,

init()
方法在使用阻塞模式时不应返回,如this Tomcat 测试示例所示。

第二个问题,即高 CPU 使用率被追踪到 tomcat 中的一个轮询线程,该线程之前有 bugs。我在 Tomcat 9.0.12 中运行我的代码,一旦升级到 Tomcat 9.0.54,CPU 使用问题就消失了。

因此,完整的工作代码如下所示:(我知道,我知道,字节显示和手动准备 HTML 请求并不是最优的,但这就是 Loom 的用途,对吧;)

import java.io.*;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.*;

import javax.servlet.ServletException;
import javax.servlet.http.*;

import org.apache.http.HttpRequest;
import org.mitre.dsmiley.httpproxy.ProxyServlet;

public class ProxyWithWebSocket extends ProxyServlet {

    private static final long serialVersionUID = -2566573965489129976L;

    protected ExecutorService exec;
    
    @Override
    public void init() throws ServletException {
        super.init();
        exec = Executors.newCachedThreadPool();
    }
    
    @Override
    public void destroy() {
        super.destroy();
        exec.shutdown();
    }
    
    @Override
    protected void copyRequestHeaders(HttpServletRequest servletRequest, HttpRequest proxyRequest) {
        super.copyRequestHeaders(servletRequest, proxyRequest);
        
        String userId = (String)servletRequest.getAttribute("UserID");
        if (userId != null) {
            proxyRequest.addHeader("UserID", userId);
        }
    }

    @Override
    protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
            throws ServletException, IOException {
        var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
        if (wsKey != null) {
            
            //initialize request attributes from caches if unset by a subclass by this point
            if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
              servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
            }
            if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
              servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
            }
            String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
            URL u = new URL(proxyRequestUri);

            Socket sock = new Socket(u.getHost(), u.getPort());
            boolean closeSocket = false;
            try {
                var sockIn = sock.getInputStream();
                var sockOut = sock.getOutputStream();
                
                StringBuilder req = new StringBuilder(512);
                req.append("GET " + u.getFile()).append(" HTTP/1.1");
                System.out.println("  > WS|" + req);
                req.append("\r\n");
                var en = servletRequest.getHeaderNames();
                while (en.hasMoreElements()) {
                    var n = en.nextElement();
                    String header = servletRequest.getHeader(n);
                    System.out.println("  > WS| " + n + ": " + header);
                    req.append(n + ": " + header + "\r\n");
                }
                req.append("\r\n");
                
                sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
                sockOut.flush();
    
                StringBuilder responseBytes = new StringBuilder(512);
                int b = 0;
                while (b != -1) {
                    b = sockIn.read();
                    if (b != -1) {
                        responseBytes.append((char)b);
                        var len = responseBytes.length();
                        if (len >= 4
                                && responseBytes.charAt(len - 4) == '\r'
                                && responseBytes.charAt(len - 3) == '\n'
                                && responseBytes.charAt(len - 2) == '\r'
                                && responseBytes.charAt(len - 1) == '\n'
                        ) {
                            break;
                        }
                    }
                }
                
                String[] rows = responseBytes.toString().split("\r\n"); 
                
                String response = rows[0];
                System.out.println("  < WS|" + response);
                
                int idx1 = response.indexOf(' ');
                int idx2 = response.indexOf(' ', idx1 + 1);
                
                for (int i = 1; i < rows.length; i++) {
                    String line = rows[i];
                    int idx3 = line.indexOf(":");
                    var k = line.substring(0, idx3);
                    var headerField = line.substring(idx3 + 2);
                    System.out.println("  < WS| " + k + ": " + headerField);
                    servletResponse.setHeader(k, headerField);
                }
                
                int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
                if (respCode != 101) {
                    servletResponse.setStatus(respCode);
                    servletResponse.flushBuffer();
                    System.out.println("  < WS| Flush");
                    closeSocket = true;
                } else {
                    var uh = servletRequest.upgrade(WsUpgradeHandler.class);
                    uh.preInit(exec, sockIn, sockOut, sock);
                }
    
                
            } finally {
                if (closeSocket) {
                    sock.close();
                }
            }
        } else {
            super.service(servletRequest, servletResponse);
        }
    }
    
    public static class WsUpgradeHandler implements HttpUpgradeHandler {

        ExecutorService exec;
        InputStream sockIn;
        OutputStream sockOut;
        Socket sock;
        Future<?> f2;
        
        public WsUpgradeHandler() { }
        
        public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
            this.exec = exec;
            this.sockIn = sockIn;
            this.sockOut = sockOut;
            this.sock = sock;
        }
        
        @Override
        public void init(WebConnection wc) {
            System.out.println("  * WS| Upgrade begin");
            try {
                var servletIn = wc.getInputStream();
                var servletOut = wc.getOutputStream();
                f2 = exec.submit(() -> {
                    System.out.println("  > WS| Backend -> Client");
                    var c = 0;
                    
                    try {
                        var bs = 0;
                        while ((bs = sockIn.read()) != -1) {
                            servletOut.write(bs);
                            servletOut.flush();
                            c++;
                        }
                    } catch (SocketException | EOFException exc) {
                        // this is fine
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        servletOut.close();
                    }
                    System.out.println("  < WS| Done: " + c);
                    return null;
                });

                System.out.println("  > WS| Client -> Backend");
                var c = 0;
                
                var bs = 0;
                try {
                    while ((bs = servletIn.read()) != -1) {
                        sockOut.write(bs);
                        c++;
                    }
                } catch (SocketException | EOFException exc) {
                    // this is fine
                } catch (Exception exc) {
                    exc.printStackTrace();
                } finally {
                    sockOut.close();
                }
                System.out.println("  > WS| Done: " + c);

                f2.get();
            } catch (Exception ex) {
                ex.printStackTrace();
            } finally {
                if (f2 != null) {
                    f2.cancel(true);
                }
            }
        }

        @Override
        public void destroy() {
            System.out.println("  * WS| Upgrade closing");
            if (f2 != null) {
                f2.cancel(true);
            }
            try {
                sock.close();
            } catch (IOException ex) {
                
            }
            System.out.println("  * WS| Upgrade close");
        }
        
    }
}

0
投票

谢谢你的代码!我用它来桥接 rasa 聊天机器人和 socket.io

遗憾的是我不能评论你的解决方案。

标题有一个小错误。

如果主机标头不正确,有些服务器会返回 400。

所以我加了一个小如果:

                String header = n.toLowerCase().equals("host") ? u.getHost() : servletRequest.getHeader(n);
© www.soinside.com 2019 - 2024. All rights reserved.