zerver/lib: Use python 3 syntax for typing.

With tweaks by tabbott to fix line spacing.
This commit is contained in:
rht
2017-11-05 11:15:10 +01:00
committed by Tim Abbott
parent 4a07214725
commit 33b1a541d7
30 changed files with 328 additions and 544 deletions

View File

@@ -3,22 +3,19 @@ import ujson
import random
from typing import List, Dict, Any, Text, Optional
def load_config():
# type: () -> Dict [str, Any]
def load_config() -> Dict[str, Any]:
with open("zerver/fixtures/config.generate_data.json", "r") as infile:
config = ujson.load(infile)
return config
def get_stream_title(gens):
# type: (Dict[str, Any]) -> str
def get_stream_title(gens: Dict[str, Any]) -> str:
return next(gens["adjectives"]) + " " + next(gens["nouns"]) + " " + \
next(gens["connectors"]) + " " + next(gens["verbs"]) + " " + \
next(gens["adverbs"])
def load_generators(config):
# type: (Dict[str, Any]) -> Dict[str, Any]
def load_generators(config: Dict[str, Any]) -> Dict[str, Any]:
results = {}
cfg = config["gen_fodder"]
@@ -40,8 +37,7 @@ def load_generators(config):
return results
def parse_file(config, gens, corpus_file):
# type: (Dict[str, Any], Dict[str, Any], str) -> List[str]
def parse_file(config: Dict[str, Any], gens: Dict[str, Any], corpus_file: str) -> List[str]:
# First, load the entire file into a dictionary,
# then apply our custom filters to it as needed.
@@ -55,8 +51,7 @@ def parse_file(config, gens, corpus_file):
return paragraphs
def get_flair_gen(length):
# type: (int) -> List[str]
def get_flair_gen(length: int) -> List[str]:
# Grab the percentages from the config file
# create a list that we can consume that will guarantee the distribution
@@ -70,8 +65,7 @@ def get_flair_gen(length):
random.shuffle(result)
return result
def add_flair(paragraphs, gens):
# type: (List[str], Dict[str, Any]) -> List[str]
def add_flair(paragraphs: List[str], gens: Dict[str, Any]) -> List[str]:
# roll the dice and see what kind of flair we should add, if any
results = []
@@ -111,8 +105,7 @@ def add_flair(paragraphs, gens):
return results
def add_md(mode, text):
# type: (str, str) -> str
def add_md(mode: str, text: str) -> str:
# mode means: bold, italic, etc.
# to add a list at the end of a paragraph, * iterm one\n * item two
@@ -127,8 +120,7 @@ def add_md(mode, text):
return " ".join(vals).strip()
def add_emoji(text, emoji):
# type: (str, str) -> str
def add_emoji(text: str, emoji: str) -> str:
vals = text.split()
start = random.randrange(len(vals))
@@ -136,8 +128,7 @@ def add_emoji(text, emoji):
vals[start] = vals[start] + " " + emoji + " "
return " ".join(vals)
def add_link(text, link):
# type: (str, str) -> str
def add_link(text: str, link: str) -> str:
vals = text.split()
start = random.randrange(len(vals))
@@ -146,8 +137,7 @@ def add_link(text, link):
return " ".join(vals)
def remove_line_breaks(fh):
# type: (Any) -> List[str]
def remove_line_breaks(fh: Any) -> List[str]:
# We're going to remove line breaks from paragraphs
results = [] # save the dialogs as tuples with (author, dialog)
@@ -168,14 +158,12 @@ def remove_line_breaks(fh):
return results
def write_file(paragraphs, filename):
# type: (List[str], str) -> None
def write_file(paragraphs: List[str], filename: str) -> None:
with open(filename, "w") as outfile:
outfile.write(ujson.dumps(paragraphs))
def create_test_data():
# type: () -> None
def create_test_data() -> None:
gens = load_generators(config) # returns a dictionary of generators

View File

@@ -27,8 +27,7 @@ ALL_HOTSPOTS = {
},
} # type: Dict[str, Dict[str, Text]]
def get_next_hotspots(user):
# type: (UserProfile) -> List[Dict[str, object]]
def get_next_hotspots(user: UserProfile) -> List[Dict[str, object]]:
# Only used for manual testing
SEND_ALL = False
if settings.DEVELOPMENT and SEND_ALL:

View File

@@ -3,12 +3,10 @@ import lxml
from lxml.html.diff import htmldiff
from typing import Optional
def highlight_with_class(text, klass):
# type: (str, str) -> str
def highlight_with_class(text: str, klass: str) -> str:
return '<span class="%s">%s</span>' % (klass, text)
def highlight_html_differences(s1, s2, msg_id=None):
# type: (str, str, Optional[int]) -> str
def highlight_html_differences(s1: str, s2: str, msg_id: Optional[int]=None) -> str:
retval = htmldiff(s1, s2)
fragment = lxml.html.fromstring(retval) # type: ignore # https://github.com/python/typeshed/issues/525

View File

@@ -7,8 +7,7 @@ import base64
from typing import Optional, Text
def initial_password(email):
# type: (Text) -> Optional[Text]
def initial_password(email: Text) -> Optional[Text]:
"""Given an email address, returns the initial password for that account, as
created by populate_db."""

View File

@@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Text, List
from zerver.models import Realm, UserProfile
def is_integer_string(val):
# type: (str) -> bool
def is_integer_string(val: str) -> bool:
try:
int(val)
return True
@@ -18,8 +17,8 @@ def is_integer_string(val):
return False
class ZulipBaseCommand(BaseCommand):
def add_realm_args(self, parser, required=False, help=None):
# type: (ArgumentParser, bool, Optional[str]) -> None
def add_realm_args(self, parser: ArgumentParser, required: bool=False,
help: Optional[str]=None) -> None:
if help is None:
help = """The numeric or string ID (subdomain) of the Zulip organization to modify.
You can use the command list_realms to find ID of the realms in this server."""
@@ -31,8 +30,11 @@ You can use the command list_realms to find ID of the realms in this server."""
type=str,
help=help)
def add_user_list_args(self, parser, required=False, help=None, all_users_arg=True, all_users_help=None):
# type: (ArgumentParser, bool, Optional[str], bool, Optional[str]) -> None
def add_user_list_args(self, parser: ArgumentParser,
required: bool=False,
help: Optional[str]=None,
all_users_arg: bool=True,
all_users_help: Optional[str]=None) -> None:
if help is None:
help = 'A comma-separated list of email addresses.'
@@ -54,8 +56,7 @@ You can use the command list_realms to find ID of the realms in this server."""
default=False,
help=all_users_help)
def get_realm(self, options):
# type: (Dict[str, Any]) -> Optional[Realm]
def get_realm(self, options: Dict[str, Any]) -> Optional[Realm]:
val = options["realm_id"]
if val is None:
return None
@@ -71,8 +72,7 @@ You can use the command list_realms to find ID of the realms in this server."""
raise CommandError("There is no realm with id '%s'. Aborting." %
(options["realm_id"],))
def get_users(self, options, realm):
# type: (Dict[str, Any], Optional[Realm]) -> List[UserProfile]
def get_users(self, options: Dict[str, Any], realm: Optional[Realm]) -> List[UserProfile]:
if "all_users" in options:
all_users = options["all_users"]
@@ -96,8 +96,7 @@ You can use the command list_realms to find ID of the realms in this server."""
user_profiles.append(self.get_user(email, realm))
return user_profiles
def get_user(self, email, realm):
# type: (Text, Optional[Realm]) -> UserProfile
def get_user(self, email: Text, realm: Optional[Realm]) -> UserProfile:
# If a realm is specified, try to find the user there, and
# throw an error if they don't exist.

View File

@@ -12,40 +12,34 @@ import binascii
from zerver.lib.str_utils import force_str
from zerver.models import UserProfile
def xor_hex_strings(bytes_a, bytes_b):
# type: (str, str) -> str
def xor_hex_strings(bytes_a: str, bytes_b: str) -> str:
"""Given two hex strings of equal length, return a hex string with
the bitwise xor of the two hex strings."""
assert len(bytes_a) == len(bytes_b)
return ''.join(["%x" % (int(x, 16) ^ int(y, 16))
for x, y in zip(bytes_a, bytes_b)])
def ascii_to_hex(input_string):
# type: (str) -> str
def ascii_to_hex(input_string: str) -> str:
"""Given an ascii string, encode it as a hex string"""
return "".join([hex(ord(c))[2:].zfill(2) for c in input_string])
def hex_to_ascii(input_string):
# type: (str) -> str
def hex_to_ascii(input_string: str) -> str:
"""Given a hex array, decode it back to a string"""
return force_str(binascii.unhexlify(input_string))
def otp_encrypt_api_key(user_profile, otp):
# type: (UserProfile, str) -> str
def otp_encrypt_api_key(user_profile: UserProfile, otp: str) -> str:
assert len(otp) == UserProfile.API_KEY_LENGTH * 2
hex_encoded_api_key = ascii_to_hex(force_str(user_profile.api_key))
assert len(hex_encoded_api_key) == UserProfile.API_KEY_LENGTH * 2
return xor_hex_strings(hex_encoded_api_key, otp)
def otp_decrypt_api_key(otp_encrypted_api_key, otp):
# type: (str, str) -> str
def otp_decrypt_api_key(otp_encrypted_api_key: str, otp: str) -> str:
assert len(otp) == UserProfile.API_KEY_LENGTH * 2
assert len(otp_encrypted_api_key) == UserProfile.API_KEY_LENGTH * 2
hex_encoded_api_key = xor_hex_strings(otp_encrypted_api_key, otp)
return hex_to_ascii(hex_encoded_api_key)
def is_valid_otp(otp):
# type: (str) -> bool
def is_valid_otp(otp: str) -> bool:
try:
assert len(otp) == UserProfile.API_KEY_LENGTH * 2
[int(c, 16) for c in otp]

View File

@@ -1,7 +1,6 @@
from typing import Text
def is_reserved_subdomain(subdomain):
# type: (Text) -> bool
def is_reserved_subdomain(subdomain: Text) -> bool:
if subdomain in ZULIP_RESERVED_SUBDOMAINS:
return True
if subdomain[-1] == 's' and subdomain[:-1] in ZULIP_RESERVED_SUBDOMAINS:
@@ -12,8 +11,7 @@ def is_reserved_subdomain(subdomain):
return True
return False
def is_disposable_domain(domain):
# type: (Text) -> bool
def is_disposable_domain(domain: Text) -> bool:
return domain.lower() in DISPOSABLE_DOMAINS
ZULIP_RESERVED_SUBDOMAINS = frozenset([

View File

@@ -7,12 +7,12 @@ import errno
JobData = TypeVar('JobData')
def run_parallel(job, data, threads=6):
# type: (Callable[[JobData], int], Iterable[JobData], int) -> Iterator[Tuple[int, JobData]]
def run_parallel(job: Callable[[JobData], int],
data: Iterable[JobData],
threads: int=6) -> Iterator[Tuple[int, JobData]]:
pids = {} # type: Dict[int, JobData]
def wait_for_one():
# type: () -> Tuple[int, JobData]
def wait_for_one() -> Tuple[int, JobData]:
while True:
try:
(pid, status) = os.wait()
@@ -59,8 +59,7 @@ if __name__ == "__main__":
jobs = [10, 19, 18, 6, 14, 12, 8, 2, 1, 13, 3, 17, 9, 11, 5, 16, 7, 15, 4]
expected_output = [6, 10, 12, 2, 1, 14, 8, 3, 18, 19, 5, 9, 13, 11, 4, 7, 17, 16, 15]
def wait_and_print(x):
# type: (int) -> int
def wait_and_print(x: int) -> int:
time.sleep(x * 0.1)
return 0

View File

@@ -21,8 +21,7 @@ Consumer = Callable[[BlockingChannel, Basic.Deliver, pika.BasicProperties, str],
# interface for external files to put things into queues and take them
# out from bots without having to import pika code all over our codebase.
class SimpleQueueClient:
def __init__(self):
# type: () -> None
def __init__(self) -> None:
self.log = logging.getLogger('zulip.queue')
self.queues = set() # type: Set[str]
self.channel = None # type: Optional[BlockingChannel]
@@ -31,22 +30,19 @@ class SimpleQueueClient:
self.rabbitmq_heartbeat = 0 # type: Optional[int]
self._connect()
def _connect(self):
# type: () -> None
def _connect(self) -> None:
start = time.time()
self.connection = pika.BlockingConnection(self._get_parameters())
self.channel = self.connection.channel()
self.log.info('SimpleQueueClient connected (connecting took %.3fs)' % (time.time() - start,))
def _reconnect(self):
# type: () -> None
def _reconnect(self) -> None:
self.connection = None
self.channel = None
self.queues = set()
self._connect()
def _get_parameters(self):
# type: () -> pika.ConnectionParameters
def _get_parameters(self) -> pika.ConnectionParameters:
# We explicitly disable the RabbitMQ heartbeat feature, since
# it doesn't make sense with BlockingConnection
credentials = pika.PlainCredentials(settings.RABBITMQ_USERNAME,
@@ -55,34 +51,28 @@ class SimpleQueueClient:
heartbeat_interval=self.rabbitmq_heartbeat,
credentials=credentials)
def _generate_ctag(self, queue_name):
# type: (str) -> str
def _generate_ctag(self, queue_name: str) -> str:
return "%s_%s" % (queue_name, str(random.getrandbits(16)))
def _reconnect_consumer_callback(self, queue, consumer):
# type: (str, Consumer) -> None
def _reconnect_consumer_callback(self, queue: str, consumer: Consumer) -> None:
self.log.info("Queue reconnecting saved consumer %s to queue %s" % (consumer, queue))
self.ensure_queue(queue, lambda: self.channel.basic_consume(consumer,
queue=queue,
consumer_tag=self._generate_ctag(queue)))
def _reconnect_consumer_callbacks(self):
# type: () -> None
def _reconnect_consumer_callbacks(self) -> None:
for queue, consumers in self.consumers.items():
for consumer in consumers:
self._reconnect_consumer_callback(queue, consumer)
def close(self):
# type: () -> None
def close(self) -> None:
if self.connection:
self.connection.close()
def ready(self):
# type: () -> bool
def ready(self) -> bool:
return self.channel is not None
def ensure_queue(self, queue_name, callback):
# type: (str, Callable[[], None]) -> None
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
'''Ensure that a given queue has been declared, and then call
the callback with no arguments.'''
if self.connection is None or not self.connection.is_open:
@@ -93,10 +83,8 @@ class SimpleQueueClient:
self.queues.add(queue_name)
callback()
def publish(self, queue_name, body):
# type: (str, str) -> None
def do_publish():
# type: () -> None
def publish(self, queue_name: str, body: str) -> None:
def do_publish() -> None:
self.channel.basic_publish(
exchange='',
routing_key=queue_name,
@@ -107,8 +95,7 @@ class SimpleQueueClient:
self.ensure_queue(queue_name, do_publish)
def json_publish(self, queue_name, body):
# type: (str, Union[Mapping[str, Any], str]) -> None
def json_publish(self, queue_name: str, body: Union[Mapping[str, Any], str]) -> None:
# Union because of zerver.middleware.write_log_line uses a str
try:
self.publish(queue_name, ujson.dumps(body))
@@ -119,10 +106,11 @@ class SimpleQueueClient:
self._reconnect()
self.publish(queue_name, ujson.dumps(body))
def register_consumer(self, queue_name, consumer):
# type: (str, Consumer) -> None
def wrapped_consumer(ch, method, properties, body):
# type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
def register_consumer(self, queue_name: str, consumer: Consumer) -> None:
def wrapped_consumer(ch: BlockingChannel,
method: Basic.Deliver,
properties: pika.BasicProperties,
body: str) -> None:
try:
consumer(ch, method, properties, body)
ch.basic_ack(delivery_tag=method.delivery_tag)
@@ -135,20 +123,20 @@ class SimpleQueueClient:
lambda: self.channel.basic_consume(wrapped_consumer, queue=queue_name,
consumer_tag=self._generate_ctag(queue_name)))
def register_json_consumer(self, queue_name, callback):
# type: (str, Callable[[Dict[str, Any]], None]) -> None
def wrapped_callback(ch, method, properties, body):
# type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
def register_json_consumer(self, queue_name: str,
callback: Callable[[Dict[str, Any]], None]) -> None:
def wrapped_callback(ch: BlockingChannel,
method: Basic.Deliver,
properties: pika.BasicProperties,
body: str) -> None:
callback(ujson.loads(body))
self.register_consumer(queue_name, wrapped_callback)
def drain_queue(self, queue_name, json=False):
# type: (str, bool) -> List[Dict[str, Any]]
def drain_queue(self, queue_name: str, json: bool=False) -> List[Dict[str, Any]]:
"Returns all messages in the desired queue"
messages = []
def opened():
# type: () -> None
def opened() -> None:
while True:
(meta, _, message) = self.channel.basic_get(queue_name)
@@ -163,20 +151,17 @@ class SimpleQueueClient:
self.ensure_queue(queue_name, opened)
return messages
def start_consuming(self):
# type: () -> None
def start_consuming(self) -> None:
self.channel.start_consuming()
def stop_consuming(self):
# type: () -> None
def stop_consuming(self) -> None:
self.channel.stop_consuming()
# Patch pika.adapters.TornadoConnection so that a socket error doesn't
# throw an exception and disconnect the tornado process from the rabbitmq
# queue. Instead, just re-connect as usual
class ExceptionFreeTornadoConnection(pika.adapters.TornadoConnection):
def _adapter_disconnect(self):
# type: () -> None
def _adapter_disconnect(self) -> None:
try:
super()._adapter_disconnect()
except (pika.exceptions.ProbableAuthenticationError,
@@ -189,15 +174,13 @@ calling _adapter_disconnect, ignoring" % (e,))
class TornadoQueueClient(SimpleQueueClient):
# Based on:
# https://pika.readthedocs.io/en/0.9.8/examples/asynchronous_consumer_example.html
def __init__(self):
# type: () -> None
def __init__(self) -> None:
super().__init__()
# Enable rabbitmq heartbeat since TornadoConection can process them
self.rabbitmq_heartbeat = None
self._on_open_cbs = [] # type: List[Callable[[], None]]
def _connect(self, on_open_cb = None):
# type: (Optional[Callable[[], None]]) -> None
def _connect(self, on_open_cb: Optional[Callable[[], None]]=None) -> None:
self.log.info("Beginning TornadoQueueClient connection")
if on_open_cb is not None:
self._on_open_cbs.append(on_open_cb)
@@ -207,36 +190,32 @@ class TornadoQueueClient(SimpleQueueClient):
stop_ioloop_on_close = False)
self.connection.add_on_close_callback(self._on_connection_closed)
def _reconnect(self):
# type: () -> None
def _reconnect(self) -> None:
self.connection = None
self.channel = None
self.queues = set()
self._connect()
def _on_open(self, connection):
# type: (pika.Connection) -> None
def _on_open(self, connection: pika.connection.Connection) -> None:
self.connection.channel(
on_open_callback = self._on_channel_open)
def _on_channel_open(self, channel):
# type: (BlockingChannel) -> None
def _on_channel_open(self, channel: BlockingChannel) -> None:
self.channel = channel
for callback in self._on_open_cbs:
callback()
self._reconnect_consumer_callbacks()
self.log.info('TornadoQueueClient connected')
def _on_connection_closed(self, connection, reply_code, reply_text):
# type: (pika.Connection, int, str) -> None
def _on_connection_closed(self, connection: pika.connection.Connection,
reply_code: int, reply_text: str) -> None:
self.log.warning("TornadoQueueClient lost connection to RabbitMQ, reconnecting...")
from tornado import ioloop
# Try to reconnect in two seconds
retry_seconds = 2
def on_timeout():
# type: () -> None
def on_timeout() -> None:
try:
self._reconnect()
except pika.exceptions.AMQPConnectionError:
@@ -245,10 +224,8 @@ class TornadoQueueClient(SimpleQueueClient):
ioloop.IOLoop.instance().add_timeout(time.time() + retry_seconds, on_timeout)
def ensure_queue(self, queue_name, callback):
# type: (str, Callable[[], None]) -> None
def finish(frame):
# type: (Any) -> None
def ensure_queue(self, queue_name: str, callback: Callable[[], None]) -> None:
def finish(frame: Any) -> None:
self.queues.add(queue_name)
callback()
@@ -263,10 +240,11 @@ class TornadoQueueClient(SimpleQueueClient):
else:
callback()
def register_consumer(self, queue_name, consumer):
# type: (str, Consumer) -> None
def wrapped_consumer(ch, method, properties, body):
# type: (BlockingChannel, Basic.Deliver, pika.BasicProperties, str) -> None
def register_consumer(self, queue_name: str, consumer: Consumer) -> None:
def wrapped_consumer(ch: BlockingChannel,
method: Basic.Deliver,
properties: pika.BasicProperties,
body: str) -> None:
consumer(ch, method, properties, body)
ch.basic_ack(delivery_tag=method.delivery_tag)
@@ -280,8 +258,7 @@ class TornadoQueueClient(SimpleQueueClient):
consumer_tag=self._generate_ctag(queue_name)))
queue_client = None # type: Optional[SimpleQueueClient]
def get_queue_client():
# type: () -> SimpleQueueClient
def get_queue_client() -> SimpleQueueClient:
global queue_client
if queue_client is None:
if settings.RUNNING_INSIDE_TORNADO and settings.USING_RABBITMQ:
@@ -298,8 +275,10 @@ def get_queue_client():
# randomly close.
queue_lock = threading.RLock()
def queue_json_publish(queue_name, event, processor, call_consume_in_tests=False):
# type: (str, Union[Dict[str, Any], str], Callable[[Any], None], bool) -> None
def queue_json_publish(queue_name: str,
event: Union[Dict[str, Any], str],
processor: Callable[[Any], None],
call_consume_in_tests: bool=False) -> None:
# most events are dicts, but zerver.middleware.write_log_line uses a str
with queue_lock:
if settings.USING_RABBITMQ:
@@ -311,8 +290,9 @@ def queue_json_publish(queue_name, event, processor, call_consume_in_tests=False
else:
processor(event)
def retry_event(queue_name, event, failure_processor):
# type: (str, Dict[str, Any], Callable[[Dict[str, Any]], None]) -> None
def retry_event(queue_name: str,
event: Dict[str, Any],
failure_processor: Callable[[Dict[str, Any]], None]) -> None:
if 'failed_tries' not in event:
event['failed_tries'] = 0
event['failed_tries'] += 1

View File

@@ -6,12 +6,10 @@ from zerver.lib.avatar_hash import gravatar_hash, user_avatar_hash
from zerver.lib.upload import upload_backend
from zerver.models import Realm
def realm_icon_url(realm):
# type: (Realm) -> Text
def realm_icon_url(realm: Realm) -> Text:
return get_realm_icon_url(realm)
def get_realm_icon_url(realm):
# type: (Realm) -> Text
def get_realm_icon_url(realm: Realm) -> Text:
if realm.icon_source == 'U':
return upload_backend.get_realm_icon_url(realm.id, realm.icon_version)
elif settings.ENABLE_GRAVATAR:

View File

@@ -3,7 +3,6 @@ from django.conf import settings
import redis
def get_redis_client():
# type: () -> redis.StrictRedis
def get_redis_client() -> redis.StrictRedis:
return redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD, db=0)

View File

@@ -14,31 +14,29 @@ from zerver.lib.exceptions import JsonableError, ErrorCode
from django.http import HttpRequest, HttpResponse
from typing import Any
class RequestVariableMissingError(JsonableError):
code = ErrorCode.REQUEST_VARIABLE_MISSING
data_fields = ['var_name']
def __init__(self, var_name):
# type: (str) -> None
def __init__(self, var_name: str) -> None:
self.var_name = var_name # type: str
@staticmethod
def msg_format():
# type: () -> str
def msg_format() -> str:
return _("Missing '{var_name}' argument")
class RequestVariableConversionError(JsonableError):
code = ErrorCode.REQUEST_VARIABLE_INVALID
data_fields = ['var_name', 'bad_value']
def __init__(self, var_name, bad_value):
# type: (str, Any) -> None
def __init__(self, var_name: str, bad_value: Any) -> None:
self.var_name = var_name # type: str
self.bad_value = bad_value
@staticmethod
def msg_format():
# type: () -> str
def msg_format() -> str:
return _("Bad value for '{var_name}': {bad_value}")
# Used in conjunction with @has_request_variables, below
@@ -100,7 +98,7 @@ class REQ:
# expected to call json_error or json_success, as it uses json_error
# internally when it encounters an error
def has_request_variables(view_func):
# type: (Callable[[HttpRequest, *Any, **Any], HttpResponse]) -> Callable[[HttpRequest, *Any, **Any], HttpResponse]
# type: (Callable[[HttpRequest, Any, Any], HttpResponse]) -> Callable[[HttpRequest, *Any, **Any], HttpResponse]
num_params = view_func.__code__.co_argcount
if view_func.__defaults__ is None:
num_default_params = 0
@@ -121,8 +119,7 @@ def has_request_variables(view_func):
post_params.append(value)
@wraps(view_func)
def _wrapped_view_func(request, *args, **kwargs):
# type: (HttpRequest, *Any, **Any) -> HttpResponse
def _wrapped_view_func(request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
for param in post_params:
if param.func_var_name in kwargs:
continue

View File

@@ -8,8 +8,7 @@ from zerver.lib.exceptions import JsonableError
class HttpResponseUnauthorized(HttpResponse):
status_code = 401
def __init__(self, realm, www_authenticate=None):
# type: (Text, Optional[Text]) -> None
def __init__(self, realm: Text, www_authenticate: Optional[Text]=None) -> None:
HttpResponse.__init__(self)
if www_authenticate is None:
self["WWW-Authenticate"] = 'Basic realm="%s"' % (realm,)
@@ -18,35 +17,33 @@ class HttpResponseUnauthorized(HttpResponse):
else:
raise AssertionError("Invalid www_authenticate value!")
def json_unauthorized(message, www_authenticate=None):
# type: (Text, Optional[Text]) -> HttpResponse
def json_unauthorized(message: Text, www_authenticate: Optional[Text]=None) -> HttpResponse:
resp = HttpResponseUnauthorized("zulip", www_authenticate=www_authenticate)
resp.content = (ujson.dumps({"result": "error",
"msg": message}) + "\n").encode()
return resp
def json_method_not_allowed(methods):
# type: (List[Text]) -> HttpResponseNotAllowed
def json_method_not_allowed(methods: List[Text]) -> HttpResponseNotAllowed:
resp = HttpResponseNotAllowed(methods)
resp.content = ujson.dumps({"result": "error",
"msg": "Method Not Allowed",
"allowed_methods": methods}).encode()
return resp
def json_response(res_type="success", msg="", data=None, status=200):
# type: (Text, Text, Optional[Dict[str, Any]], int) -> HttpResponse
def json_response(res_type: Text="success",
msg: Text="",
data: Optional[Dict[str, Any]]=None,
status: int=200) -> HttpResponse:
content = {"result": res_type, "msg": msg}
if data is not None:
content.update(data)
return HttpResponse(content=ujson.dumps(content) + "\n",
content_type='application/json', status=status)
def json_success(data=None):
# type: (Optional[Dict[str, Any]]) -> HttpResponse
def json_success(data: Optional[Dict[str, Any]]=None) -> HttpResponse:
return json_response(data=data)
def json_response_from_error(exception):
# type: (JsonableError) -> HttpResponse
def json_response_from_error(exception: JsonableError) -> HttpResponse:
'''
This should only be needed in middleware; in app code, just raise.
@@ -59,6 +56,5 @@ def json_response_from_error(exception):
data=exception.data,
status=exception.http_status_code)
def json_error(msg, data=None, status=400):
# type: (Text, Optional[Dict[str, Any]], int) -> HttpResponse
def json_error(msg: Text, data: Optional[Dict[str, Any]]=None, status: int=400) -> HttpResponse:
return json_response(res_type="error", msg=msg, data=data, status=status)

View File

@@ -12,48 +12,40 @@ from zerver.models import Realm, UserProfile, get_user_profile_by_id
session_engine = import_module(settings.SESSION_ENGINE)
def get_session_dict_user(session_dict):
# type: (Mapping[Text, int]) -> Optional[int]
def get_session_dict_user(session_dict: Mapping[Text, int]) -> Optional[int]:
# Compare django.contrib.auth._get_user_session_key
try:
return get_user_model()._meta.pk.to_python(session_dict[SESSION_KEY])
except KeyError:
return None
def get_session_user(session):
# type: (Session) -> Optional[int]
def get_session_user(session: Session) -> Optional[int]:
return get_session_dict_user(session.get_decoded())
def user_sessions(user_profile):
# type: (UserProfile) -> List[Session]
def user_sessions(user_profile: UserProfile) -> List[Session]:
return [s for s in Session.objects.all()
if get_session_user(s) == user_profile.id]
def delete_session(session):
# type: (Session) -> None
def delete_session(session: Session) -> None:
session_engine.SessionStore(session.session_key).delete() # type: ignore # import_module
def delete_user_sessions(user_profile):
# type: (UserProfile) -> None
def delete_user_sessions(user_profile: UserProfile) -> None:
for session in Session.objects.all():
if get_session_user(session) == user_profile.id:
delete_session(session)
def delete_realm_user_sessions(realm):
# type: (Realm) -> None
def delete_realm_user_sessions(realm: Realm) -> None:
realm_user_ids = [user_profile.id for user_profile in
UserProfile.objects.filter(realm=realm)]
for session in Session.objects.filter(expire_date__gte=timezone_now()):
if get_session_user(session) in realm_user_ids:
delete_session(session)
def delete_all_user_sessions():
# type: () -> None
def delete_all_user_sessions() -> None:
for session in Session.objects.all():
delete_session(session)
def delete_all_deactivated_user_sessions():
# type: () -> None
def delete_all_deactivated_user_sessions() -> None:
for session in Session.objects.all():
user_profile_id = get_session_user(session)
if user_profile_id is None:

View File

@@ -17,13 +17,11 @@ class StreamRecipientMap:
Note that this class uses raw SQL, because we want to highly
optimize page loads.
'''
def __init__(self):
# type: () -> None
def __init__(self) -> None:
self.recip_to_stream = dict() # type: Dict[int, int]
self.stream_to_recip = dict() # type: Dict[int, int]
def populate_for_stream_ids(self, stream_ids):
# type: (List[int]) -> None
def populate_for_stream_ids(self, stream_ids: List[int]) -> None:
stream_ids = sorted([
stream_id for stream_id in stream_ids
if stream_id not in self.stream_to_recip
@@ -49,8 +47,7 @@ class StreamRecipientMap:
''' % (Recipient.STREAM, id_list)
self._process_query(query)
def populate_for_recipient_ids(self, recipient_ids):
# type: (List[int]) -> None
def populate_for_recipient_ids(self, recipient_ids: List[int]) -> None:
recipient_ids = sorted([
recip_id for recip_id in recipient_ids
if recip_id not in self.recip_to_stream
@@ -77,8 +74,7 @@ class StreamRecipientMap:
self._process_query(query)
def _process_query(self, query):
# type: (str) -> None
def _process_query(self, query: str) -> None:
cursor = connection.cursor()
cursor.execute(query)
rows = cursor.fetchall()
@@ -87,14 +83,11 @@ class StreamRecipientMap:
self.recip_to_stream[recip_id] = stream_id
self.stream_to_recip[stream_id] = recip_id
def recipient_id_for(self, stream_id):
# type: (int) -> int
def recipient_id_for(self, stream_id: int) -> int:
return self.stream_to_recip[stream_id]
def stream_id_for(self, recip_id):
# type: (int) -> int
def stream_id_for(self, recip_id: int) -> int:
return self.recip_to_stream[recip_id]
def recipient_to_stream_id_dict(self):
# type: () -> Dict[int, int]
def recipient_to_stream_id_dict(self) -> Dict[int, int]:
return self.recip_to_stream

View File

@@ -44,8 +44,9 @@ SubInfo = TypedDict('SubInfo', {
'stream': Stream,
})
def get_bulk_stream_subscriber_info(user_profiles, stream_dict):
# type: (List[UserProfile], Dict[int, Stream]) -> Dict[int, List[Tuple[Subscription, Stream]]]
def get_bulk_stream_subscriber_info(
user_profiles: List[UserProfile],
stream_dict: Dict[int, Stream]) -> Dict[int, List[Tuple[Subscription, Stream]]]:
stream_ids = stream_dict.keys()
@@ -69,8 +70,7 @@ def get_bulk_stream_subscriber_info(user_profiles, stream_dict):
return result
def num_subscribers_for_stream_id(stream_id):
# type: (int) -> int
def num_subscribers_for_stream_id(stream_id: int) -> int:
return get_active_subscriptions_for_stream_id(stream_id).filter(
user_profile__is_active=True,
).count()

View File

@@ -16,13 +16,11 @@ class StreamTopicTarget:
places where we are are still using `subject` or
`topic_name` as a key into tables.
'''
def __init__(self, stream_id, topic_name):
# type: (int, Text) -> None
def __init__(self, stream_id: int, topic_name: Text) -> None:
self.stream_id = stream_id
self.topic_name = topic_name
def user_ids_muting_topic(self):
# type: () -> Set[int]
def user_ids_muting_topic(self) -> Set[int]:
query = MutedTopic.objects.filter(
stream_id=self.stream_id,
topic_name__iexact=self.topic_name,
@@ -34,6 +32,5 @@ class StreamTopicTarget:
for row in query
}
def get_active_subscriptions(self):
# type: () -> QuerySet
def get_active_subscriptions(self) -> QuerySet:
return get_active_subscriptions_for_stream_id(self.stream_id)

View File

@@ -7,8 +7,7 @@ from typing import Optional, Text
from zerver.models import get_realm, Realm, UserProfile
def get_subdomain(request):
# type: (HttpRequest) -> Text
def get_subdomain(request: HttpRequest) -> Text:
# The HTTP spec allows, but doesn't require, a client to omit the
# port in the `Host` header if it's "the default port for the
@@ -39,18 +38,15 @@ def get_subdomain(request):
return Realm.SUBDOMAIN_FOR_ROOT_DOMAIN
def is_subdomain_root_or_alias(request):
# type: (HttpRequest) -> bool
def is_subdomain_root_or_alias(request: HttpRequest) -> bool:
return get_subdomain(request) == Realm.SUBDOMAIN_FOR_ROOT_DOMAIN
def user_matches_subdomain(realm_subdomain, user_profile):
# type: (Optional[Text], UserProfile) -> bool
def user_matches_subdomain(realm_subdomain: Optional[Text], user_profile: UserProfile) -> bool:
if realm_subdomain is None:
return True
return user_profile.realm.subdomain == realm_subdomain
def is_root_domain_available():
# type: () -> bool
def is_root_domain_available() -> bool:
if settings.ROOT_DOMAIN_LANDING_PAGE:
return False
return get_realm(Realm.SUBDOMAIN_FOR_ROOT_DOMAIN) is None

View File

@@ -20,8 +20,7 @@ from scripts.lib.zulip_tools import get_dev_uuid_var_path
UUID_VAR_DIR = get_dev_uuid_var_path()
FILENAME_SPLITTER = re.compile('[\W\-_]')
def database_exists(database_name, **options):
# type: (Text, **Any) -> bool
def database_exists(database_name: Text, **options: Any) -> bool:
db = options.get('database', DEFAULT_DB_ALIAS)
try:
connection = connections[db]
@@ -34,8 +33,7 @@ def database_exists(database_name, **options):
except OperationalError:
return False
def get_migration_status(**options):
# type: (**Any) -> str
def get_migration_status(**options: Any) -> str:
verbosity = options.get('verbosity', 1)
for app_config in apps.get_app_configs():
@@ -61,8 +59,7 @@ def get_migration_status(**options):
output = out.read()
return re.sub('\x1b\[(1|0)m', '', output)
def are_migrations_the_same(migration_file, **options):
# type: (Text, **Any) -> bool
def are_migrations_the_same(migration_file: Text, **options: Any) -> bool:
if not os.path.exists(migration_file):
return False
@@ -70,14 +67,12 @@ def are_migrations_the_same(migration_file, **options):
migration_content = f.read()
return migration_content == get_migration_status(**options)
def _get_hash_file_path(source_file_path, status_dir):
# type: (str, str) -> str
def _get_hash_file_path(source_file_path: str, status_dir: str) -> str:
basename = os.path.basename(source_file_path)
filename = '_'.join(FILENAME_SPLITTER.split(basename)).lower()
return os.path.join(status_dir, filename)
def _check_hash(target_hash_file, status_dir):
# type: (str, str) -> bool
def _check_hash(target_hash_file: str, status_dir: str) -> bool:
"""
This function has a side effect of creating a new hash file or
updating the old hash file.

View File

@@ -79,8 +79,7 @@ class MockLDAP(fakeldap.MockLDAP):
pass
@contextmanager
def stub_event_queue_user_events(event_queue_return, user_events_return):
# type: (Any, Any) -> Iterator[None]
def stub_event_queue_user_events(event_queue_return: Any, user_events_return: Any) -> Iterator[None]:
with mock.patch('zerver.lib.events.request_event_queue',
return_value=event_queue_return):
with mock.patch('zerver.lib.events.get_user_events',
@@ -88,16 +87,14 @@ def stub_event_queue_user_events(event_queue_return, user_events_return):
yield
@contextmanager
def simulated_queue_client(client):
# type: (Callable[..., Any]) -> Iterator[None]
def simulated_queue_client(client: Callable[..., Any]) -> Iterator[None]:
real_SimpleQueueClient = queue_processors.SimpleQueueClient
queue_processors.SimpleQueueClient = client # type: ignore # https://github.com/JukkaL/mypy/issues/1152
yield
queue_processors.SimpleQueueClient = real_SimpleQueueClient # type: ignore # https://github.com/JukkaL/mypy/issues/1152
@contextmanager
def tornado_redirected_to_list(lst):
# type: (List[Mapping[str, Any]]) -> Iterator[None]
def tornado_redirected_to_list(lst: List[Mapping[str, Any]]) -> Iterator[None]:
real_event_queue_process_notification = event_queue.process_notification
event_queue.process_notification = lambda notice: lst.append(notice)
# process_notification takes a single parameter called 'notice'.
@@ -109,12 +106,11 @@ def tornado_redirected_to_list(lst):
event_queue.process_notification = real_event_queue_process_notification
@contextmanager
def simulated_empty_cache():
# type: () -> Generator[List[Tuple[str, Union[Text, List[Text]], Text]], None, None]
def simulated_empty_cache() -> Generator[
List[Tuple[str, Union[Text, List[Text]], Text]], None, None]:
cache_queries = [] # type: List[Tuple[str, Union[Text, List[Text]], Text]]
def my_cache_get(key, cache_name=None):
# type: (Text, Optional[str]) -> Optional[Dict[Text, Any]]
def my_cache_get(key: Text, cache_name: Optional[str]=None) -> Optional[Dict[Text, Any]]:
cache_queries.append(('get', key, cache_name))
return None
@@ -132,8 +128,8 @@ def simulated_empty_cache():
cache.cache_get_many = old_get_many
@contextmanager
def queries_captured(include_savepoints=False):
# type: (Optional[bool]) -> Generator[List[Dict[str, Union[str, bytes]]], None, None]
def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
List[Dict[str, Union[str, bytes]]], None, None]:
'''
Allow a user to capture just the queries executed during
the with statement.
@@ -141,8 +137,10 @@ def queries_captured(include_savepoints=False):
queries = [] # type: List[Dict[str, Union[str, bytes]]]
def wrapper_execute(self, action, sql, params=()):
# type: (TimeTrackingCursor, Callable[[NonBinaryStr, Iterable[Any]], None], NonBinaryStr, Iterable[Any]) -> None
def wrapper_execute(self: TimeTrackingCursor,
action: Callable[[NonBinaryStr, Iterable[Any]], None],
sql: NonBinaryStr,
params: Iterable[Any]=()) -> None:
cache = get_cache_backend(None)
cache.clear()
start = time.time()
@@ -160,13 +158,13 @@ def queries_captured(include_savepoints=False):
old_execute = TimeTrackingCursor.execute
old_executemany = TimeTrackingCursor.executemany
def cursor_execute(self, sql, params=()):
# type: (TimeTrackingCursor, NonBinaryStr, Iterable[Any]) -> None
def cursor_execute(self: TimeTrackingCursor, sql: NonBinaryStr,
params: Iterable[Any]=()) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167
TimeTrackingCursor.execute = cursor_execute # type: ignore # https://github.com/JukkaL/mypy/issues/1167
def cursor_executemany(self, sql, params=()):
# type: (TimeTrackingCursor, NonBinaryStr, Iterable[Any]) -> None
def cursor_executemany(self: TimeTrackingCursor, sql: NonBinaryStr,
params: Iterable[Any]=()) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # type: ignore # https://github.com/JukkaL/mypy/issues/1167 # nocoverage -- doesn't actually get used in tests
TimeTrackingCursor.executemany = cursor_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
@@ -176,8 +174,7 @@ def queries_captured(include_savepoints=False):
TimeTrackingCursor.executemany = old_executemany # type: ignore # https://github.com/JukkaL/mypy/issues/1167
@contextmanager
def stdout_suppressed():
# type: () -> Iterator[IO[str]]
def stdout_suppressed() -> Iterator[IO[str]]:
"""Redirect stdout to /dev/null."""
with open(os.devnull, 'a') as devnull:
@@ -185,26 +182,22 @@ def stdout_suppressed():
yield stdout
sys.stdout = stdout
def get_test_image_file(filename):
# type: (str) -> IO[Any]
def get_test_image_file(filename: str) -> IO[Any]:
test_avatar_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../tests/images'))
return open(os.path.join(test_avatar_dir, filename), 'rb')
def avatar_disk_path(user_profile, medium=False):
# type: (UserProfile, bool) -> Text
def avatar_disk_path(user_profile: UserProfile, medium: bool=False) -> Text:
avatar_url_path = avatar_url(user_profile, medium)
avatar_disk_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars",
avatar_url_path.split("/")[-2],
avatar_url_path.split("/")[-1].split("?")[0])
return avatar_disk_path
def make_client(name):
# type: (str) -> Client
def make_client(name: str) -> Client:
client, _ = Client.objects.get_or_create(name=name)
return client
def find_key_by_email(address):
# type: (Text) -> Optional[Text]
def find_key_by_email(address: Text) -> Optional[Text]:
from django.core.mail import outbox
key_regex = re.compile("accounts/do_confirm/([a-z0-9]{24})>")
for message in reversed(outbox):
@@ -212,8 +205,7 @@ def find_key_by_email(address):
return key_regex.search(message.body).groups()[0]
return None # nocoverage -- in theory a test might want this case, but none do
def find_pattern_in_email(address, pattern):
# type: (Text, Text) -> Optional[Text]
def find_pattern_in_email(address: Text, pattern: Text) -> Optional[Text]:
from django.core.mail import outbox
key_regex = re.compile(pattern)
for message in reversed(outbox):
@@ -221,35 +213,30 @@ def find_pattern_in_email(address, pattern):
return key_regex.search(message.body).group(0)
return None # nocoverage -- in theory a test might want this case, but none do
def message_stream_count(user_profile):
# type: (UserProfile) -> int
def message_stream_count(user_profile: UserProfile) -> int:
return UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
count()
def most_recent_usermessage(user_profile):
# type: (UserProfile) -> UserMessage
def most_recent_usermessage(user_profile: UserProfile) -> UserMessage:
query = UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
order_by('-message')
return query[0] # Django does LIMIT here
def most_recent_message(user_profile):
# type: (UserProfile) -> Message
def most_recent_message(user_profile: UserProfile) -> Message:
usermessage = most_recent_usermessage(user_profile)
return usermessage.message
def get_subscription(stream_name, user_profile):
# type: (Text, UserProfile) -> Subscription
def get_subscription(stream_name: Text, user_profile: UserProfile) -> Subscription:
stream = get_stream(stream_name, user_profile.realm)
recipient = get_stream_recipient(stream.id)
return Subscription.objects.get(user_profile=user_profile,
recipient=recipient, active=True)
def get_user_messages(user_profile):
# type: (UserProfile) -> List[Message]
def get_user_messages(user_profile: UserProfile) -> List[Message]:
query = UserMessage.objects. \
select_related("message"). \
filter(user_profile=user_profile). \
@@ -257,15 +244,13 @@ def get_user_messages(user_profile):
return [um.message for um in query]
class DummyHandler:
def __init__(self):
# type: () -> None
def __init__(self) -> None:
allocate_handler_id(self) # type: ignore # this is a testing mock
class POSTRequestMock:
method = "POST"
def __init__(self, post_data, user_profile):
# type: (Dict[str, Any], Optional[UserProfile]) -> None
def __init__(self, post_data: Dict[str, Any], user_profile: Optional[UserProfile]) -> None:
self.GET = {} # type: Dict[str, Any]
self.POST = post_data
self.user = user_profile
@@ -278,8 +263,7 @@ class HostRequestMock:
"""A mock request object where get_host() works. Useful for testing
routes that use Zulip's subdomains feature"""
def __init__(self, user_profile=None, host=settings.EXTERNAL_HOST):
# type: (UserProfile, Text) -> None
def __init__(self, user_profile: UserProfile=None, host: Text=settings.EXTERNAL_HOST) -> None:
self.host = host
self.GET = {} # type: Dict[str, Any]
self.POST = {} # type: Dict[str, Any]
@@ -291,19 +275,16 @@ class HostRequestMock:
self.content_type = ''
self._email = ''
def get_host(self):
# type: () -> Text
def get_host(self) -> Text:
return self.host
class MockPythonResponse:
def __init__(self, text, status_code):
# type: (Text, int) -> None
def __init__(self, text: Text, status_code: int) -> None:
self.text = text
self.status_code = status_code
@property
def ok(self):
# type: () -> bool
def ok(self) -> bool:
return self.status_code == 200
INSTRUMENTING = os.environ.get('TEST_INSTRUMENT_URL_COVERAGE', '') == 'TRUE'
@@ -311,17 +292,15 @@ INSTRUMENTED_CALLS = [] # type: List[Dict[str, Any]]
UrlFuncT = Callable[..., HttpResponse] # TODO: make more specific
def append_instrumentation_data(data):
# type: (Dict[str, Any]) -> None
def append_instrumentation_data(data: Dict[str, Any]) -> None:
INSTRUMENTED_CALLS.append(data)
def instrument_url(f):
# type: (UrlFuncT) -> UrlFuncT
def instrument_url(f: UrlFuncT) -> UrlFuncT:
if not INSTRUMENTING: # nocoverage -- option is always enabled; should we remove?
return f
else:
def wrapper(self, url, info={}, **kwargs):
# type: (Any, Text, Dict[str, Any], **Any) -> HttpResponse
def wrapper(self: Any, url: Text, info: Dict[str, Any]={},
**kwargs: Any) -> HttpResponse:
start = time.time()
result = f(self, url, info, **kwargs)
delay = time.time() - start
@@ -343,8 +322,7 @@ def instrument_url(f):
return result
return wrapper
def write_instrumentation_reports(full_suite):
# type: (bool) -> None
def write_instrumentation_reports(full_suite: bool) -> None:
if INSTRUMENTING:
calls = INSTRUMENTED_CALLS
@@ -353,17 +331,14 @@ def write_instrumentation_reports(full_suite):
# Find our untested urls.
pattern_cnt = collections.defaultdict(int) # type: Dict[str, int]
def re_strip(r):
# type: (Any) -> str
def re_strip(r: Any) -> str:
return str(r).lstrip('^').rstrip('$')
def find_patterns(patterns, prefixes):
# type: (List[Any], List[str]) -> None
def find_patterns(patterns: List[Any], prefixes: List[str]) -> None:
for pattern in patterns:
find_pattern(pattern, prefixes)
def cleanup_url(url):
# type: (str) -> str
def cleanup_url(url: str) -> str:
if url.startswith('/'):
url = url[1:]
if url.startswith('http://testserver/'):
@@ -374,8 +349,7 @@ def write_instrumentation_reports(full_suite):
url = url[len('http://testserver:9080/'):]
return url
def find_pattern(pattern, prefixes):
# type: (Any, List[str]) -> None
def find_pattern(pattern: Any, prefixes: List[str]) -> None:
if isinstance(pattern, type(LocaleRegexURLResolver)):
return # nocoverage -- shouldn't actually happen
@@ -447,16 +421,14 @@ def write_instrumentation_reports(full_suite):
print(" %s" % (untested_pattern,))
sys.exit(1)
def get_all_templates():
# type: () -> List[str]
def get_all_templates() -> List[str]:
templates = []
relpath = os.path.relpath
isfile = os.path.isfile
path_exists = os.path.exists
def is_valid_template(p, n):
# type: (Text, Text) -> bool
def is_valid_template(p: Text, n: Text) -> bool:
return 'webhooks' not in p \
and not n.startswith('.') \
and not n.startswith('__init__') \
@@ -464,8 +436,7 @@ def get_all_templates():
and not n.endswith('.source.html') \
and isfile(p)
def process(template_dir, dirname, fnames):
# type: (str, str, Iterable[str]) -> None
def process(template_dir: str, dirname: str, fnames: Iterable[str]) -> None:
for name in fnames:
path = os.path.join(dirname, name)
if is_valid_template(path, name):
@@ -480,20 +451,17 @@ def get_all_templates():
return templates
def load_subdomain_token(response):
# type: (HttpResponse) -> Dict[str, Any]
def load_subdomain_token(response: HttpResponse) -> Dict[str, Any]:
assert isinstance(response, HttpResponseRedirect)
token = response.url.rsplit('/', 1)[1]
return signing.loads(token, salt='zerver.views.auth.log_into_subdomain')
FuncT = TypeVar('FuncT', bound=Callable[..., None])
def use_s3_backend(method):
# type: (FuncT) -> FuncT
def use_s3_backend(method: FuncT) -> FuncT:
@mock_s3_deprecated
@override_settings(LOCAL_UPLOADS_DIR=None)
def new_method(*args, **kwargs):
# type: (*Any, **Any) -> Any
def new_method(*args: Any, **kwargs: Any) -> Any:
zerver.lib.upload.upload_backend = S3UploadBackend()
try:
return method(*args, **kwargs)

View File

@@ -32,50 +32,42 @@ import time
import traceback
import unittest
if False:
# Only needed by mypy.
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.sharedctypes import Synchronized
_worker_id = 0 # Used to identify the worker process.
ReturnT = TypeVar('ReturnT') # Constrain return type to match
def slow(slowness_reason):
# type: (str) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]
def slow(slowness_reason: str) -> Callable[[Callable[..., ReturnT]], Callable[..., ReturnT]]:
'''
This is a decorate that annotates a test as being "known
to be slow." The decorator will set expected_run_time and slowness_reason
as attributes of the function. Other code can use this annotation
as needed, e.g. to exclude these tests in "fast" mode.
'''
def decorator(f):
# type: (Any) -> ReturnT
def decorator(f: Any) -> ReturnT:
f.slowness_reason = slowness_reason
return f
return decorator
def is_known_slow_test(test_method):
# type: (Any) -> bool
def is_known_slow_test(test_method: Any) -> bool:
return hasattr(test_method, 'slowness_reason')
def full_test_name(test):
# type: (TestCase) -> str
def full_test_name(test: TestCase) -> str:
test_module = test.__module__
test_class = test.__class__.__name__
test_method = test._testMethodName
return '%s.%s.%s' % (test_module, test_class, test_method)
def get_test_method(test):
# type: (TestCase) -> Callable[[], None]
def get_test_method(test: TestCase) -> Callable[[], None]:
return getattr(test, test._testMethodName)
# Each tuple is delay, test_name, slowness_reason
TEST_TIMINGS = [] # type: List[Tuple[float, str, str]]
def report_slow_tests():
# type: () -> None
def report_slow_tests() -> None:
timings = sorted(TEST_TIMINGS, reverse=True)
print('SLOWNESS REPORT')
print(' delay test')
@@ -92,8 +84,8 @@ def report_slow_tests():
print(' consider removing @slow decorator')
print(' This may no longer be true: %s' % (slowness_reason,))
def enforce_timely_test_completion(test_method, test_name, delay, result):
# type: (Any, str, float, TestResult) -> None
def enforce_timely_test_completion(test_method: Any, test_name: str,
delay: float, result: TestResult) -> None:
if hasattr(test_method, 'slowness_reason'):
max_delay = 2.0 # seconds
else:
@@ -103,12 +95,10 @@ def enforce_timely_test_completion(test_method, test_name, delay, result):
msg = '** Test is TOO slow: %s (%.3f s)\n' % (test_name, delay)
result.addInfo(test_method, msg)
def fast_tests_only():
# type: () -> bool
def fast_tests_only() -> bool:
return "FAST_TESTS_ONLY" in os.environ
def run_test(test, result):
# type: (TestCase, TestResult) -> bool
def run_test(test: TestCase, result: TestResult) -> bool:
failed = False
test_method = get_test_method(test)
@@ -175,44 +165,36 @@ class TextTestResult(runner.TextTestResult):
This class has unpythonic function names because base class follows
this style.
"""
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.failed_tests = [] # type: List[str]
def addInfo(self, test, msg):
# type: (TestCase, Text) -> None
def addInfo(self, test: TestCase, msg: Text) -> None:
self.stream.write(msg)
self.stream.flush()
def addInstrumentation(self, test, data):
# type: (TestCase, Dict[str, Any]) -> None
def addInstrumentation(self, test: TestCase, data: Dict[str, Any]) -> None:
append_instrumentation_data(data)
def startTest(self, test):
# type: (TestCase) -> None
def startTest(self, test: TestCase) -> None:
TestResult.startTest(self, test)
self.stream.writeln("Running {}".format(full_test_name(test)))
self.stream.flush()
def addSuccess(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def addSuccess(self, *args: Any, **kwargs: Any) -> None:
TestResult.addSuccess(self, *args, **kwargs)
def addError(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def addError(self, *args: Any, **kwargs: Any) -> None:
TestResult.addError(self, *args, **kwargs)
test_name = full_test_name(args[0])
self.failed_tests.append(test_name)
def addFailure(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def addFailure(self, *args: Any, **kwargs: Any) -> None:
TestResult.addFailure(self, *args, **kwargs)
test_name = full_test_name(args[0])
self.failed_tests.append(test_name)
def addSkip(self, test, reason):
# type: (TestCase, Text) -> None
def addSkip(self, test: TestCase, reason: Text) -> None:
TestResult.addSkip(self, test, reason)
self.stream.writeln("** Skipping {}: {}".format(full_test_name(test),
reason))
@@ -223,28 +205,24 @@ class RemoteTestResult(django_runner.RemoteTestResult):
The class follows the unpythonic style of function names of the
base class.
"""
def addInfo(self, test, msg):
# type: (TestCase, Text) -> None
def addInfo(self, test: TestCase, msg: Text) -> None:
self.events.append(('addInfo', self.test_index, msg))
def addInstrumentation(self, test, data):
# type: (TestCase, Dict[str, Any]) -> None
def addInstrumentation(self, test: TestCase, data: Dict[str, Any]) -> None:
# Some elements of data['info'] cannot be serialized.
if 'info' in data:
del data['info']
self.events.append(('addInstrumentation', self.test_index, data))
def process_instrumented_calls(func):
# type: (Callable[[Dict[str, Any]], None]) -> None
def process_instrumented_calls(func: Callable[[Dict[str, Any]], None]) -> None:
for call in test_helpers.INSTRUMENTED_CALLS:
func(call)
SerializedSubsuite = Tuple[Type['TestSuite'], List[str]]
SubsuiteArgs = Tuple[Type['RemoteTestRunner'], int, SerializedSubsuite, bool]
def run_subsuite(args):
# type: (SubsuiteArgs) -> Tuple[int, Any]
def run_subsuite(args: SubsuiteArgs) -> Tuple[int, Any]:
# Reset the accumulated INSTRUMENTED_CALLS before running this subsuite.
test_helpers.INSTRUMENTED_CALLS = []
# The first argument is the test runner class but we don't need it
@@ -262,8 +240,9 @@ def run_subsuite(args):
# Monkey-patch database creation to fix unnecessary sleep(1)
from django.db.backends.postgresql.creation import DatabaseCreation
def _replacement_destroy_test_db(self, test_database_name, verbosity):
# type: (Any, str, Any) -> None
def _replacement_destroy_test_db(self: Any,
test_database_name: str,
verbosity: Any) -> None:
"""Replacement for Django's _destroy_test_db that removes the
unnecessary sleep(1)."""
with self.connection._nodb_connection.cursor() as cursor:
@@ -271,8 +250,7 @@ def _replacement_destroy_test_db(self, test_database_name, verbosity):
% self.connection.ops.quote_name(test_database_name))
DatabaseCreation._destroy_test_db = _replacement_destroy_test_db
def destroy_test_databases(database_id=None):
# type: (Optional[int]) -> None
def destroy_test_databases(database_id: Optional[int]=None) -> None:
"""
When database_id is None, the name of the databases is picked up
by the database settings.
@@ -285,8 +263,7 @@ def destroy_test_databases(database_id=None):
# DB doesn't exist. No need to do anything.
pass
def create_test_databases(database_id):
# type: (int) -> None
def create_test_databases(database_id: int) -> None:
for alias in connections:
connection = connections[alias]
connection.creation.clone_test_db(
@@ -302,8 +279,7 @@ def create_test_databases(database_id):
connection.settings_dict.update(settings_dict)
connection.close()
def init_worker(counter):
# type: (Synchronized) -> None
def init_worker(counter: Synchronized) -> None:
"""
This function runs only under parallel mode. It initializes the
individual processes which are also called workers.
@@ -336,8 +312,7 @@ def init_worker(counter):
settings.LOCAL_UPLOADS_DIR = '{}_{}'.format(settings.LOCAL_UPLOADS_DIR,
_worker_id)
def is_upload_avatar_url(url):
# type: (RegexURLPattern) -> bool
def is_upload_avatar_url(url: RegexURLPattern) -> bool:
if url.regex.pattern == r'^user_avatars/(?P<path>.*)$':
return True
return False
@@ -355,8 +330,7 @@ def init_worker(counter):
print("*** Upload directory not found.")
class TestSuite(unittest.TestSuite):
def run(self, result, debug=False):
# type: (TestResult, Optional[bool]) -> TestResult
def run(self, result: TestResult, debug: Optional[bool]=False) -> TestResult:
"""
This function mostly contains the code from
unittest.TestSuite.run. The need to override this function
@@ -400,8 +374,7 @@ class ParallelTestSuite(django_runner.ParallelTestSuite):
run_subsuite = run_subsuite
init_worker = init_worker
def __init__(self, suite, processes, failfast):
# type: (TestSuite, int, bool) -> None
def __init__(self, suite: TestSuite, processes: int, failfast: bool) -> None:
super().__init__(suite, processes, failfast)
# We can't specify a consistent type for self.subsuites, since
# the whole idea here is to monkey-patch that so we can use
@@ -414,8 +387,7 @@ class Runner(DiscoverRunner):
test_loader = TestLoader()
parallel_test_suite = ParallelTestSuite
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def __init__(self, *args: Any, **kwargs: Any) -> None:
DiscoverRunner.__init__(self, *args, **kwargs)
# `templates_rendered` holds templates which were rendered
@@ -427,12 +399,10 @@ class Runner(DiscoverRunner):
template_rendered.connect(self.on_template_rendered)
self.database_id = random.randint(1, 10000)
def get_resultclass(self):
# type: () -> Type[TestResult]
def get_resultclass(self) -> Type[TestResult]:
return TextTestResult
def on_template_rendered(self, sender, context, **kwargs):
# type: (Any, Dict[str, Any], **Any) -> None
def on_template_rendered(self, sender: Any, context: Dict[str, Any], **kwargs: Any) -> None:
if hasattr(sender, 'template'):
template_name = sender.template.name
if template_name not in self.templates_rendered:
@@ -442,19 +412,16 @@ class Runner(DiscoverRunner):
self.templates_rendered.add(template_name)
self.shallow_tested_templates.discard(template_name)
def get_shallow_tested_templates(self):
# type: () -> Set[str]
def get_shallow_tested_templates(self) -> Set[str]:
return self.shallow_tested_templates
def setup_test_environment(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
def setup_test_environment(self, *args: Any, **kwargs: Any) -> Any:
settings.DATABASES['default']['NAME'] = settings.BACKEND_DATABASE_TEMPLATE
# We create/destroy the test databases in run_tests to avoid
# duplicate work when running in parallel mode.
return super().setup_test_environment(*args, **kwargs)
def teardown_test_environment(self, *args, **kwargs):
# type: (*Any, **Any) -> Any
def teardown_test_environment(self, *args: Any, **kwargs: Any) -> Any:
# No need to pass the database id now. It will be picked up
# automatically through settings.
if self.parallel == 1:
@@ -502,12 +469,10 @@ class Runner(DiscoverRunner):
write_instrumentation_reports(full_suite=full_suite)
return failed, result.failed_tests
def get_test_names(suite):
# type: (TestSuite) -> List[str]
def get_test_names(suite: TestSuite) -> List[str]:
return [full_test_name(t) for t in get_tests_from_suite(suite)]
def get_tests_from_suite(suite):
# type: (TestSuite) -> TestCase
def get_tests_from_suite(suite: TestSuite) -> TestCase:
for test in suite:
if isinstance(test, TestSuite):
for child in get_tests_from_suite(test):
@@ -515,12 +480,10 @@ def get_tests_from_suite(suite):
else:
yield test
def serialize_suite(suite):
# type: (TestSuite) -> Tuple[Type[TestSuite], List[str]]
def serialize_suite(suite: TestSuite) -> Tuple[Type[TestSuite], List[str]]:
return type(suite), get_test_names(suite)
def deserialize_suite(args):
# type: (Tuple[Type[TestSuite], List[str]]) -> TestSuite
def deserialize_suite(args: Tuple[Type[TestSuite], List[str]]) -> TestSuite:
suite_class, test_names = args
suite = suite_class()
tests = TestLoader().loadTestsFromNames(test_names)
@@ -536,12 +499,10 @@ class SubSuiteList(List[Tuple[Type[TestSuite], List[str]]]):
This class allows us to avoid changing the main logic of
ParallelTestSuite and still make it serializable.
"""
def __init__(self, suites):
# type: (List[TestSuite]) -> None
def __init__(self, suites: List[TestSuite]) -> None:
serialized_suites = [serialize_suite(s) for s in suites]
super().__init__(serialized_suites)
def __getitem__(self, index):
# type: (Any) -> Any
def __getitem__(self, index: Any) -> Any:
suite = super().__getitem__(index)
return deserialize_suite(suite)

View File

@@ -5,8 +5,7 @@ import subprocess
from django.conf import settings
from typing import Optional, Text
def render_tex(tex, is_inline=True):
# type: (Text, bool) -> Optional[Text]
def render_tex(tex: Text, is_inline: bool=True) -> Optional[Text]:
"""Render a TeX string into HTML using KaTeX
Returns the HTML string, or None if there was some error in the TeX syntax

View File

@@ -12,14 +12,12 @@ import threading
class TimeoutExpired(Exception):
'''Exception raised when a function times out.'''
def __str__(self):
# type: () -> str
def __str__(self) -> str:
return 'Function call timed out.'
ResultT = TypeVar('ResultT')
def timeout(timeout, func, *args, **kwargs):
# type: (float, Callable[..., ResultT], *Any, **Any) -> ResultT
def timeout(timeout: float, func: Callable[..., ResultT], *args: Any, **kwargs: Any) -> ResultT:
'''Call the function in a separate thread.
Return its return value, or raise an exception,
within approximately 'timeout' seconds.
@@ -36,8 +34,7 @@ def timeout(timeout, func, *args, **kwargs):
operation.'''
class TimeoutThread(threading.Thread):
def __init__(self):
# type: () -> None
def __init__(self) -> None:
threading.Thread.__init__(self)
self.result = None # type: Optional[ResultT]
self.exc_info = None # type: Optional[Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]]
@@ -46,15 +43,13 @@ def timeout(timeout, func, *args, **kwargs):
# if this is the only thread left.
self.daemon = True
def run(self):
# type: () -> None
def run(self) -> None:
try:
self.result = func(*args, **kwargs)
except BaseException:
self.exc_info = sys.exc_info()
def raise_async_timeout(self):
# type: () -> None
def raise_async_timeout(self) -> None:
# Called from another thread.
# Attempt to raise a TimeoutExpired in the thread represented by 'self'.
tid = ctypes.c_long(self.ident)

View File

@@ -5,48 +5,40 @@ from django.utils.timezone import utc as timezone_utc
class TimezoneNotUTCException(Exception):
pass
def verify_UTC(dt):
# type: (datetime.datetime) -> None
def verify_UTC(dt: datetime.datetime) -> None:
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) != timezone_utc.utcoffset(dt):
raise TimezoneNotUTCException("Datetime %s does not have a UTC timezone." % (dt,))
def convert_to_UTC(dt):
# type: (datetime.datetime) -> datetime.datetime
def convert_to_UTC(dt: datetime.datetime) -> datetime.datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone_utc)
return dt.astimezone(timezone_utc)
def floor_to_hour(dt):
# type: (datetime.datetime) -> datetime.datetime
def floor_to_hour(dt: datetime.datetime) -> datetime.datetime:
verify_UTC(dt)
return datetime.datetime(*dt.timetuple()[:4]) \
.replace(tzinfo=timezone_utc)
def floor_to_day(dt):
# type: (datetime.datetime) -> datetime.datetime
def floor_to_day(dt: datetime.datetime) -> datetime.datetime:
verify_UTC(dt)
return datetime.datetime(*dt.timetuple()[:3]) \
.replace(tzinfo=timezone_utc)
def ceiling_to_hour(dt):
# type: (datetime.datetime) -> datetime.datetime
def ceiling_to_hour(dt: datetime.datetime) -> datetime.datetime:
floor = floor_to_hour(dt)
if floor == dt:
return floor
return floor + datetime.timedelta(hours=1)
def ceiling_to_day(dt):
# type: (datetime.datetime) -> datetime.datetime
def ceiling_to_day(dt: datetime.datetime) -> datetime.datetime:
floor = floor_to_day(dt)
if floor == dt:
return floor
return floor + datetime.timedelta(days=1)
def timestamp_to_datetime(timestamp):
# type: (float) -> datetime.datetime
def timestamp_to_datetime(timestamp: float) -> datetime.datetime:
return datetime.datetime.fromtimestamp(float(timestamp), tz=timezone_utc)
def datetime_to_timestamp(dt):
# type: (datetime.datetime) -> int
def datetime_to_timestamp(dt: datetime.datetime) -> int:
verify_UTC(dt)
return calendar.timegm(dt.timetuple())

View File

@@ -3,6 +3,5 @@ from typing import Text, List
import pytz
def get_all_timezones():
# type: () -> List[Text]
def get_all_timezones() -> List[Text]:
return sorted(pytz.all_timezones)

View File

@@ -4,8 +4,7 @@ import functools
from typing import Any, Callable, IO, Mapping, Sequence, TypeVar, Text
def get_mapping_type_str(x):
# type: (Mapping[Any, Any]) -> str
def get_mapping_type_str(x: Mapping[Any, Any]) -> str:
container_type = type(x).__name__
if not x:
if container_type == 'dict':
@@ -26,8 +25,7 @@ def get_mapping_type_str(x):
else:
return '%s([(%s, %s), ...])' % (container_type, key_type, value_type)
def get_sequence_type_str(x):
# type: (Sequence[Any]) -> str
def get_sequence_type_str(x: Sequence[Any]) -> str:
container_type = type(x).__name__
if not x:
if container_type == 'list':
@@ -48,8 +46,7 @@ def get_sequence_type_str(x):
expansion_blacklist = [Text, bytes]
def get_type_str(x):
# type: (Any) -> str
def get_type_str(x: Any) -> str:
if x is None:
return 'None'
elif isinstance(x, tuple):
@@ -69,13 +66,10 @@ def get_type_str(x):
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
def print_types_to(file_obj):
# type: (IO[str]) -> Callable[[FuncT], FuncT]
def decorator(func):
# type: (FuncT) -> FuncT
def print_types_to(file_obj: IO[str]) -> Callable[[FuncT], FuncT]:
def decorator(func: FuncT) -> FuncT:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# type: (*Any, **Any) -> Any
def wrapper(*args: Any, **kwargs: Any) -> Any:
arg_types = [get_type_str(arg) for arg in args]
kwarg_types = [key + "=" + get_type_str(value) for key, value in kwargs.items()]
ret_val = func(*args, **kwargs)
@@ -87,6 +81,5 @@ def print_types_to(file_obj):
return wrapper # type: ignore # https://github.com/python/mypy/issues/1927
return decorator
def print_types(func):
# type: (FuncT) -> FuncT
def print_types(func: FuncT) -> FuncT:
return print_types_to(sys.stdout)(func)

View File

@@ -9,13 +9,11 @@ from typing import Dict, List, Text
class SourceMap:
'''Map (line, column) pairs from generated to source file.'''
def __init__(self, sourcemap_dirs):
# type: (List[Text]) -> None
def __init__(self, sourcemap_dirs: List[Text]) -> None:
self._dirs = sourcemap_dirs
self._indices = {} # type: Dict[Text, sourcemap.SourceMapDecoder]
def _index_for(self, minified_src):
# type: (Text) -> sourcemap.SourceMapDecoder
def _index_for(self, minified_src: Text) -> sourcemap.SourceMapDecoder:
'''Return the source map index for minified_src, loading it if not
already loaded.'''
if minified_src not in self._indices:
@@ -28,8 +26,7 @@ class SourceMap:
return self._indices[minified_src]
def annotate_stacktrace(self, stacktrace):
# type: (Text) -> Text
def annotate_stacktrace(self, stacktrace: Text) -> Text:
out = '' # type: Text
for ln in stacktrace.splitlines():
out += ln + '\n'

View File

@@ -53,14 +53,12 @@ DEFAULT_EMOJI_SIZE = 64
attachment_url_re = re.compile('[/\-]user[\-_]uploads[/\.-].*?(?=[ )]|\Z)')
def attachment_url_to_path_id(attachment_url):
# type: (Text) -> Text
def attachment_url_to_path_id(attachment_url: Text) -> Text:
path_id_raw = re.sub('[/\-]user[\-_]uploads[/\.-]', '', attachment_url)
# Remove any extra '.' after file extension. These are probably added by the user
return re.sub('[.]+$', '', path_id_raw, re.M)
def sanitize_name(value):
# type: (NonBinaryStr) -> Text
def sanitize_name(value: NonBinaryStr) -> Text:
"""
Sanitizes a value to be safe to store in a Linux filesystem, in
S3, and in a URL. So unicode is allowed, but not special
@@ -75,8 +73,7 @@ def sanitize_name(value):
value = re.sub('[^\w\s._-]', '', value, flags=re.U).strip()
return mark_safe(re.sub('[-\s]+', '-', value, flags=re.U))
def random_name(bytes=60):
# type: (int) -> Text
def random_name(bytes: int=60) -> Text:
return base64.urlsafe_b64encode(os.urandom(bytes)).decode('utf-8')
class BadImageError(JsonableError):
@@ -85,8 +82,7 @@ class BadImageError(JsonableError):
class ExceededQuotaError(JsonableError):
code = ErrorCode.QUOTA_EXCEEDED
def resize_avatar(image_data, size=DEFAULT_AVATAR_SIZE):
# type: (bytes, int) -> bytes
def resize_avatar(image_data: bytes, size: int=DEFAULT_AVATAR_SIZE) -> bytes:
try:
im = Image.open(io.BytesIO(image_data))
im = ImageOps.fit(im, (size, size), Image.ANTIALIAS)
@@ -97,8 +93,7 @@ def resize_avatar(image_data, size=DEFAULT_AVATAR_SIZE):
return out.getvalue()
def resize_emoji(image_data, size=DEFAULT_EMOJI_SIZE):
# type: (bytes, int) -> bytes
def resize_emoji(image_data: bytes, size: int=DEFAULT_EMOJI_SIZE) -> bytes:
try:
im = Image.open(io.BytesIO(image_data))
image_format = im.format
@@ -127,43 +122,36 @@ class ZulipUploadBackend:
# type: (Text, int, Optional[Text], bytes, UserProfile, Optional[Realm]) -> Text
raise NotImplementedError()
def upload_avatar_image(self, user_file, acting_user_profile, target_user_profile):
# type: (File, UserProfile, UserProfile) -> None
def upload_avatar_image(self, user_file: File,
acting_user_profile: UserProfile,
target_user_profile: UserProfile) -> None:
raise NotImplementedError()
def delete_message_image(self, path_id):
# type: (Text) -> bool
def delete_message_image(self, path_id: Text) -> bool:
raise NotImplementedError()
def get_avatar_url(self, hash_key, medium=False):
# type: (Text, bool) -> Text
def get_avatar_url(self, hash_key: Text, medium: bool=False) -> Text:
raise NotImplementedError()
def ensure_medium_avatar_image(self, user_profile):
# type: (UserProfile) -> None
def ensure_medium_avatar_image(self, user_profile: UserProfile) -> None:
raise NotImplementedError()
def upload_realm_icon_image(self, icon_file, user_profile):
# type: (File, UserProfile) -> None
def upload_realm_icon_image(self, icon_file: File, user_profile: UserProfile) -> None:
raise NotImplementedError()
def get_realm_icon_url(self, realm_id, version):
# type: (int, int) -> Text
def get_realm_icon_url(self, realm_id: int, version: int) -> Text:
raise NotImplementedError()
def upload_emoji_image(self, emoji_file, emoji_file_name, user_profile):
# type: (File, Text, UserProfile) -> None
def upload_emoji_image(self, emoji_file: File, emoji_file_name: Text, user_profile: UserProfile) -> None:
raise NotImplementedError()
def get_emoji_url(self, emoji_file_name, realm_id):
# type: (Text, int) -> Text
def get_emoji_url(self, emoji_file_name: Text, realm_id: int) -> Text:
raise NotImplementedError()
### S3
def get_bucket(conn, bucket_name):
# type: (S3Connection, Text) -> Bucket
def get_bucket(conn: S3Connection, bucket_name: Text) -> Bucket:
# Calling get_bucket() with validate=True can apparently lead
# to expensive S3 bills:
# http://www.appneta.com/blog/s3-list-get-bucket-default/
@@ -197,8 +185,7 @@ def upload_image_to_s3(
key.set_contents_from_string(contents, headers=headers) # type: ignore # https://github.com/python/typeshed/issues/1552
def get_total_uploads_size_for_user(user):
# type: (UserProfile) -> int
def get_total_uploads_size_for_user(user: UserProfile) -> int:
uploads = Attachment.objects.filter(owner=user)
total_quota = uploads.aggregate(Sum('size'))['size__sum']
@@ -207,16 +194,14 @@ def get_total_uploads_size_for_user(user):
total_quota = 0
return total_quota
def within_upload_quota(user, uploaded_file_size):
# type: (UserProfile, int) -> bool
def within_upload_quota(user: UserProfile, uploaded_file_size: int) -> bool:
total_quota = get_total_uploads_size_for_user(user)
if (total_quota + uploaded_file_size > user.quota):
return False
else:
return True
def get_file_info(request, user_file):
# type: (HttpRequest, File) -> Tuple[Text, int, Optional[Text]]
def get_file_info(request: HttpRequest, user_file: File) -> Tuple[Text, int, Optional[Text]]:
uploaded_file_name = user_file.name
assert isinstance(uploaded_file_name, str)
@@ -237,13 +222,11 @@ def get_file_info(request, user_file):
return uploaded_file_name, uploaded_file_size, content_type
def get_signed_upload_url(path):
# type: (Text) -> Text
def get_signed_upload_url(path: Text) -> Text:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
return conn.generate_url(15, 'GET', bucket=settings.S3_AUTH_UPLOADS_BUCKET, key=path)
def get_realm_for_filename(path):
# type: (Text) -> Optional[int]
def get_realm_for_filename(path: Text) -> Optional[int]:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
key = get_bucket(conn, settings.S3_AUTH_UPLOADS_BUCKET).get_key(path)
if key is None:
@@ -278,8 +261,7 @@ class S3UploadBackend(ZulipUploadBackend):
create_attachment(uploaded_file_name, s3_file_name, user_profile, uploaded_file_size)
return url
def delete_message_image(self, path_id):
# type: (Text) -> bool
def delete_message_image(self, path_id: Text) -> bool:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
bucket = get_bucket(conn, settings.S3_AUTH_UPLOADS_BUCKET)
@@ -293,8 +275,9 @@ class S3UploadBackend(ZulipUploadBackend):
logging.warning("%s does not exist. Its entry in the database will be removed." % (file_name,))
return False
def upload_avatar_image(self, user_file, acting_user_profile, target_user_profile):
# type: (File, UserProfile, UserProfile) -> None
def upload_avatar_image(self, user_file: File,
acting_user_profile: UserProfile,
target_user_profile: UserProfile) -> None:
content_type = guess_type(user_file.name)[0]
bucket_name = settings.S3_AVATAR_BUCKET
s3_file_name = user_avatar_path(target_user_profile)
@@ -329,15 +312,13 @@ class S3UploadBackend(ZulipUploadBackend):
# See avatar_url in avatar.py for URL. (That code also handles the case
# that users use gravatar.)
def get_avatar_url(self, hash_key, medium=False):
# type: (Text, bool) -> Text
def get_avatar_url(self, hash_key: Text, medium: bool=False) -> Text:
bucket = settings.S3_AVATAR_BUCKET
medium_suffix = "-medium.png" if medium else ""
# ?x=x allows templates to append additional parameters with &s
return u"https://%s.s3.amazonaws.com/%s%s?x=x" % (bucket, hash_key, medium_suffix)
def upload_realm_icon_image(self, icon_file, user_profile):
# type: (File, UserProfile) -> None
def upload_realm_icon_image(self, icon_file: File, user_profile: UserProfile) -> None:
content_type = guess_type(icon_file.name)[0]
bucket_name = settings.S3_AVATAR_BUCKET
s3_file_name = os.path.join(str(user_profile.realm.id), 'realm', 'icon')
@@ -362,14 +343,12 @@ class S3UploadBackend(ZulipUploadBackend):
# See avatar_url in avatar.py for URL. (That code also handles the case
# that users use gravatar.)
def get_realm_icon_url(self, realm_id, version):
# type: (int, int) -> Text
def get_realm_icon_url(self, realm_id: int, version: int) -> Text:
bucket = settings.S3_AVATAR_BUCKET
# ?x=x allows templates to append additional parameters with &s
return u"https://%s.s3.amazonaws.com/%s/realm/icon.png?version=%s" % (bucket, realm_id, version)
def ensure_medium_avatar_image(self, user_profile):
# type: (UserProfile) -> None
def ensure_medium_avatar_image(self, user_profile: UserProfile) -> None:
file_path = user_avatar_path(user_profile)
s3_file_name = file_path
@@ -388,8 +367,8 @@ class S3UploadBackend(ZulipUploadBackend):
resized_medium
)
def upload_emoji_image(self, emoji_file, emoji_file_name, user_profile):
# type: (File, Text, UserProfile) -> None
def upload_emoji_image(self, emoji_file: File, emoji_file_name: Text,
user_profile: UserProfile) -> None:
content_type = guess_type(emoji_file.name)[0]
bucket_name = settings.S3_AVATAR_BUCKET
emoji_path = RealmEmoji.PATH_ID_TEMPLATE.format(
@@ -414,8 +393,7 @@ class S3UploadBackend(ZulipUploadBackend):
resized_image_data,
)
def get_emoji_url(self, emoji_file_name, realm_id):
# type: (Text, int) -> Text
def get_emoji_url(self, emoji_file_name: Text, realm_id: int) -> Text:
bucket = settings.S3_AVATAR_BUCKET
emoji_path = RealmEmoji.PATH_ID_TEMPLATE.format(realm_id=realm_id,
emoji_file_name=emoji_file_name)
@@ -424,15 +402,13 @@ class S3UploadBackend(ZulipUploadBackend):
### Local
def write_local_file(type, path, file_data):
# type: (Text, Text, bytes) -> None
def write_local_file(type: Text, path: Text, file_data: bytes) -> None:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f:
f.write(file_data)
def get_local_file_path(path_id):
# type: (Text) -> Optional[Text]
def get_local_file_path(path_id: Text) -> Optional[Text]:
local_path = os.path.join(settings.LOCAL_UPLOADS_DIR, 'files', path_id)
if os.path.isfile(local_path):
return local_path
@@ -455,8 +431,7 @@ class LocalUploadBackend(ZulipUploadBackend):
create_attachment(uploaded_file_name, path, user_profile, uploaded_file_size)
return '/user_uploads/' + path
def delete_message_image(self, path_id):
# type: (Text) -> bool
def delete_message_image(self, path_id: Text) -> bool:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, 'files', path_id)
if os.path.isfile(file_path):
# This removes the file but the empty folders still remain.
@@ -467,8 +442,9 @@ class LocalUploadBackend(ZulipUploadBackend):
logging.warning("%s does not exist. Its entry in the database will be removed." % (file_name,))
return False
def upload_avatar_image(self, user_file, acting_user_profile, target_user_profile):
# type: (File, UserProfile, UserProfile) -> None
def upload_avatar_image(self, user_file: File,
acting_user_profile: UserProfile,
target_user_profile: UserProfile) -> None:
file_path = user_avatar_path(target_user_profile)
image_data = user_file.read()
@@ -480,14 +456,12 @@ class LocalUploadBackend(ZulipUploadBackend):
resized_medium = resize_avatar(image_data, MEDIUM_AVATAR_SIZE)
write_local_file('avatars', file_path + '-medium.png', resized_medium)
def get_avatar_url(self, hash_key, medium=False):
# type: (Text, bool) -> Text
def get_avatar_url(self, hash_key: Text, medium: bool=False) -> Text:
# ?x=x allows templates to append additional parameters with &s
medium_suffix = "-medium" if medium else ""
return u"/user_avatars/%s%s.png?x=x" % (hash_key, medium_suffix)
def upload_realm_icon_image(self, icon_file, user_profile):
# type: (File, UserProfile) -> None
def upload_realm_icon_image(self, icon_file: File, user_profile: UserProfile) -> None:
upload_path = os.path.join('avatars', str(user_profile.realm.id), 'realm')
image_data = icon_file.read()
@@ -499,13 +473,11 @@ class LocalUploadBackend(ZulipUploadBackend):
resized_data = resize_avatar(image_data)
write_local_file(upload_path, 'icon.png', resized_data)
def get_realm_icon_url(self, realm_id, version):
# type: (int, int) -> Text
def get_realm_icon_url(self, realm_id: int, version: int) -> Text:
# ?x=x allows templates to append additional parameters with &s
return u"/user_avatars/%s/realm/icon.png?version=%s" % (realm_id, version)
def ensure_medium_avatar_image(self, user_profile):
# type: (UserProfile) -> None
def ensure_medium_avatar_image(self, user_profile: UserProfile) -> None:
file_path = user_avatar_path(user_profile)
output_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", file_path + "-medium.png")
@@ -517,8 +489,8 @@ class LocalUploadBackend(ZulipUploadBackend):
resized_medium = resize_avatar(image_data, MEDIUM_AVATAR_SIZE)
write_local_file('avatars', file_path + '-medium.png', resized_medium)
def upload_emoji_image(self, emoji_file, emoji_file_name, user_profile):
# type: (File, Text, UserProfile) -> None
def upload_emoji_image(self, emoji_file: File, emoji_file_name: Text,
user_profile: UserProfile) -> None:
emoji_path = RealmEmoji.PATH_ID_TEMPLATE.format(
realm_id= user_profile.realm_id,
emoji_file_name=emoji_file_name
@@ -535,8 +507,7 @@ class LocalUploadBackend(ZulipUploadBackend):
emoji_path,
resized_image_data)
def get_emoji_url(self, emoji_file_name, realm_id):
# type: (Text, int) -> Text
def get_emoji_url(self, emoji_file_name: Text, realm_id: int) -> Text:
return os.path.join(
u"/user_avatars",
RealmEmoji.PATH_ID_TEMPLATE.format(realm_id=realm_id, emoji_file_name=emoji_file_name))
@@ -547,20 +518,17 @@ if settings.LOCAL_UPLOADS_DIR is not None:
else:
upload_backend = S3UploadBackend()
def delete_message_image(path_id):
# type: (Text) -> bool
def delete_message_image(path_id: Text) -> bool:
return upload_backend.delete_message_image(path_id)
def upload_avatar_image(user_file, acting_user_profile, target_user_profile):
# type: (File, UserProfile, UserProfile) -> None
def upload_avatar_image(user_file: File, acting_user_profile: UserProfile,
target_user_profile: UserProfile) -> None:
upload_backend.upload_avatar_image(user_file, acting_user_profile, target_user_profile)
def upload_icon_image(user_file, user_profile):
# type: (File, UserProfile) -> None
def upload_icon_image(user_file: File, user_profile: UserProfile) -> None:
upload_backend.upload_realm_icon_image(user_file, user_profile)
def upload_emoji_image(emoji_file, emoji_file_name, user_profile):
# type: (File, Text, UserProfile) -> None
def upload_emoji_image(emoji_file: File, emoji_file_name: Text, user_profile: UserProfile) -> None:
upload_backend.upload_emoji_image(emoji_file, emoji_file_name, user_profile)
def upload_message_image(uploaded_file_name, uploaded_file_size,
@@ -570,21 +538,23 @@ def upload_message_image(uploaded_file_name, uploaded_file_size,
content_type, file_data, user_profile,
target_realm=target_realm)
def claim_attachment(user_profile, path_id, message, is_message_realm_public):
# type: (UserProfile, Text, Message, bool) -> None
def claim_attachment(user_profile: UserProfile,
path_id: Text,
message: Message,
is_message_realm_public: bool) -> None:
attachment = Attachment.objects.get(path_id=path_id)
attachment.messages.add(message)
attachment.is_realm_public = attachment.is_realm_public or is_message_realm_public
attachment.save()
def create_attachment(file_name, path_id, user_profile, file_size):
# type: (Text, Text, UserProfile, int) -> bool
def create_attachment(file_name: Text, path_id: Text, user_profile: UserProfile,
file_size: int) -> bool:
Attachment.objects.create(file_name=file_name, path_id=path_id, owner=user_profile,
realm=user_profile.realm, size=file_size)
return True
def upload_message_image_from_request(request, user_file, user_profile):
# type: (HttpRequest, File, UserProfile) -> Text
def upload_message_image_from_request(request: HttpRequest, user_file: File,
user_profile: UserProfile) -> Text:
uploaded_file_name, uploaded_file_size, content_type = get_file_info(request, user_file)
return upload_message_image(uploaded_file_name, uploaded_file_size,
content_type, user_file.read(), user_profile)

View File

@@ -4,8 +4,7 @@ from typing import Optional, Dict
# Warning: If you change this parsing, please test using
# zerver/tests/test_decorators.py
# And extend zerver/fixtures/user_agents_unique with any new test cases
def parse_user_agent(user_agent):
# type: (str) -> Optional[Dict[str, str]]
def parse_user_agent(user_agent: str) -> Optional[Dict[str, str]]:
match = re.match("^(?P<name>[^/ ]*[^0-9/(]*)(/(?P<version>[^/ ]*))?([ /].*)?$", user_agent)
if match is None:
return None

View File

@@ -14,13 +14,11 @@ def access_user_group_by_id(user_group_id: int, realm: Realm) -> UserGroup:
raise JsonableError(_("Invalid user group"))
return user_group
def user_groups_in_realm(realm):
# type: (Realm) -> List[UserGroup]
def user_groups_in_realm(realm: Realm) -> List[UserGroup]:
user_groups = UserGroup.objects.filter(realm=realm)
return list(user_groups)
def user_groups_in_realm_serialized(realm):
# type: (Realm) -> List[Dict[Text, Any]]
def user_groups_in_realm_serialized(realm: Realm) -> List[Dict[Text, Any]]:
"""
This function is used in do_events_register code path so this code should
be performant. This is the reason why we get the groups through
@@ -43,32 +41,28 @@ def user_groups_in_realm_serialized(realm):
user_groups.sort(key=lambda item: item['id'])
return user_groups
def get_user_groups(user_profile):
# type: (UserProfile) -> List[UserGroup]
def get_user_groups(user_profile: UserProfile) -> List[UserGroup]:
return list(user_profile.usergroup_set.all())
def check_add_user_to_user_group(user_profile, user_group):
# type: (UserProfile, UserGroup) -> bool
def check_add_user_to_user_group(user_profile: UserProfile, user_group: UserGroup) -> bool:
member_obj, created = UserGroupMembership.objects.get_or_create(
user_group=user_group, user_profile=user_profile)
return created
def remove_user_from_user_group(user_profile, user_group):
# type: (UserProfile, UserGroup) -> int
def remove_user_from_user_group(user_profile: UserProfile, user_group: UserGroup) -> int:
num_deleted, _ = UserGroupMembership.objects.filter(
user_profile=user_profile, user_group=user_group).delete()
return num_deleted
def check_remove_user_from_user_group(user_profile, user_group):
# type: (UserProfile, UserGroup) -> bool
def check_remove_user_from_user_group(user_profile: UserProfile, user_group: UserGroup) -> bool:
try:
num_deleted = remove_user_from_user_group(user_profile, user_group)
return bool(num_deleted)
except Exception:
return False
def create_user_group(name, members, realm, description=''):
# type: (Text, List[UserProfile], Realm, Text) -> UserGroup
def create_user_group(name: Text, members: List[UserProfile], realm: Realm,
description: Text='') -> UserGroup:
with transaction.atomic():
user_group = UserGroup.objects.create(name=name, realm=realm,
description=description)