Add PEP-484 type annotations to zerver/lib/.

This commit is contained in:
Tim Abbott
2016-01-25 14:42:16 -08:00
parent d8f7d89fb4
commit 2059f650ab
17 changed files with 62 additions and 47 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from typing import *
from django.conf import settings from django.conf import settings
from django.core import validators from django.core import validators
@@ -575,7 +576,7 @@ def do_send_messages(messages):
message['message'].update_calculated_fields() message['message'].update_calculated_fields()
# Save the message receipts in the database # Save the message receipts in the database
user_message_flags = defaultdict(dict) user_message_flags = defaultdict(dict) # type: Dict[int, Dict[int, List[str]]]
with transaction.atomic(): with transaction.atomic():
Message.objects.bulk_create([message['message'] for message in messages]) Message.objects.bulk_create([message['message'] for message in messages])
ums = [] ums = []
@@ -1042,7 +1043,7 @@ def bulk_get_subscriber_user_ids(stream_dicts, user_profile, sub_dict):
user_profile__is_active=True, user_profile__is_active=True,
active=True).values("user_profile_id", "recipient__type_id") active=True).values("user_profile_id", "recipient__type_id")
result = dict((stream["id"], []) for stream in stream_dicts) result = dict((stream["id"], []) for stream in stream_dicts) # type: Dict[int, List[int]]
for sub in subscriptions: for sub in subscriptions:
result[sub["recipient__type_id"]].append(sub["user_profile_id"]) result[sub["recipient__type_id"]].append(sub["user_profile_id"])
@@ -1114,7 +1115,7 @@ def get_subscribers_to_streams(streams):
arrays of all the streams within 'streams' to which that user is arrays of all the streams within 'streams' to which that user is
subscribed. subscribed.
""" """
subscribes_to = {} subscribes_to = {} # type: Dict[str, List[Stream]]
for stream in streams: for stream in streams:
try: try:
subscribers = get_subscribers(stream) subscribers = get_subscribers(stream)
@@ -1160,7 +1161,7 @@ def bulk_add_subscriptions(streams, users):
for stream in streams: for stream in streams:
stream_map[recipients_map[stream.id].id] = stream stream_map[recipients_map[stream.id].id] = stream
subs_by_user = defaultdict(list) subs_by_user = defaultdict(list) # type: Dict[int, List[Subscription]]
all_subs_query = Subscription.objects.select_related("user_profile") all_subs_query = Subscription.objects.select_related("user_profile")
for sub in all_subs_query.filter(user_profile__in=users, for sub in all_subs_query.filter(user_profile__in=users,
recipient__type=Recipient.STREAM): recipient__type=Recipient.STREAM):
@@ -1222,8 +1223,8 @@ def bulk_add_subscriptions(streams, users):
user_profile__is_active=True, user_profile__is_active=True,
active=True).select_related('recipient', 'user_profile') active=True).select_related('recipient', 'user_profile')
all_subs_by_stream = defaultdict(list) all_subs_by_stream = defaultdict(list) # type: Dict[int, List[UserProfile]]
emails_by_stream = defaultdict(list) emails_by_stream = defaultdict(list) # type: Dict[int, List[str]]
for sub in all_subs: for sub in all_subs:
all_subs_by_stream[sub.recipient.type_id].append(sub.user_profile) all_subs_by_stream[sub.recipient.type_id].append(sub.user_profile)
emails_by_stream[sub.recipient.type_id].append(sub.user_profile.email) emails_by_stream[sub.recipient.type_id].append(sub.user_profile.email)
@@ -1233,7 +1234,7 @@ def bulk_add_subscriptions(streams, users):
return [] return []
return emails_by_stream[stream.id] return emails_by_stream[stream.id]
sub_tuples_by_user = defaultdict(list) sub_tuples_by_user = defaultdict(list) # type: Dict[int, List[Tuple[Subscription, Stream]]]
new_streams = set() new_streams = set()
for (sub, stream) in subs_to_add + subs_to_activate: for (sub, stream) in subs_to_add + subs_to_activate:
sub_tuples_by_user[sub.user_profile.id].append((sub, stream)) sub_tuples_by_user[sub.user_profile.id].append((sub, stream))
@@ -1336,7 +1337,7 @@ def bulk_remove_subscriptions(users, streams):
for stream in streams: for stream in streams:
stream_map[recipients_map[stream.id].id] = stream stream_map[recipients_map[stream.id].id] = stream
subs_by_user = dict((user_profile.id, []) for user_profile in users) subs_by_user = dict((user_profile.id, []) for user_profile in users) # type: Dict[int, List[Subscription]]
for sub in Subscription.objects.select_related("user_profile").filter(user_profile__in=users, for sub in Subscription.objects.select_related("user_profile").filter(user_profile__in=users,
recipient__in=list(recipients_map.values()), recipient__in=list(recipients_map.values()),
active=True): active=True):
@@ -1369,7 +1370,7 @@ def bulk_remove_subscriptions(users, streams):
for stream in new_vacant_streams]) for stream in new_vacant_streams])
send_event(event, active_user_ids(user_profile.realm)) send_event(event, active_user_ids(user_profile.realm))
streams_by_user = defaultdict(list) streams_by_user = defaultdict(list) # type: Dict[int, List[Stream]]
for (sub, stream) in subs_to_deactivate: for (sub, stream) in subs_to_deactivate:
streams_by_user[sub.user_profile_id].append(stream) streams_by_user[sub.user_profile_id].append(stream)
@@ -2786,7 +2787,7 @@ def do_invite_users(user_profile, invitee_emails, streams):
skipped = [] skipped = []
ret_error = None ret_error = None
ret_error_data = {} ret_error_data = {} # type: Dict[str, List[Tuple[str, str]]]
for email in invitee_emails: for email in invitee_emails:
if email == '': if email == '':

View File

@@ -1,6 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
# Zulip's main markdown implementation. See docs/markdown.md for # Zulip's main markdown implementation. See docs/markdown.md for
# detailed documentation on our markdown syntax. # detailed documentation on our markdown syntax.
from typing import *
import codecs import codecs
import markdown import markdown
@@ -561,7 +562,7 @@ class Emoji(markdown.inlinepatterns.Pattern):
orig_syntax = match.group("syntax") orig_syntax = match.group("syntax")
name = orig_syntax[1:-1] name = orig_syntax[1:-1]
realm_emoji = {} realm_emoji = {} # type: Dict[str, str]
if db_data is not None: if db_data is not None:
realm_emoji = db_data['emoji'] realm_emoji = db_data['emoji']
@@ -992,7 +993,7 @@ def make_md_engine(key, opts):
def subject_links(domain, subject): def subject_links(domain, subject):
from zerver.models import get_realm, RealmFilter, realm_filters_for_domain from zerver.models import get_realm, RealmFilter, realm_filters_for_domain
matches = [] matches = [] # type: List[str]
try: try:
realm_filters = realm_filters_for_domain(domain) realm_filters = realm_filters_for_domain(domain)
@@ -1048,12 +1049,12 @@ def _sanitize_for_log(md):
# Filters such as UserMentionPattern need a message, but python-markdown # Filters such as UserMentionPattern need a message, but python-markdown
# provides no way to pass extra params through to a pattern. Thus, a global. # provides no way to pass extra params through to a pattern. Thus, a global.
current_message = None current_message = None # type: Any # Should be Message but bugdown doesn't import models.py.
# We avoid doing DB queries in our markdown thread to avoid the overhead of # We avoid doing DB queries in our markdown thread to avoid the overhead of
# opening a new DB connection. These connections tend to live longer than the # opening a new DB connection. These connections tend to live longer than the
# threads themselves, as well. # threads themselves, as well.
db_data = None db_data = None # type: Dict[str, Any]
def do_convert(md, realm_domain=None, message=None): def do_convert(md, realm_domain=None, message=None):
"""Convert Markdown to HTML, with Zulip-specific settings and hacks.""" """Convert Markdown to HTML, with Zulip-specific settings and hacks."""

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import *
# This file needs to be different from cache.py because cache.py # This file needs to be different from cache.py because cache.py
# cannot import anything from zerver.models or we'd have an import # cannot import anything from zerver.models or we'd have an import
@@ -75,7 +76,7 @@ cache_fillers = {
def fill_remote_cache(cache): def fill_remote_cache(cache):
remote_cache_time_start = get_remote_cache_time() remote_cache_time_start = get_remote_cache_time()
remote_cache_requests_start = get_remote_cache_requests() remote_cache_requests_start = get_remote_cache_requests()
items_for_remote_cache = {} items_for_remote_cache = {} # type: Dict[str, Any]
(objects, items_filler, timeout, batch_size) = cache_fillers[cache] (objects, items_filler, timeout, batch_size) = cache_fillers[cache]
count = 0 count = 0
for obj in objects(): for obj in objects():

View File

@@ -29,7 +29,7 @@ class TimeTrackingConnection(connection):
"""A psycopg2 connection class that uses TimeTrackingCursors.""" """A psycopg2 connection class that uses TimeTrackingCursors."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.queries = [] self.queries = [] # type: List[Dict[str, str]]
super(TimeTrackingConnection, self).__init__(*args, **kwargs) super(TimeTrackingConnection, self).__init__(*args, **kwargs)
def cursor(self, name=None): def cursor(self, name=None):

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import *
from collections import defaultdict from collections import defaultdict
import datetime import datetime
@@ -40,8 +41,8 @@ def gather_hot_conversations(user_profile, stream_messages):
# Returns a list of dictionaries containing the templating # Returns a list of dictionaries containing the templating
# information for each hot conversation. # information for each hot conversation.
conversation_length = defaultdict(int) conversation_length = defaultdict(int) # type: Dict[Tuple[int, str], int]
conversation_diversity = defaultdict(set) conversation_diversity = defaultdict(set) # type: Dict[Tuple[int, str], Set[str]]
for user_message in stream_messages: for user_message in stream_messages:
if not user_message.message.sent_by_human(): if not user_message.message.sent_by_human():
# Don't include automated messages in the count. # Don't include automated messages in the count.
@@ -99,7 +100,7 @@ def gather_new_users(user_profile, threshold):
# Gather information on users in the realm who have recently # Gather information on users in the realm who have recently
# joined. # joined.
if user_profile.realm.domain == "mit.edu": if user_profile.realm.domain == "mit.edu":
new_users = [] new_users = [] # type: List[UserProfile]
else: else:
new_users = list(UserProfile.objects.filter( new_users = list(UserProfile.objects.filter(
realm=user_profile.realm, date_joined__gt=threshold, realm=user_profile.realm, date_joined__gt=threshold,
@@ -110,7 +111,7 @@ def gather_new_users(user_profile, threshold):
def gather_new_streams(user_profile, threshold): def gather_new_streams(user_profile, threshold):
if user_profile.realm.domain == "mit.edu": if user_profile.realm.domain == "mit.edu":
new_streams = [] new_streams = [] # type: List[Stream]
else: else:
new_streams = list(get_active_streams(user_profile.realm).filter( new_streams = list(get_active_streams(user_profile.realm).filter(
invite_only=False, date_created__gt=threshold)) invite_only=False, date_created__gt=threshold))

View File

@@ -235,7 +235,7 @@ def find_emailgateway_recipient(message):
# it is more accurate, so try to find the most-accurate # it is more accurate, so try to find the most-accurate
# recipient list in descending priority order # recipient list in descending priority order
recipient_headers = ["X-Gm-Original-To", "Delivered-To", "To"] recipient_headers = ["X-Gm-Original-To", "Delivered-To", "To"]
recipients = [] recipients = [] # type: List[str]
for recipient_header in recipient_headers: for recipient_header in recipient_headers:
r = message.get_all(recipient_header, None) r = message.get_all(recipient_header, None)
if r: if r:

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import *
from django.conf import settings from django.conf import settings
from django.utils.timezone import now from django.utils.timezone import now
@@ -58,8 +59,8 @@ class ClientDescriptor(object):
self.user_profile_id = user_profile_id self.user_profile_id = user_profile_id
self.user_profile_email = user_profile_email self.user_profile_email = user_profile_email
self.realm_id = realm_id self.realm_id = realm_id
self.current_handler_id = None self.current_handler_id = None # type: int
self.current_client_name = None self.current_client_name = None # type: str
self.event_queue = event_queue self.event_queue = event_queue
self.queue_timeout = lifespan_secs self.queue_timeout = lifespan_secs
self.event_types = event_types self.event_types = event_types
@@ -192,7 +193,7 @@ class ClientDescriptor(object):
do_gc_event_queues([self.event_queue.id], [self.user_profile_id], do_gc_event_queues([self.event_queue.id], [self.user_profile_id],
[self.realm_id]) [self.realm_id])
descriptors_by_handler_id = {} descriptors_by_handler_id = {} # type: Dict[int, ClientDescriptor]
def get_descriptor_by_handler_id(handler_id): def get_descriptor_by_handler_id(handler_id):
return descriptors_by_handler_id.get(handler_id) return descriptors_by_handler_id.get(handler_id)
@@ -213,10 +214,11 @@ def compute_full_event_type(event):
class EventQueue(object): class EventQueue(object):
def __init__(self, id): def __init__(self, id):
self.queue = deque() # type: (Any) -> None
self.queue = deque() # type: deque[Dict[str, str]]
self.next_event_id = 0 self.next_event_id = 0
self.id = id self.id = id
self.virtual_events = {} self.virtual_events = {} # type: Dict[str, Dict[str, str]]
def to_dict(self): def to_dict(self):
# If you add a new key to this dict, make sure you add appropriate # If you add a new key to this dict, make sure you add appropriate
@@ -296,7 +298,7 @@ class EventQueue(object):
return contents return contents
# maps queue ids to client descriptors # maps queue ids to client descriptors
clients = {} # type: Dict[int, ClientDescriptor] clients = {} # type: Dict[str, ClientDescriptor]
# maps user id to list of client descriptors # maps user id to list of client descriptors
user_clients = {} # type: Dict[int, List[ClientDescriptor]] user_clients = {} # type: Dict[int, List[ClientDescriptor]]
# maps realm id to list of client descriptors with all_public_streams=True # maps realm id to list of client descriptors with all_public_streams=True
@@ -432,7 +434,7 @@ def setup_event_queue():
atexit.register(dump_event_queues) atexit.register(dump_event_queues)
# Make sure we dump event queues even if we exit via signal # Make sure we dump event queues even if we exit via signal
signal.signal(signal.SIGTERM, lambda signum, stack: sys.exit(1)) signal.signal(signal.SIGTERM, lambda signum, stack: sys.exit(1))
tornado.autoreload.add_reload_hook(dump_event_queues) tornado.autoreload.add_reload_hook(dump_event_queues) # type: ignore # TODO: Fix missing tornado.autoreload stub
try: try:
os.rename(settings.JSON_PERSISTENT_QUEUE_FILENAME, "/var/tmp/event_queues.json.last") os.rename(settings.JSON_PERSISTENT_QUEUE_FILENAME, "/var/tmp/event_queues.json.last")

View File

@@ -1,8 +1,9 @@
import logging import logging
from zerver.middleware import async_request_restart from zerver.middleware import async_request_restart
from typing import *
current_handler_id = 0 current_handler_id = 0
handlers = {} handlers = {} # type: Dict[int, Any] # TODO: Should be AsyncDjangoHandler but we don't important runtornado.py.
def get_handler_by_id(handler_id): def get_handler_by_id(handler_id):
return handlers[handler_id] return handlers[handler_id]

View File

@@ -1,4 +1,6 @@
from __future__ import print_function from __future__ import print_function
from typing import *
from confirmation.models import Confirmation from confirmation.models import Confirmation
from django.conf import settings from django.conf import settings
from django.core.mail import EmailMultiAlternatives from django.core.mail import EmailMultiAlternatives
@@ -7,7 +9,7 @@ from zerver.decorator import statsd_increment, uses_mandrill
from zerver.models import Recipient, ScheduledJob, UserMessage, \ from zerver.models import Recipient, ScheduledJob, UserMessage, \
Stream, get_display_recipient, get_user_profile_by_email, \ Stream, get_display_recipient, get_user_profile_by_email, \
get_user_profile_by_id, receives_offline_notifications, \ get_user_profile_by_id, receives_offline_notifications, \
get_context_for_message get_context_for_message, Message
import datetime import datetime
import re import re
@@ -58,7 +60,7 @@ def build_message_list(user_profile, messages):
The messages are collapsed into per-recipient and per-sender blocks, like The messages are collapsed into per-recipient and per-sender blocks, like
our web interface our web interface
""" """
messages_to_render = [] messages_to_render = [] # type: List[Dict[str, Any]]
def sender_string(message): def sender_string(message):
sender = '' sender = ''
@@ -324,7 +326,7 @@ def handle_missedmessage_emails(user_profile_id, missed_email_events):
if not messages: if not messages:
return return
messages_by_recipient_subject = defaultdict(list) messages_by_recipient_subject = defaultdict(list) # type: Dict[Tuple[int, str], List[Message]]
for msg in messages: for msg in messages:
messages_by_recipient_subject[(msg.recipient_id, msg.subject)].append(msg) messages_by_recipient_subject[(msg.recipient_id, msg.subject)].append(msg)

View File

@@ -1,5 +1,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import print_function from __future__ import print_function
from typing import *
import os import os
import pty import pty
@@ -7,7 +8,8 @@ import sys
import errno import errno
def run_parallel(job, data, threads=6): def run_parallel(job, data, threads=6):
pids = {} # type: (Any, Iterable[Any], int) -> Generator[Tuple[int, Any], None, None]
pids = {} # type: Dict[int, Any]
def wait_for_one(): def wait_for_one():
while True: while True:

View File

@@ -11,6 +11,7 @@ import atexit
from collections import defaultdict from collections import defaultdict
from zerver.lib.utils import statsd from zerver.lib.utils import statsd
from typing import *
# This simple queuing library doesn't expose much of the power of # This simple queuing library doesn't expose much of the power of
# rabbitmq/pika's queuing system; its purpose is to just provide an # rabbitmq/pika's queuing system; its purpose is to just provide an
@@ -19,9 +20,9 @@ from zerver.lib.utils import statsd
class SimpleQueueClient(object): class SimpleQueueClient(object):
def __init__(self): def __init__(self):
self.log = logging.getLogger('zulip.queue') self.log = logging.getLogger('zulip.queue')
self.queues = set() self.queues = set() # type: Set[str]
self.channel = None self.channel = None # type: Any
self.consumers = defaultdict(set) self.consumers = defaultdict(set) # type: Dict[str, Set[Any]]
self._connect() self._connect()
def _connect(self): def _connect(self):
@@ -156,7 +157,7 @@ class TornadoQueueClient(SimpleQueueClient):
# https://pika.readthedocs.org/en/0.9.8/examples/asynchronous_consumer_example.html # https://pika.readthedocs.org/en/0.9.8/examples/asynchronous_consumer_example.html
def __init__(self): def __init__(self):
super(TornadoQueueClient, self).__init__() super(TornadoQueueClient, self).__init__()
self._on_open_cbs = [] self._on_open_cbs = [] # type: List[Callable[[], None]]
def _connect(self, on_open_cb = None): def _connect(self, on_open_cb = None):
self.log.info("Beginning TornadoQueueClient connection") self.log.info("Beginning TornadoQueueClient connection")
@@ -230,7 +231,7 @@ class TornadoQueueClient(SimpleQueueClient):
lambda: self.channel.basic_consume(wrapped_consumer, queue=queue_name, lambda: self.channel.basic_consume(wrapped_consumer, queue=queue_name,
consumer_tag=self._generate_ctag(queue_name))) consumer_tag=self._generate_ctag(queue_name)))
queue_client = None queue_client = None # type: SimpleQueueClient
def get_queue_client(): def get_queue_client():
global queue_client global queue_client
if queue_client is None: if queue_client is None:

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import *
from django.conf import settings from django.conf import settings
from django.utils.importlib import import_module from django.utils.importlib import import_module
@@ -79,7 +80,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
self.authenticated = False self.authenticated = False
self.session.user_profile = None self.session.user_profile = None
self.close_info = None self.close_info = None # type: CloseErrorInfo
self.did_close = False self.did_close = False
try: try:
@@ -226,7 +227,7 @@ class SocketConnection(sockjs.tornado.SockJSConnection):
self.did_close = True self.did_close = True
def fake_message_sender(event): def fake_message_sender(event):
log_data = dict() log_data = dict() # type: Dict[str, Any]
record_request_start_data(log_data) record_request_start_data(log_data)
req = event['request'] req = event['request']

View File

@@ -166,7 +166,7 @@ class DummyObject(object):
class DummyTornadoRequest(object): class DummyTornadoRequest(object):
def __init__(self): def __init__(self):
self.connection = DummyObject() self.connection = DummyObject()
self.connection.stream = DummyStream() self.connection.stream = DummyStream() # type: ignore # monkey-patching here
class DummyHandler(object): class DummyHandler(object):
def __init__(self, assert_callback): def __init__(self, assert_callback):
@@ -202,7 +202,7 @@ class POSTRequestMock(object):
self.user = user_profile self.user = user_profile
self._tornado_handler = DummyHandler(assert_callback) self._tornado_handler = DummyHandler(assert_callback)
self.session = DummySession() self.session = DummySession()
self._log_data = {} self._log_data = {} # type: Dict[str, Any]
self.META = {'PATH_INFO': 'test'} self.META = {'PATH_INFO': 'test'}
class AuthedTestCase(TestCase): class AuthedTestCase(TestCase):

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import *
import sys import sys
import time import time
@@ -33,8 +34,8 @@ def timeout(timeout, func, *args, **kwargs):
class TimeoutThread(threading.Thread): class TimeoutThread(threading.Thread):
def __init__(self): def __init__(self):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.result = None self.result = None # type: Any
self.exc_info = None self.exc_info = None # type: Tuple[type, BaseException, Any]
# Don't block the whole program from exiting # Don't block the whole program from exiting
# if this is the only thread left. # if this is the only thread left.

View File

@@ -1,5 +1,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from typing import *
import logging import logging
import time import time
@@ -37,7 +38,7 @@ except:
class InstrumentedPoll(object): class InstrumentedPoll(object):
def __init__(self): def __init__(self):
self._underlying = orig_poll_impl() self._underlying = orig_poll_impl()
self._times = [] self._times = [] # type: List[Tuple[float, float]]
self._last_print = 0.0 self._last_print = 0.0
# Python won't let us subclass e.g. select.epoll, so instead # Python won't let us subclass e.g. select.epoll, so instead

View File

@@ -11,7 +11,7 @@ class SourceMap(object):
def __init__(self, sourcemap_dir): def __init__(self, sourcemap_dir):
self._dir = sourcemap_dir self._dir = sourcemap_dir
self._indices = {} self._indices = {} # type: Dict[str, sourcemap.SourceMapDecoder]
def _index_for(self, minified_src): def _index_for(self, minified_src):
'''Return the source map index for minified_src, loading it if not '''Return the source map index for minified_src, loading it if not

View File

@@ -59,7 +59,7 @@ class QueueProcessingWorker(object):
queue_name = None queue_name = None
def __init__(self): def __init__(self):
self.q = None self.q = None # type: SimpleQueueClient
if self.queue_name is None: if self.queue_name is None:
raise WorkerDeclarationException("Queue worker declared without queue_name") raise WorkerDeclarationException("Queue worker declared without queue_name")