mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-03 21:43:21 +00:00 
			
		
		
		
	run-dev: Rewrite development proxy with aiohttp.
This allows request cancellation to be propagated to the server. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							c1988a14a7
						
					
				
				
					commit
					55b26da82b
				
			@@ -8,6 +8,9 @@
 | 
			
		||||
# moto s3 mock
 | 
			
		||||
moto[s3]
 | 
			
		||||
 | 
			
		||||
# For tools/run-dev
 | 
			
		||||
aiohttp
 | 
			
		||||
 | 
			
		||||
# Needed for documentation links test
 | 
			
		||||
Scrapy
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -88,6 +88,7 @@ aiohttp==3.9.1 \
 | 
			
		||||
    --hash=sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065 \
 | 
			
		||||
    --hash=sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca
 | 
			
		||||
    # via
 | 
			
		||||
    #   -r requirements/dev.in
 | 
			
		||||
    #   aiohttp-retry
 | 
			
		||||
    #   twilio
 | 
			
		||||
aiohttp-retry==2.8.3 \
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										223
									
								
								tools/run-dev
									
									
									
									
									
								
							
							
						
						
									
										223
									
								
								tools/run-dev
									
									
									
									
									
								
							@@ -2,13 +2,13 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import asyncio
 | 
			
		||||
import errno
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import pwd
 | 
			
		||||
import signal
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
from typing import List, Sequence
 | 
			
		||||
from urllib.parse import urlunsplit
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
TOOLS_DIR = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
sys.path.insert(0, os.path.dirname(TOOLS_DIR))
 | 
			
		||||
@@ -18,9 +18,9 @@ from tools.lib import sanity_check
 | 
			
		||||
 | 
			
		||||
sanity_check.check_venv(__file__)
 | 
			
		||||
 | 
			
		||||
from tornado import httpclient, httputil, web
 | 
			
		||||
from tornado.platform.asyncio import AsyncIOMainLoop
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
import aiohttp
 | 
			
		||||
from aiohttp import hdrs, web
 | 
			
		||||
from returns.curry import partial
 | 
			
		||||
 | 
			
		||||
from tools.lib.test_script import add_provision_check_override_param, assert_provisioning_status_ok
 | 
			
		||||
 | 
			
		||||
@@ -54,11 +54,6 @@ parser.add_argument(
 | 
			
		||||
    help="Do not clear memcached on startup",
 | 
			
		||||
)
 | 
			
		||||
parser.add_argument("--streamlined", action="store_true", help="Avoid process_queue, etc.")
 | 
			
		||||
parser.add_argument(
 | 
			
		||||
    "--enable-tornado-logging",
 | 
			
		||||
    action="store_true",
 | 
			
		||||
    help="Enable access logs from tornado proxy server.",
 | 
			
		||||
)
 | 
			
		||||
parser.add_argument(
 | 
			
		||||
    "--behind-https-proxy",
 | 
			
		||||
    action="store_true",
 | 
			
		||||
@@ -204,135 +199,81 @@ def start_webpack_watcher() -> "subprocess.Popen[bytes]":
 | 
			
		||||
    return subprocess.Popen(webpack_cmd)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transform_url(protocol: str, path: str, query: str, target_port: int, target_host: str) -> str:
 | 
			
		||||
    # generate url with target host
 | 
			
		||||
    host = ":".join((target_host, str(target_port)))
 | 
			
		||||
    # Here we are going to rewrite the path a bit so that it is in parity with
 | 
			
		||||
    # what we will have for production
 | 
			
		||||
    newpath = urlunsplit((protocol, host, path, query, ""))
 | 
			
		||||
    return newpath
 | 
			
		||||
session: aiohttp.ClientSession
 | 
			
		||||
 | 
			
		||||
# https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1
 | 
			
		||||
HOP_BY_HOP_HEADERS = {
 | 
			
		||||
    hdrs.CONNECTION,
 | 
			
		||||
    hdrs.KEEP_ALIVE,
 | 
			
		||||
    hdrs.PROXY_AUTHENTICATE,
 | 
			
		||||
    hdrs.PROXY_AUTHORIZATION,
 | 
			
		||||
    hdrs.TE,
 | 
			
		||||
    hdrs.TRAILER,
 | 
			
		||||
    hdrs.TRANSFER_ENCODING,
 | 
			
		||||
    hdrs.UPGRADE,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Headers that aiohttp would otherwise generate by default
 | 
			
		||||
SKIP_AUTO_HEADERS = {
 | 
			
		||||
    hdrs.ACCEPT,
 | 
			
		||||
    hdrs.ACCEPT_ENCODING,
 | 
			
		||||
    hdrs.CONTENT_TYPE,
 | 
			
		||||
    hdrs.USER_AGENT,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
client: httpclient.AsyncHTTPClient
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseHandler(web.RequestHandler):
 | 
			
		||||
    # target server ip
 | 
			
		||||
    target_host: str = "127.0.0.1"
 | 
			
		||||
    # target server port
 | 
			
		||||
    target_port: int
 | 
			
		||||
 | 
			
		||||
    def _add_request_headers(
 | 
			
		||||
        self,
 | 
			
		||||
        exclude_lower_headers_list: Sequence[str] = [],
 | 
			
		||||
    ) -> httputil.HTTPHeaders:
 | 
			
		||||
        headers = httputil.HTTPHeaders()
 | 
			
		||||
        for header, v in self.request.headers.get_all():
 | 
			
		||||
            if header.lower() not in exclude_lower_headers_list:
 | 
			
		||||
                headers.add(header, v)
 | 
			
		||||
        return headers
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def head(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def post(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def put(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def patch(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def options(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def delete(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def prepare(self) -> None:
 | 
			
		||||
        assert self.request.method is not None
 | 
			
		||||
        assert self.request.remote_ip is not None
 | 
			
		||||
        if "X-REAL-IP" not in self.request.headers:
 | 
			
		||||
            self.request.headers["X-REAL-IP"] = self.request.remote_ip
 | 
			
		||||
        if "X-FORWARDED_PORT" not in self.request.headers:
 | 
			
		||||
            self.request.headers["X-FORWARDED-PORT"] = str(proxy_port)
 | 
			
		||||
        url = transform_url(
 | 
			
		||||
            self.request.protocol,
 | 
			
		||||
            self.request.path,
 | 
			
		||||
            self.request.query,
 | 
			
		||||
            self.target_port,
 | 
			
		||||
            self.target_host,
 | 
			
		||||
        )
 | 
			
		||||
async def forward(upstream_port: int, request: web.Request) -> web.StreamResponse:
 | 
			
		||||
    try:
 | 
			
		||||
            request = httpclient.HTTPRequest(
 | 
			
		||||
                url=url,
 | 
			
		||||
                method=self.request.method,
 | 
			
		||||
                headers=self._add_request_headers(["upgrade-insecure-requests"]),
 | 
			
		||||
                follow_redirects=False,
 | 
			
		||||
                body=self.request.body,
 | 
			
		||||
                allow_nonstandard_methods=True,
 | 
			
		||||
                # use large timeouts to handle polling requests
 | 
			
		||||
                connect_timeout=240.0,
 | 
			
		||||
                request_timeout=240.0,
 | 
			
		||||
                # https://github.com/tornadoweb/tornado/issues/2743
 | 
			
		||||
                decompress_response=False,
 | 
			
		||||
            )
 | 
			
		||||
            response = await client.fetch(request, raise_error=False)
 | 
			
		||||
 | 
			
		||||
            self.set_status(response.code, response.reason)
 | 
			
		||||
            self._headers = httputil.HTTPHeaders()  # clear tornado default header
 | 
			
		||||
 | 
			
		||||
            for header, v in response.headers.get_all():
 | 
			
		||||
                # some header appear multiple times, eg 'Set-Cookie'
 | 
			
		||||
                if header.lower() != "transfer-encoding":
 | 
			
		||||
                    self.add_header(header, v)
 | 
			
		||||
            if response.body:
 | 
			
		||||
                self.write(response.body)
 | 
			
		||||
        except (ConnectionError, httpclient.HTTPError) as e:
 | 
			
		||||
            self.set_status(500)
 | 
			
		||||
            self.write("Internal server error:\n" + str(e))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebPackHandler(BaseHandler):
 | 
			
		||||
    target_port = webpack_port
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoHandler(BaseHandler):
 | 
			
		||||
    target_port = django_port
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TornadoHandler(BaseHandler):
 | 
			
		||||
    target_port = tornado_port
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Application(web.Application):
 | 
			
		||||
    def __init__(self, enable_logging: bool = False) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            [
 | 
			
		||||
                (r"/json/events.*", TornadoHandler),
 | 
			
		||||
                (r"/api/v1/events.*", TornadoHandler),
 | 
			
		||||
                (r"/webpack.*", WebPackHandler),
 | 
			
		||||
                (r"/.*", DjangoHandler),
 | 
			
		||||
        upstream_response = await session.request(
 | 
			
		||||
            request.method,
 | 
			
		||||
            request.url.with_host("127.0.0.1").with_port(upstream_port),
 | 
			
		||||
            headers=[
 | 
			
		||||
                (key, value)
 | 
			
		||||
                for key, value in request.headers.items()
 | 
			
		||||
                if key not in HOP_BY_HOP_HEADERS
 | 
			
		||||
            ],
 | 
			
		||||
            enable_logging=enable_logging,
 | 
			
		||||
            data=request.content.iter_any() if request.body_exists else None,
 | 
			
		||||
            allow_redirects=False,
 | 
			
		||||
            auto_decompress=False,
 | 
			
		||||
            compress=False,
 | 
			
		||||
            skip_auto_headers=SKIP_AUTO_HEADERS,
 | 
			
		||||
        )
 | 
			
		||||
    except aiohttp.ClientError as error:
 | 
			
		||||
        logging.error(
 | 
			
		||||
            "Failed to forward %s %s to port %d: %s",
 | 
			
		||||
            request.method,
 | 
			
		||||
            request.url.path,
 | 
			
		||||
            upstream_port,
 | 
			
		||||
            error,
 | 
			
		||||
        )
 | 
			
		||||
        raise web.HTTPBadGateway from error
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def log_request(self, handler: web.RequestHandler) -> None:
 | 
			
		||||
        if self.settings["enable_logging"]:
 | 
			
		||||
            super().log_request(handler)
 | 
			
		||||
    response = web.StreamResponse(status=upstream_response.status, reason=upstream_response.reason)
 | 
			
		||||
    response.headers.extend(
 | 
			
		||||
        (key, value)
 | 
			
		||||
        for key, value in upstream_response.headers.items()
 | 
			
		||||
        if key not in HOP_BY_HOP_HEADERS
 | 
			
		||||
    )
 | 
			
		||||
    assert request.remote is not None
 | 
			
		||||
    response.headers["X-Real-IP"] = request.remote
 | 
			
		||||
    response.headers["X-Forwarded-Port"] = str(proxy_port)
 | 
			
		||||
    await response.prepare(request)
 | 
			
		||||
    async for data in upstream_response.content.iter_any():
 | 
			
		||||
        await response.write(data)
 | 
			
		||||
    await response.write_eof()
 | 
			
		||||
    return response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = web.Application()
 | 
			
		||||
app.add_routes(
 | 
			
		||||
    [
 | 
			
		||||
        web.route(
 | 
			
		||||
            hdrs.METH_ANY, r"/{path:json/events|api/v1/events}", partial(forward, tornado_port)
 | 
			
		||||
        ),
 | 
			
		||||
        web.route(hdrs.METH_ANY, r"/{path:webpack/.*}", partial(forward, webpack_port)),
 | 
			
		||||
        web.route(hdrs.METH_ANY, r"/{path:.*}", partial(forward, django_port)),
 | 
			
		||||
    ]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_listeners() -> None:
 | 
			
		||||
@@ -365,13 +306,12 @@ def print_listeners() -> None:
 | 
			
		||||
    print()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
runner: web.AppRunner
 | 
			
		||||
children: List["subprocess.Popen[bytes]"] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def serve() -> None:
 | 
			
		||||
    global client
 | 
			
		||||
 | 
			
		||||
    AsyncIOMainLoop().install()
 | 
			
		||||
    global runner, session
 | 
			
		||||
 | 
			
		||||
    if options.test:
 | 
			
		||||
        do_one_time_webpack_compile()
 | 
			
		||||
@@ -380,10 +320,12 @@ async def serve() -> None:
 | 
			
		||||
 | 
			
		||||
    children.extend(subprocess.Popen(cmd) for cmd in server_processes())
 | 
			
		||||
 | 
			
		||||
    client = httpclient.AsyncHTTPClient()
 | 
			
		||||
    app = Application(enable_logging=options.enable_tornado_logging)
 | 
			
		||||
    session = aiohttp.ClientSession()
 | 
			
		||||
    runner = web.AppRunner(app, auto_decompress=False, handler_cancellation=True)
 | 
			
		||||
    await runner.setup()
 | 
			
		||||
    site = web.TCPSite(runner, host=options.interface, port=proxy_port)
 | 
			
		||||
    try:
 | 
			
		||||
        app.listen(proxy_port, address=options.interface)
 | 
			
		||||
        await site.start()
 | 
			
		||||
    except OSError as e:
 | 
			
		||||
        if e.errno == errno.EADDRINUSE:
 | 
			
		||||
            print("\n\nERROR: You probably have another server running!!!\n\n")
 | 
			
		||||
@@ -400,6 +342,9 @@ try:
 | 
			
		||||
        loop.add_signal_handler(s, loop.stop)
 | 
			
		||||
    loop.run_forever()
 | 
			
		||||
finally:
 | 
			
		||||
    loop.run_until_complete(runner.cleanup())
 | 
			
		||||
    loop.run_until_complete(session.close())
 | 
			
		||||
 | 
			
		||||
    for child in children:
 | 
			
		||||
        child.terminate()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -214,11 +214,6 @@ class AsyncDjangoHandler(tornado.web.RequestHandler):
 | 
			
		||||
    def on_connection_close(self) -> None:
 | 
			
		||||
        # Register a Tornado handler that runs when client-side
 | 
			
		||||
        # connections are closed to notify the events system.
 | 
			
		||||
        #
 | 
			
		||||
        # Note that in the development environment, the development
 | 
			
		||||
        # proxy does not correctly close connections to Tornado when
 | 
			
		||||
        # its clients (e.g. `curl`) close their connections.  This
 | 
			
		||||
        # code path is thus _unreachable except in production_.
 | 
			
		||||
 | 
			
		||||
        # If the client goes away, garbage collect the handler (with
 | 
			
		||||
        # attached request information).
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user