mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-04 05:53:43 +00:00 
			
		
		
		
	An implicit coercion from an untyped dict to the TypedDict was hiding
a type error: CapturedQuery.sql was really str, not bytes.  We should
always prefer dataclass over TypedDict to prevent such errors.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
(cherry picked from commit 92db6eba78)
		
	
		
			
				
	
	
		
			765 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			765 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import collections
 | 
						|
import os
 | 
						|
import re
 | 
						|
import sys
 | 
						|
import time
 | 
						|
from contextlib import contextmanager
 | 
						|
from dataclasses import dataclass
 | 
						|
from typing import (
 | 
						|
    IO,
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Any,
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Iterable,
 | 
						|
    Iterator,
 | 
						|
    List,
 | 
						|
    Mapping,
 | 
						|
    Optional,
 | 
						|
    Tuple,
 | 
						|
    TypeVar,
 | 
						|
    Union,
 | 
						|
    cast,
 | 
						|
)
 | 
						|
from unittest import mock
 | 
						|
from unittest.mock import patch
 | 
						|
 | 
						|
import boto3.session
 | 
						|
import fakeldap
 | 
						|
import ldap
 | 
						|
import orjson
 | 
						|
from django.conf import settings
 | 
						|
from django.contrib.auth.models import AnonymousUser
 | 
						|
from django.db.migrations.state import StateApps
 | 
						|
from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
 | 
						|
from django.http.request import QueryDict
 | 
						|
from django.http.response import HttpResponseBase
 | 
						|
from django.test import override_settings
 | 
						|
from django.urls import URLResolver
 | 
						|
from moto.s3 import mock_s3
 | 
						|
from mypy_boto3_s3.service_resource import Bucket
 | 
						|
 | 
						|
from zerver.actions.realm_settings import do_set_realm_user_default_setting
 | 
						|
from zerver.actions.user_settings import do_change_user_setting
 | 
						|
from zerver.lib import cache
 | 
						|
from zerver.lib.avatar import avatar_url
 | 
						|
from zerver.lib.cache import get_cache_backend
 | 
						|
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
 | 
						|
from zerver.lib.integrations import WEBHOOK_INTEGRATIONS
 | 
						|
from zerver.lib.rate_limiter import RateLimitedIPAddr, rules
 | 
						|
from zerver.lib.request import RequestNotes
 | 
						|
from zerver.lib.upload.s3 import S3UploadBackend
 | 
						|
from zerver.models import (
 | 
						|
    Client,
 | 
						|
    Message,
 | 
						|
    RealmUserDefault,
 | 
						|
    Subscription,
 | 
						|
    UserMessage,
 | 
						|
    UserProfile,
 | 
						|
    get_client,
 | 
						|
    get_realm,
 | 
						|
    get_stream,
 | 
						|
)
 | 
						|
from zerver.tornado.handlers import AsyncDjangoHandler, allocate_handler_id
 | 
						|
from zilencer.models import RemoteZulipServer
 | 
						|
from zproject.backends import ExternalAuthDataDict, ExternalAuthResult
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from django.test.client import _MonkeyPatchedWSGIResponse as TestHttpResponse
 | 
						|
 | 
						|
    # Avoid an import cycle; we only need these for type annotations.
 | 
						|
    from zerver.lib.test_classes import MigrationsTestCase, ZulipTestCase
 | 
						|
 | 
						|
 | 
						|
class MockLDAP(fakeldap.MockLDAP):
 | 
						|
    class LDAPError(ldap.LDAPError):
 | 
						|
        pass
 | 
						|
 | 
						|
    class INVALID_CREDENTIALS(ldap.INVALID_CREDENTIALS):  # noqa: N801
 | 
						|
        pass
 | 
						|
 | 
						|
    class NO_SUCH_OBJECT(ldap.NO_SUCH_OBJECT):  # noqa: N801
 | 
						|
        pass
 | 
						|
 | 
						|
    class ALREADY_EXISTS(ldap.ALREADY_EXISTS):  # noqa: N801
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def stub_event_queue_user_events(
 | 
						|
    event_queue_return: Any, user_events_return: Any
 | 
						|
) -> Iterator[None]:
 | 
						|
    with mock.patch("zerver.lib.events.request_event_queue", return_value=event_queue_return):
 | 
						|
        with mock.patch("zerver.lib.events.get_user_events", return_value=user_events_return):
 | 
						|
            yield
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def cache_tries_captured() -> Iterator[List[Tuple[str, Union[str, List[str]], Optional[str]]]]:
 | 
						|
    cache_queries: List[Tuple[str, Union[str, List[str]], Optional[str]]] = []
 | 
						|
 | 
						|
    orig_get = cache.cache_get
 | 
						|
    orig_get_many = cache.cache_get_many
 | 
						|
 | 
						|
    def my_cache_get(key: str, cache_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
 | 
						|
        cache_queries.append(("get", key, cache_name))
 | 
						|
        return orig_get(key, cache_name)
 | 
						|
 | 
						|
    def my_cache_get_many(
 | 
						|
        keys: List[str], cache_name: Optional[str] = None
 | 
						|
    ) -> Dict[str, Any]:  # nocoverage -- simulated code doesn't use this
 | 
						|
        cache_queries.append(("getmany", keys, cache_name))
 | 
						|
        return orig_get_many(keys, cache_name)
 | 
						|
 | 
						|
    with mock.patch.multiple(cache, cache_get=my_cache_get, cache_get_many=my_cache_get_many):
 | 
						|
        yield cache_queries
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def simulated_empty_cache() -> Iterator[List[Tuple[str, Union[str, List[str]], Optional[str]]]]:
 | 
						|
    cache_queries: List[Tuple[str, Union[str, List[str]], Optional[str]]] = []
 | 
						|
 | 
						|
    def my_cache_get(key: str, cache_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
 | 
						|
        cache_queries.append(("get", key, cache_name))
 | 
						|
        return None
 | 
						|
 | 
						|
    def my_cache_get_many(
 | 
						|
        keys: List[str], cache_name: Optional[str] = None
 | 
						|
    ) -> Dict[str, Any]:  # nocoverage -- simulated code doesn't use this
 | 
						|
        cache_queries.append(("getmany", keys, cache_name))
 | 
						|
        return {}
 | 
						|
 | 
						|
    with mock.patch.multiple(cache, cache_get=my_cache_get, cache_get_many=my_cache_get_many):
 | 
						|
        yield cache_queries
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class CapturedQuery:
 | 
						|
    sql: str
 | 
						|
    time: str
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def queries_captured(
 | 
						|
    include_savepoints: bool = False, keep_cache_warm: bool = False
 | 
						|
) -> Iterator[List[CapturedQuery]]:
 | 
						|
    """
 | 
						|
    Allow a user to capture just the queries executed during
 | 
						|
    the with statement.
 | 
						|
    """
 | 
						|
 | 
						|
    queries: List[CapturedQuery] = []
 | 
						|
 | 
						|
    def wrapper_execute(
 | 
						|
        self: TimeTrackingCursor,
 | 
						|
        action: Callable[[Query, ParamsT], None],
 | 
						|
        sql: Query,
 | 
						|
        params: ParamsT,
 | 
						|
    ) -> None:
 | 
						|
        start = time.time()
 | 
						|
        try:
 | 
						|
            return action(sql, params)
 | 
						|
        finally:
 | 
						|
            stop = time.time()
 | 
						|
            duration = stop - start
 | 
						|
            if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql:
 | 
						|
                queries.append(
 | 
						|
                    CapturedQuery(
 | 
						|
                        sql=self.mogrify(sql, params).decode(),
 | 
						|
                        time=f"{duration:.3f}",
 | 
						|
                    )
 | 
						|
                )
 | 
						|
 | 
						|
    def cursor_execute(
 | 
						|
        self: TimeTrackingCursor, sql: Query, params: Optional[Params] = None
 | 
						|
    ) -> None:
 | 
						|
        return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
 | 
						|
 | 
						|
    def cursor_executemany(self: TimeTrackingCursor, sql: Query, params: Iterable[Params]) -> None:
 | 
						|
        return wrapper_execute(
 | 
						|
            self, super(TimeTrackingCursor, self).executemany, sql, params
 | 
						|
        )  # nocoverage -- doesn't actually get used in tests
 | 
						|
 | 
						|
    if not keep_cache_warm:
 | 
						|
        cache = get_cache_backend(None)
 | 
						|
        cache.clear()
 | 
						|
    with mock.patch.multiple(
 | 
						|
        TimeTrackingCursor, execute=cursor_execute, executemany=cursor_executemany
 | 
						|
    ):
 | 
						|
        yield queries
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def stdout_suppressed() -> Iterator[IO[str]]:
 | 
						|
    """Redirect stdout to /dev/null."""
 | 
						|
 | 
						|
    with open(os.devnull, "a") as devnull:
 | 
						|
        stdout, sys.stdout = sys.stdout, devnull
 | 
						|
        try:
 | 
						|
            yield stdout
 | 
						|
        finally:
 | 
						|
            sys.stdout = stdout
 | 
						|
 | 
						|
 | 
						|
def reset_email_visibility_to_everyone_in_zulip_realm() -> None:
 | 
						|
    """
 | 
						|
    This function is used to reset email visibility for all users and
 | 
						|
    RealmUserDefault object in the zulip realm in development environment
 | 
						|
    to "EMAIL_ADDRESS_VISIBILITY_EVERYONE" since the default value is
 | 
						|
    "EMAIL_ADDRESS_VISIBILITY_ADMINS". This function is needed in
 | 
						|
    tests that want "email" field of users to be set to their real email.
 | 
						|
    """
 | 
						|
    realm = get_realm("zulip")
 | 
						|
    realm_user_default = RealmUserDefault.objects.get(realm=realm)
 | 
						|
    do_set_realm_user_default_setting(
 | 
						|
        realm_user_default,
 | 
						|
        "email_address_visibility",
 | 
						|
        RealmUserDefault.EMAIL_ADDRESS_VISIBILITY_EVERYONE,
 | 
						|
        acting_user=None,
 | 
						|
    )
 | 
						|
    users = UserProfile.objects.filter(realm=realm)
 | 
						|
    for user in users:
 | 
						|
        do_change_user_setting(
 | 
						|
            user,
 | 
						|
            "email_address_visibility",
 | 
						|
            UserProfile.EMAIL_ADDRESS_VISIBILITY_EVERYONE,
 | 
						|
            acting_user=None,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def get_test_image_file(filename: str) -> IO[bytes]:
 | 
						|
    test_avatar_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../tests/images"))
 | 
						|
    return open(os.path.join(test_avatar_dir, filename), "rb")  # noqa: SIM115
 | 
						|
 | 
						|
 | 
						|
def read_test_image_file(filename: str) -> bytes:
 | 
						|
    with get_test_image_file(filename) as img_file:
 | 
						|
        return img_file.read()
 | 
						|
 | 
						|
 | 
						|
def avatar_disk_path(
 | 
						|
    user_profile: UserProfile, medium: bool = False, original: bool = False
 | 
						|
) -> str:
 | 
						|
    avatar_url_path = avatar_url(user_profile, medium)
 | 
						|
    assert avatar_url_path is not None
 | 
						|
    assert settings.LOCAL_UPLOADS_DIR is not None
 | 
						|
    assert settings.LOCAL_AVATARS_DIR is not None
 | 
						|
    avatar_disk_path = os.path.join(
 | 
						|
        settings.LOCAL_AVATARS_DIR,
 | 
						|
        avatar_url_path.split("/")[-2],
 | 
						|
        avatar_url_path.split("/")[-1].split("?")[0],
 | 
						|
    )
 | 
						|
    if original:
 | 
						|
        return avatar_disk_path.replace(".png", ".original")
 | 
						|
    return avatar_disk_path
 | 
						|
 | 
						|
 | 
						|
def make_client(name: str) -> Client:
 | 
						|
    client, _ = Client.objects.get_or_create(name=name)
 | 
						|
    return client
 | 
						|
 | 
						|
 | 
						|
def find_key_by_email(address: str) -> Optional[str]:
 | 
						|
    from django.core.mail import outbox
 | 
						|
 | 
						|
    key_regex = re.compile("accounts/do_confirm/([a-z0-9]{24})>")
 | 
						|
    for message in reversed(outbox):
 | 
						|
        if address in message.to:
 | 
						|
            match = key_regex.search(str(message.body))
 | 
						|
            assert match is not None
 | 
						|
            [key] = match.groups()
 | 
						|
            return key
 | 
						|
    return None  # nocoverage -- in theory a test might want this case, but none do
 | 
						|
 | 
						|
 | 
						|
def message_stream_count(user_profile: UserProfile) -> int:
 | 
						|
    return UserMessage.objects.select_related("message").filter(user_profile=user_profile).count()
 | 
						|
 | 
						|
 | 
						|
def most_recent_usermessage(user_profile: UserProfile) -> UserMessage:
 | 
						|
    query = (
 | 
						|
        UserMessage.objects.select_related("message")
 | 
						|
        .filter(user_profile=user_profile)
 | 
						|
        .order_by("-message")
 | 
						|
    )
 | 
						|
    return query[0]  # Django does LIMIT here
 | 
						|
 | 
						|
 | 
						|
def most_recent_message(user_profile: UserProfile) -> Message:
 | 
						|
    usermessage = most_recent_usermessage(user_profile)
 | 
						|
    return usermessage.message
 | 
						|
 | 
						|
 | 
						|
def get_subscription(stream_name: str, user_profile: UserProfile) -> Subscription:
 | 
						|
    stream = get_stream(stream_name, user_profile.realm)
 | 
						|
    recipient_id = stream.recipient_id
 | 
						|
    assert recipient_id is not None
 | 
						|
    return Subscription.objects.get(
 | 
						|
        user_profile=user_profile, recipient_id=recipient_id, active=True
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def get_user_messages(user_profile: UserProfile) -> List[Message]:
 | 
						|
    query = (
 | 
						|
        UserMessage.objects.select_related("message")
 | 
						|
        .filter(user_profile=user_profile)
 | 
						|
        .order_by("message")
 | 
						|
    )
 | 
						|
    return [um.message for um in query]
 | 
						|
 | 
						|
 | 
						|
class DummyHandler(AsyncDjangoHandler):
 | 
						|
    def __init__(self) -> None:
 | 
						|
        self.handler_id = allocate_handler_id(self)
 | 
						|
 | 
						|
 | 
						|
dummy_handler = DummyHandler()
 | 
						|
 | 
						|
 | 
						|
class HostRequestMock(HttpRequest):
 | 
						|
    """A mock request object where get_host() works.  Useful for testing
 | 
						|
    routes that use Zulip's subdomains feature"""
 | 
						|
 | 
						|
    # The base class HttpRequest declares GET and POST as immutable
 | 
						|
    # QueryDict objects. The implementation of HostRequestMock
 | 
						|
    # requires POST to be mutable, and we have some use cases that
 | 
						|
    # modify GET, so GET and POST are both redeclared as mutable.
 | 
						|
 | 
						|
    GET: QueryDict  # type: ignore[assignment] # See previous comment.
 | 
						|
    POST: QueryDict  # type: ignore[assignment] # See previous comment.
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        post_data: Mapping[str, Any] = {},
 | 
						|
        user_profile: Union[UserProfile, None] = None,
 | 
						|
        remote_server: Optional[RemoteZulipServer] = None,
 | 
						|
        host: str = settings.EXTERNAL_HOST,
 | 
						|
        client_name: Optional[str] = None,
 | 
						|
        meta_data: Optional[Dict[str, Any]] = None,
 | 
						|
        tornado_handler: Optional[AsyncDjangoHandler] = None,
 | 
						|
        path: str = "",
 | 
						|
    ) -> None:
 | 
						|
        self.host = host
 | 
						|
        self.GET = QueryDict(mutable=True)
 | 
						|
        self.method = ""
 | 
						|
 | 
						|
        # Convert any integer parameters passed into strings, even
 | 
						|
        # though of course the HTTP API would do so.  Ideally, we'd
 | 
						|
        # get rid of this abstraction entirely and just use the HTTP
 | 
						|
        # API directly, but while it exists, we need this code
 | 
						|
        self.POST = QueryDict(mutable=True)
 | 
						|
        for key in post_data:
 | 
						|
            self.POST[key] = str(post_data[key])
 | 
						|
            self.method = "POST"
 | 
						|
 | 
						|
        if meta_data is None:
 | 
						|
            self.META = {"PATH_INFO": "test"}
 | 
						|
        else:
 | 
						|
            self.META = meta_data
 | 
						|
        self.path = path
 | 
						|
        self.user = user_profile or AnonymousUser()
 | 
						|
        self._body = b""
 | 
						|
        self.content_type = ""
 | 
						|
 | 
						|
        RequestNotes.set_notes(
 | 
						|
            self,
 | 
						|
            RequestNotes(
 | 
						|
                client_name="",
 | 
						|
                log_data={},
 | 
						|
                tornado_handler_id=None if tornado_handler is None else tornado_handler.handler_id,
 | 
						|
                client=get_client(client_name) if client_name is not None else None,
 | 
						|
                remote_server=remote_server,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
    @property
 | 
						|
    def body(self) -> bytes:
 | 
						|
        return super().body
 | 
						|
 | 
						|
    @body.setter
 | 
						|
    def body(self, val: bytes) -> None:
 | 
						|
        self._body = val
 | 
						|
 | 
						|
    def get_host(self) -> str:
 | 
						|
        return self.host
 | 
						|
 | 
						|
 | 
						|
INSTRUMENTING = os.environ.get("TEST_INSTRUMENT_URL_COVERAGE", "") == "TRUE"
 | 
						|
INSTRUMENTED_CALLS: List[Dict[str, Any]] = []
 | 
						|
 | 
						|
UrlFuncT = TypeVar("UrlFuncT", bound=Callable[..., HttpResponseBase])  # TODO: make more specific
 | 
						|
 | 
						|
 | 
						|
def append_instrumentation_data(data: Dict[str, Any]) -> None:
 | 
						|
    INSTRUMENTED_CALLS.append(data)
 | 
						|
 | 
						|
 | 
						|
def instrument_url(f: UrlFuncT) -> UrlFuncT:
 | 
						|
    # TODO: Type this with ParamSpec to preserve the function signature.
 | 
						|
    if not INSTRUMENTING:  # nocoverage -- option is always enabled; should we remove?
 | 
						|
        return f
 | 
						|
    else:
 | 
						|
 | 
						|
        def wrapper(
 | 
						|
            self: "ZulipTestCase", url: str, info: object = {}, **kwargs: Union[bool, str]
 | 
						|
        ) -> HttpResponseBase:
 | 
						|
            start = time.time()
 | 
						|
            result = f(self, url, info, **kwargs)
 | 
						|
            delay = time.time() - start
 | 
						|
            test_name = self.id()
 | 
						|
            if "?" in url:
 | 
						|
                url, extra_info = url.split("?", 1)
 | 
						|
            else:
 | 
						|
                extra_info = ""
 | 
						|
 | 
						|
            if isinstance(info, HostRequestMock):
 | 
						|
                info = "<HostRequestMock>"
 | 
						|
            elif isinstance(info, bytes):
 | 
						|
                info = "<bytes>"
 | 
						|
            elif isinstance(info, dict):
 | 
						|
                info = {
 | 
						|
                    k: "<file object>" if hasattr(v, "read") and callable(v.read) else v
 | 
						|
                    for k, v in info.items()
 | 
						|
                }
 | 
						|
 | 
						|
            append_instrumentation_data(
 | 
						|
                dict(
 | 
						|
                    url=url,
 | 
						|
                    status_code=result.status_code,
 | 
						|
                    method=f.__name__,
 | 
						|
                    delay=delay,
 | 
						|
                    extra_info=extra_info,
 | 
						|
                    info=info,
 | 
						|
                    test_name=test_name,
 | 
						|
                    kwargs=kwargs,
 | 
						|
                )
 | 
						|
            )
 | 
						|
            return result
 | 
						|
 | 
						|
        return cast(UrlFuncT, wrapper)  # https://github.com/python/mypy/issues/1927
 | 
						|
 | 
						|
 | 
						|
def write_instrumentation_reports(full_suite: bool, include_webhooks: bool) -> None:
 | 
						|
    if INSTRUMENTING:
 | 
						|
        calls = INSTRUMENTED_CALLS
 | 
						|
 | 
						|
        from zproject.urls import urlpatterns, v1_api_and_json_patterns
 | 
						|
 | 
						|
        # Find our untested urls.
 | 
						|
        pattern_cnt: Dict[str, int] = collections.defaultdict(int)
 | 
						|
 | 
						|
        def re_strip(r: str) -> str:
 | 
						|
            assert r.startswith(r"^")
 | 
						|
            if r.endswith(r"$"):
 | 
						|
                return r[1:-1]
 | 
						|
            else:
 | 
						|
                assert r.endswith(r"\Z")
 | 
						|
                return r[1:-2]
 | 
						|
 | 
						|
        def find_patterns(patterns: List[Any], prefixes: List[str]) -> None:
 | 
						|
            for pattern in patterns:
 | 
						|
                find_pattern(pattern, prefixes)
 | 
						|
 | 
						|
        def cleanup_url(url: str) -> str:
 | 
						|
            if url.startswith("/"):
 | 
						|
                url = url[1:]
 | 
						|
            if url.startswith("http://testserver/"):
 | 
						|
                url = url[len("http://testserver/") :]
 | 
						|
            if url.startswith("http://zulip.testserver/"):
 | 
						|
                url = url[len("http://zulip.testserver/") :]
 | 
						|
            if url.startswith("http://testserver:9080/"):
 | 
						|
                url = url[len("http://testserver:9080/") :]
 | 
						|
            return url
 | 
						|
 | 
						|
        def find_pattern(pattern: Any, prefixes: List[str]) -> None:
 | 
						|
            if isinstance(pattern, type(URLResolver)):
 | 
						|
                return  # nocoverage -- shouldn't actually happen
 | 
						|
 | 
						|
            if hasattr(pattern, "url_patterns"):
 | 
						|
                return
 | 
						|
 | 
						|
            canon_pattern = prefixes[0] + re_strip(pattern.pattern.regex.pattern)
 | 
						|
            cnt = 0
 | 
						|
            for call in calls:
 | 
						|
                if "pattern" in call:
 | 
						|
                    continue
 | 
						|
 | 
						|
                url = cleanup_url(call["url"])
 | 
						|
 | 
						|
                for prefix in prefixes:
 | 
						|
                    if url.startswith(prefix):
 | 
						|
                        match_url = url[len(prefix) :]
 | 
						|
                        if pattern.resolve(match_url):
 | 
						|
                            if call["status_code"] in [200, 204, 301, 302]:
 | 
						|
                                cnt += 1
 | 
						|
                            call["pattern"] = canon_pattern
 | 
						|
            pattern_cnt[canon_pattern] += cnt
 | 
						|
 | 
						|
        find_patterns(urlpatterns, ["", "en/", "de/"])
 | 
						|
        find_patterns(v1_api_and_json_patterns, ["api/v1/", "json/"])
 | 
						|
 | 
						|
        assert len(pattern_cnt) > 100
 | 
						|
        untested_patterns = {p.replace("\\", "") for p in pattern_cnt if pattern_cnt[p] == 0}
 | 
						|
 | 
						|
        exempt_patterns = {
 | 
						|
            # We exempt some patterns that are called via Tornado.
 | 
						|
            "api/v1/events",
 | 
						|
            "api/v1/events/internal",
 | 
						|
            "api/v1/register",
 | 
						|
            # We also exempt some development environment debugging
 | 
						|
            # static content URLs, since the content they point to may
 | 
						|
            # or may not exist.
 | 
						|
            "coverage/(?P<path>.+)",
 | 
						|
            "confirmation_key/",
 | 
						|
            "node-coverage/(?P<path>.+)",
 | 
						|
            "docs/(?P<path>.+)",
 | 
						|
            "casper/(?P<path>.+)",
 | 
						|
            "static/(?P<path>.+)",
 | 
						|
            "flush_caches",
 | 
						|
            "external_content/(?P<digest>[^/]+)/(?P<received_url>[^/]+)",
 | 
						|
            # These are SCIM2 urls overridden from django-scim2 to return Not Implemented.
 | 
						|
            # We actually test them, but it's not being detected as a tested pattern,
 | 
						|
            # possibly due to the use of re_path. TODO: Investigate and get them
 | 
						|
            # recognized as tested.
 | 
						|
            "scim/v2/",
 | 
						|
            "scim/v2/.search",
 | 
						|
            "scim/v2/Bulk",
 | 
						|
            "scim/v2/Me",
 | 
						|
            "scim/v2/ResourceTypes(?:/(?P<uuid>[^/]+))?",
 | 
						|
            "scim/v2/Schemas(?:/(?P<uuid>[^/]+))?",
 | 
						|
            "scim/v2/ServiceProviderConfig",
 | 
						|
            "scim/v2/Groups(?:/(?P<uuid>[^/]+))?",
 | 
						|
            "scim/v2/Groups/.search",
 | 
						|
            *(webhook.url for webhook in WEBHOOK_INTEGRATIONS if not include_webhooks),
 | 
						|
        }
 | 
						|
 | 
						|
        untested_patterns -= exempt_patterns
 | 
						|
 | 
						|
        var_dir = "var"  # TODO make sure path is robust here
 | 
						|
        fn = os.path.join(var_dir, "url_coverage.txt")
 | 
						|
        with open(fn, "wb") as f:
 | 
						|
            for call in calls:
 | 
						|
                f.write(orjson.dumps(call, option=orjson.OPT_APPEND_NEWLINE))
 | 
						|
 | 
						|
        if full_suite:
 | 
						|
            print(f"INFO: URL coverage report is in {fn}")
 | 
						|
 | 
						|
        if full_suite and len(untested_patterns):  # nocoverage -- test suite error handling
 | 
						|
            print("\nERROR: Some URLs are untested!  Here's the list of untested URLs:")
 | 
						|
            for untested_pattern in sorted(untested_patterns):
 | 
						|
                print(f"   {untested_pattern}")
 | 
						|
            sys.exit(1)
 | 
						|
 | 
						|
 | 
						|
def load_subdomain_token(response: Union["TestHttpResponse", HttpResponse]) -> ExternalAuthDataDict:
 | 
						|
    assert isinstance(response, HttpResponseRedirect)
 | 
						|
    token = response.url.rsplit("/", 1)[1]
 | 
						|
    data = ExternalAuthResult(login_token=token, delete_stored_data=False).data_dict
 | 
						|
    assert data is not None
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
FuncT = TypeVar("FuncT", bound=Callable[..., None])
 | 
						|
 | 
						|
 | 
						|
def use_s3_backend(method: FuncT) -> FuncT:
 | 
						|
    @mock_s3
 | 
						|
    @override_settings(LOCAL_UPLOADS_DIR=None)
 | 
						|
    @override_settings(LOCAL_AVATARS_DIR=None)
 | 
						|
    @override_settings(LOCAL_FILES_DIR=None)
 | 
						|
    def new_method(*args: Any, **kwargs: Any) -> Any:
 | 
						|
        with mock.patch("zerver.lib.upload.upload_backend", S3UploadBackend()):
 | 
						|
            return method(*args, **kwargs)
 | 
						|
 | 
						|
    return new_method
 | 
						|
 | 
						|
 | 
						|
def create_s3_buckets(*bucket_names: str) -> List[Bucket]:
 | 
						|
    session = boto3.session.Session(settings.S3_KEY, settings.S3_SECRET_KEY)
 | 
						|
    s3 = session.resource("s3")
 | 
						|
    buckets = [s3.create_bucket(Bucket=name) for name in bucket_names]
 | 
						|
    return buckets
 | 
						|
 | 
						|
 | 
						|
TestCaseT = TypeVar("TestCaseT", bound="MigrationsTestCase")
 | 
						|
 | 
						|
 | 
						|
def use_db_models(
 | 
						|
    method: Callable[[TestCaseT, StateApps], None]
 | 
						|
) -> Callable[[TestCaseT, StateApps], None]:  # nocoverage
 | 
						|
    def method_patched_with_mock(self: TestCaseT, apps: StateApps) -> None:
 | 
						|
        ArchivedAttachment = apps.get_model("zerver", "ArchivedAttachment")
 | 
						|
        ArchivedMessage = apps.get_model("zerver", "ArchivedMessage")
 | 
						|
        ArchivedUserMessage = apps.get_model("zerver", "ArchivedUserMessage")
 | 
						|
        Attachment = apps.get_model("zerver", "Attachment")
 | 
						|
        BotConfigData = apps.get_model("zerver", "BotConfigData")
 | 
						|
        BotStorageData = apps.get_model("zerver", "BotStorageData")
 | 
						|
        Client = apps.get_model("zerver", "Client")
 | 
						|
        CustomProfileField = apps.get_model("zerver", "CustomProfileField")
 | 
						|
        CustomProfileFieldValue = apps.get_model("zerver", "CustomProfileFieldValue")
 | 
						|
        DefaultStream = apps.get_model("zerver", "DefaultStream")
 | 
						|
        DefaultStreamGroup = apps.get_model("zerver", "DefaultStreamGroup")
 | 
						|
        EmailChangeStatus = apps.get_model("zerver", "EmailChangeStatus")
 | 
						|
        Huddle = apps.get_model("zerver", "Huddle")
 | 
						|
        Message = apps.get_model("zerver", "Message")
 | 
						|
        MultiuseInvite = apps.get_model("zerver", "MultiuseInvite")
 | 
						|
        UserTopic = apps.get_model("zerver", "UserTopic")
 | 
						|
        PreregistrationUser = apps.get_model("zerver", "PreregistrationUser")
 | 
						|
        PushDeviceToken = apps.get_model("zerver", "PushDeviceToken")
 | 
						|
        Reaction = apps.get_model("zerver", "Reaction")
 | 
						|
        Realm = apps.get_model("zerver", "Realm")
 | 
						|
        RealmAuditLog = apps.get_model("zerver", "RealmAuditLog")
 | 
						|
        RealmDomain = apps.get_model("zerver", "RealmDomain")
 | 
						|
        RealmEmoji = apps.get_model("zerver", "RealmEmoji")
 | 
						|
        RealmFilter = apps.get_model("zerver", "RealmFilter")
 | 
						|
        Recipient = apps.get_model("zerver", "Recipient")
 | 
						|
        Recipient.PERSONAL = 1
 | 
						|
        Recipient.STREAM = 2
 | 
						|
        Recipient.HUDDLE = 3
 | 
						|
        ScheduledEmail = apps.get_model("zerver", "ScheduledEmail")
 | 
						|
        ScheduledMessage = apps.get_model("zerver", "ScheduledMessage")
 | 
						|
        Service = apps.get_model("zerver", "Service")
 | 
						|
        Stream = apps.get_model("zerver", "Stream")
 | 
						|
        Subscription = apps.get_model("zerver", "Subscription")
 | 
						|
        UserActivity = apps.get_model("zerver", "UserActivity")
 | 
						|
        UserActivityInterval = apps.get_model("zerver", "UserActivityInterval")
 | 
						|
        UserGroup = apps.get_model("zerver", "UserGroup")
 | 
						|
        UserGroupMembership = apps.get_model("zerver", "UserGroupMembership")
 | 
						|
        UserHotspot = apps.get_model("zerver", "UserHotspot")
 | 
						|
        UserMessage = apps.get_model("zerver", "UserMessage")
 | 
						|
        UserPresence = apps.get_model("zerver", "UserPresence")
 | 
						|
        UserProfile = apps.get_model("zerver", "UserProfile")
 | 
						|
 | 
						|
        zerver_models_patch = mock.patch.multiple(
 | 
						|
            "zerver.models",
 | 
						|
            ArchivedAttachment=ArchivedAttachment,
 | 
						|
            ArchivedMessage=ArchivedMessage,
 | 
						|
            ArchivedUserMessage=ArchivedUserMessage,
 | 
						|
            Attachment=Attachment,
 | 
						|
            BotConfigData=BotConfigData,
 | 
						|
            BotStorageData=BotStorageData,
 | 
						|
            Client=Client,
 | 
						|
            CustomProfileField=CustomProfileField,
 | 
						|
            CustomProfileFieldValue=CustomProfileFieldValue,
 | 
						|
            DefaultStream=DefaultStream,
 | 
						|
            DefaultStreamGroup=DefaultStreamGroup,
 | 
						|
            EmailChangeStatus=EmailChangeStatus,
 | 
						|
            Huddle=Huddle,
 | 
						|
            Message=Message,
 | 
						|
            MultiuseInvite=MultiuseInvite,
 | 
						|
            UserTopic=UserTopic,
 | 
						|
            PreregistrationUser=PreregistrationUser,
 | 
						|
            PushDeviceToken=PushDeviceToken,
 | 
						|
            Reaction=Reaction,
 | 
						|
            Realm=Realm,
 | 
						|
            RealmAuditLog=RealmAuditLog,
 | 
						|
            RealmDomain=RealmDomain,
 | 
						|
            RealmEmoji=RealmEmoji,
 | 
						|
            RealmFilter=RealmFilter,
 | 
						|
            Recipient=Recipient,
 | 
						|
            ScheduledEmail=ScheduledEmail,
 | 
						|
            ScheduledMessage=ScheduledMessage,
 | 
						|
            Service=Service,
 | 
						|
            Stream=Stream,
 | 
						|
            Subscription=Subscription,
 | 
						|
            UserActivity=UserActivity,
 | 
						|
            UserActivityInterval=UserActivityInterval,
 | 
						|
            UserGroup=UserGroup,
 | 
						|
            UserGroupMembership=UserGroupMembership,
 | 
						|
            UserHotspot=UserHotspot,
 | 
						|
            UserMessage=UserMessage,
 | 
						|
            UserPresence=UserPresence,
 | 
						|
            UserProfile=UserProfile,
 | 
						|
        )
 | 
						|
        zerver_test_helpers_patch = mock.patch.multiple(
 | 
						|
            "zerver.lib.test_helpers",
 | 
						|
            Client=Client,
 | 
						|
            Message=Message,
 | 
						|
            Subscription=Subscription,
 | 
						|
            UserMessage=UserMessage,
 | 
						|
            UserProfile=UserProfile,
 | 
						|
        )
 | 
						|
 | 
						|
        zerver_test_classes_patch = mock.patch.multiple(
 | 
						|
            "zerver.lib.test_classes",
 | 
						|
            Client=Client,
 | 
						|
            Message=Message,
 | 
						|
            Realm=Realm,
 | 
						|
            Recipient=Recipient,
 | 
						|
            Stream=Stream,
 | 
						|
            Subscription=Subscription,
 | 
						|
            UserProfile=UserProfile,
 | 
						|
        )
 | 
						|
 | 
						|
        with zerver_models_patch, zerver_test_helpers_patch, zerver_test_classes_patch:
 | 
						|
            method(self, apps)
 | 
						|
 | 
						|
    return method_patched_with_mock
 | 
						|
 | 
						|
 | 
						|
def create_dummy_file(filename: str) -> str:
 | 
						|
    filepath = os.path.join(settings.TEST_WORKER_DIR, filename)
 | 
						|
    with open(filepath, "w") as f:
 | 
						|
        f.write("zulip!")
 | 
						|
    return filepath
 | 
						|
 | 
						|
 | 
						|
def zulip_reaction_info() -> Dict[str, str]:
 | 
						|
    return dict(
 | 
						|
        emoji_name="zulip",
 | 
						|
        emoji_code="zulip",
 | 
						|
        reaction_type="zulip_extra_emoji",
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def mock_queue_publish(
 | 
						|
    method_to_patch: str,
 | 
						|
    **kwargs: object,
 | 
						|
) -> Iterator[mock.MagicMock]:
 | 
						|
    inner = mock.MagicMock(**kwargs)
 | 
						|
 | 
						|
    # This helper ensures that events published to the queues are
 | 
						|
    # serializable as JSON; unserializable events would make RabbitMQ
 | 
						|
    # crash in production.
 | 
						|
    def verify_serialize(
 | 
						|
        queue_name: str,
 | 
						|
        event: Dict[str, object],
 | 
						|
        processor: Optional[Callable[[object], None]] = None,
 | 
						|
    ) -> None:
 | 
						|
        marshalled_event = orjson.loads(orjson.dumps(event))
 | 
						|
        assert marshalled_event == event
 | 
						|
        inner(queue_name, event, processor)
 | 
						|
 | 
						|
    with mock.patch(method_to_patch, side_effect=verify_serialize):
 | 
						|
        yield inner
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def timeout_mock(mock_path: str) -> Iterator[None]:
 | 
						|
    # timeout() doesn't work in test environment with database operations
 | 
						|
    # and they don't get committed - so we need to replace it with a mock
 | 
						|
    # that just calls the function.
 | 
						|
    def mock_timeout(seconds: int, func: Callable[[], object]) -> object:
 | 
						|
        return func()
 | 
						|
 | 
						|
    with mock.patch(f"{mock_path}.timeout", new=mock_timeout):
 | 
						|
        yield
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def ratelimit_rule(
 | 
						|
    range_seconds: int,
 | 
						|
    num_requests: int,
 | 
						|
    domain: str = "api_by_user",
 | 
						|
) -> Iterator[None]:
 | 
						|
    """Temporarily add a rate-limiting rule to the ratelimiter"""
 | 
						|
    RateLimitedIPAddr("127.0.0.1", domain=domain).clear_history()
 | 
						|
 | 
						|
    domain_rules = rules.get(domain, []).copy()
 | 
						|
    domain_rules.append((range_seconds, num_requests))
 | 
						|
    domain_rules.sort(key=lambda x: x[0])
 | 
						|
 | 
						|
    with patch.dict(rules, {domain: domain_rules}), override_settings(RATE_LIMITING=True):
 | 
						|
        yield
 |