Type annotation of zerver/models.py

[Substantially revised by tabbott]

This probably still has some bugs in it, but having mostly complete
annotations for models.py will help a lot for the annotations folks
are adding to other files.
This commit is contained in:
Ashish Kumar
2016-05-07 21:32:57 +05:30
committed by Tim Abbott
parent 37015fd7c5
commit 31bf6b8259
4 changed files with 106 additions and 8 deletions

View File

@@ -3142,7 +3142,7 @@ def do_remove_realm_filter(realm, pattern):
notify_realm_filters(realm) notify_realm_filters(realm)
def get_emails_from_user_ids(user_ids): def get_emails_from_user_ids(user_ids):
# type: (Iterable[int]) -> Dict[int, text_type] # type: (Sequence[int]) -> Dict[int, text_type]
# We may eventually use memcached to speed this up, but the DB is fast. # We may eventually use memcached to speed this up, but the DB is fast.
return UserProfile.emails_from_ids(user_ids) return UserProfile.emails_from_ids(user_ids)

View File

@@ -202,7 +202,7 @@ def generic_bulk_cached_fetch(cache_key_function, query_function, object_ids,
setter=lambda obj: obj, setter=lambda obj: obj,
id_fetcher=lambda obj: obj.id, id_fetcher=lambda obj: obj.id,
cache_transformer=lambda obj: obj): cache_transformer=lambda obj: obj):
# type: (Callable[[Any], str], Callable[[List[int]], List[Any]], List[Any], Callable[[Any], Any], Callable[[Any], Any], Callable[[Any], Any], Callable[[Any], Any]) -> Dict[int, Any] # type: (Callable[[Any], str], Callable[[List[Any]], List[Any]], List[Any], Callable[[Any], Any], Callable[[Any], Any], Callable[[Any], Any], Callable[[Any], Any]) -> Dict[Any, Any]
cache_keys = {} # type: Dict[int, str] cache_keys = {} # type: Dict[int, str]
for object_id in object_ids: for object_id in object_ids:
cache_keys[object_id] = cache_key_function(object_id) cache_keys[object_id] = cache_key_function(object_id)

View File

@@ -1,7 +1,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from typing import Any, Tuple from typing import Any, List, Set, Tuple, TypeVar, \
Union, Optional, Sequence, AbstractSet
from typing.re import Match
from django.db import models from django.db import models
from django.db.models.query import QuerySet
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, UserManager, \ from django.contrib.auth.models import AbstractBaseUser, UserManager, \
PermissionsMixin PermissionsMixin
@@ -31,6 +34,7 @@ import pylibmc
import re import re
import ujson import ujson
import logging import logging
from six import text_type
import time import time
import datetime import datetime
@@ -39,16 +43,20 @@ bugdown = None # type: Any
MAX_SUBJECT_LENGTH = 60 MAX_SUBJECT_LENGTH = 60
MAX_MESSAGE_LENGTH = 10000 MAX_MESSAGE_LENGTH = 10000
STREAM_NAMES = TypeVar('STREAM_NAMES', Sequence[str], AbstractSet[str])
# Doing 1000 remote cache requests to get_display_recipient is quite slow, # Doing 1000 remote cache requests to get_display_recipient is quite slow,
# so add a local cache as well as the remote cache cache. # so add a local cache as well as the remote cache cache.
per_request_display_recipient_cache = {} # type: Dict[int, List[Dict[str, Any]]] per_request_display_recipient_cache = {} # type: Dict[int, List[Dict[text_type, Any]]]
def get_display_recipient_by_id(recipient_id, recipient_type, recipient_type_id): def get_display_recipient_by_id(recipient_id, recipient_type, recipient_type_id):
## type: (int, int, int) -> Union[text_type, List[Dict[text_type, Any]]]
if recipient_id not in per_request_display_recipient_cache: if recipient_id not in per_request_display_recipient_cache:
result = get_display_recipient_remote_cache(recipient_id, recipient_type, recipient_type_id) result = get_display_recipient_remote_cache(recipient_id, recipient_type, recipient_type_id)
per_request_display_recipient_cache[recipient_id] = result per_request_display_recipient_cache[recipient_id] = result
return per_request_display_recipient_cache[recipient_id] return per_request_display_recipient_cache[recipient_id]
def get_display_recipient(recipient): def get_display_recipient(recipient):
## type: (Recipient) -> Union[text_type, List[Dict[text_type, Any]]]
return get_display_recipient_by_id( return get_display_recipient_by_id(
recipient.id, recipient.id,
recipient.type, recipient.type,
@@ -56,6 +64,7 @@ def get_display_recipient(recipient):
) )
def flush_per_request_caches(): def flush_per_request_caches():
# type: () -> None
global per_request_display_recipient_cache global per_request_display_recipient_cache
per_request_display_recipient_cache = {} per_request_display_recipient_cache = {}
global per_request_realm_filters_cache global per_request_realm_filters_cache
@@ -64,6 +73,7 @@ def flush_per_request_caches():
@cache_with_key(lambda *args: display_recipient_cache_key(args[0]), @cache_with_key(lambda *args: display_recipient_cache_key(args[0]),
timeout=3600*24*7) timeout=3600*24*7)
def get_display_recipient_remote_cache(recipient_id, recipient_type, recipient_type_id): def get_display_recipient_remote_cache(recipient_id, recipient_type, recipient_type_id):
## type: (int, int, int) -> Union[text_type, List[Dict[text_type, Any]]]
""" """
returns: an appropriate object describing the recipient. For a returns: an appropriate object describing the recipient. For a
stream this will be the stream name as a string. For a huddle or stream this will be the stream name as a string. For a huddle or
@@ -85,6 +95,7 @@ def get_display_recipient_remote_cache(recipient_id, recipient_type, recipient_t
'is_mirror_dummy': user_profile.is_mirror_dummy,} for user_profile in user_profile_list] 'is_mirror_dummy': user_profile.is_mirror_dummy,} for user_profile in user_profile_list]
def completely_open(domain): def completely_open(domain):
# type: (text_type) -> bool
# This domain is completely open to everyone on the internet to # This domain is completely open to everyone on the internet to
# join. E-mail addresses do not need to match the domain and # join. E-mail addresses do not need to match the domain and
# an invite from an existing user is not required. # an invite from an existing user is not required.
@@ -94,6 +105,7 @@ def completely_open(domain):
return not realm.invite_required and not realm.restricted_to_domain return not realm.invite_required and not realm.restricted_to_domain
def get_unique_open_realm(): def get_unique_open_realm():
# type: () -> Optional[Realm]
# We only return a realm if there is a unique realm and it is completely open. # We only return a realm if there is a unique realm and it is completely open.
realms = Realm.objects.filter(deactivated=False) realms = Realm.objects.filter(deactivated=False)
if settings.VOYAGER: if settings.VOYAGER:
@@ -109,6 +121,7 @@ def get_unique_open_realm():
return realm return realm
def get_realm_emoji_cache_key(realm): def get_realm_emoji_cache_key(realm):
# type: (Realm) -> str
return 'realm_emoji:%s' % (realm.id,) return 'realm_emoji:%s' % (realm.id,)
class Realm(models.Model): class Realm(models.Model):
@@ -154,10 +167,12 @@ class Realm(models.Model):
self._deployments = [value] # type: Any self._deployments = [value] # type: Any
def get_admin_users(self): def get_admin_users(self):
# type: () -> List[UserProfile]
return UserProfile.objects.filter(realm=self, is_realm_admin=True, return UserProfile.objects.filter(realm=self, is_realm_admin=True,
is_active=True).select_related() is_active=True).select_related()
def get_active_users(self): def get_active_users(self):
# type: () -> List[UserProfile]
return UserProfile.objects.filter(realm=self, is_active=True).select_related() return UserProfile.objects.filter(realm=self, is_active=True).select_related()
class Meta(object): class Meta(object):
@@ -179,15 +194,18 @@ class RealmAlias(models.Model):
# "tabbott@test"@zulip.com # "tabbott@test"@zulip.com
# is valid email address # is valid email address
def email_to_username(email): def email_to_username(email):
# type: (text_type) -> text_type
return "@".join(email.split("@")[:-1]).lower() return "@".join(email.split("@")[:-1]).lower()
# Returns the raw domain portion of the desired email address # Returns the raw domain portion of the desired email address
def split_email_to_domain(email): def split_email_to_domain(email):
# type: (text_type) -> text_type
return email.split("@")[-1].lower() return email.split("@")[-1].lower()
# Returns the domain, potentually de-aliased, for the realm # Returns the domain, potentually de-aliased, for the realm
# that this user's email is in # that this user's email is in
def resolve_email_to_domain(email): def resolve_email_to_domain(email):
# type: (text_type) -> text_type
domain = split_email_to_domain(email) domain = split_email_to_domain(email)
alias = alias_for_realm(domain) alias = alias_for_realm(domain)
if alias is not None: if alias is not None:
@@ -199,6 +217,7 @@ def resolve_email_to_domain(email):
# So for invite-only realms, this is the test for whether a user can be invited, # So for invite-only realms, this is the test for whether a user can be invited,
# not whether the user can sign up currently.) # not whether the user can sign up currently.)
def email_allowed_for_realm(email, realm): def email_allowed_for_realm(email, realm):
# type: (text_type, Realm) -> bool
# Anyone can be in an open realm # Anyone can be in an open realm
if not realm.restricted_to_domain: if not realm.restricted_to_domain:
return True return True
@@ -208,12 +227,14 @@ def email_allowed_for_realm(email, realm):
return email_domain == realm.domain.lower() return email_domain == realm.domain.lower()
def alias_for_realm(domain): def alias_for_realm(domain):
# type: (text_type) -> Optional[RealmAlias]
try: try:
return RealmAlias.objects.get(domain=domain) return RealmAlias.objects.get(domain=domain)
except RealmAlias.DoesNotExist: except RealmAlias.DoesNotExist:
return None return None
def remote_user_to_email(remote_user): def remote_user_to_email(remote_user):
# type: (text_type) -> text_type
if settings.SSO_APPEND_DOMAIN is not None: if settings.SSO_APPEND_DOMAIN is not None:
remote_user += "@" + settings.SSO_APPEND_DOMAIN remote_user += "@" + settings.SSO_APPEND_DOMAIN
return remote_user return remote_user
@@ -234,6 +255,7 @@ class RealmEmoji(models.Model):
return "<RealmEmoji(%s): %s %s>" % (self.realm.domain, self.name, self.img_url) return "<RealmEmoji(%s): %s %s>" % (self.realm.domain, self.name, self.img_url)
def get_realm_emoji_uncached(realm): def get_realm_emoji_uncached(realm):
# type: (Realm) -> Dict[str, Dict[str, str]]
d = {} d = {}
for row in RealmEmoji.objects.filter(realm=realm): for row in RealmEmoji.objects.filter(realm=realm):
d[row.name] = dict(source_url=row.img_url, d[row.name] = dict(source_url=row.img_url,
@@ -241,6 +263,7 @@ def get_realm_emoji_uncached(realm):
return d return d
def flush_realm_emoji(sender, **kwargs): def flush_realm_emoji(sender, **kwargs):
# type: (Any, **Any) -> None
realm = kwargs['instance'].realm realm = kwargs['instance'].realm
cache_set(get_realm_emoji_cache_key(realm), cache_set(get_realm_emoji_cache_key(realm),
get_realm_emoji_uncached(realm), get_realm_emoji_uncached(realm),
@@ -261,11 +284,13 @@ class RealmFilter(models.Model):
return "<RealmFilter(%s): %s %s>" % (self.realm.domain, self.pattern, self.url_format_string) return "<RealmFilter(%s): %s %s>" % (self.realm.domain, self.pattern, self.url_format_string)
def get_realm_filters_cache_key(domain): def get_realm_filters_cache_key(domain):
# type: (str) -> str
return 'all_realm_filters:%s' % (domain,) return 'all_realm_filters:%s' % (domain,)
# We have a per-process cache to avoid doing 1000 remote cache queries during page load # We have a per-process cache to avoid doing 1000 remote cache queries during page load
per_request_realm_filters_cache = {} # type: Dict[str, List[RealmFilter]] per_request_realm_filters_cache = {} # type: Dict[str, List[RealmFilter]]
def realm_filters_for_domain(domain): def realm_filters_for_domain(domain):
# type: (str) -> List[RealmFilter]
domain = domain.lower() domain = domain.lower()
if domain not in per_request_realm_filters_cache: if domain not in per_request_realm_filters_cache:
per_request_realm_filters_cache[domain] = realm_filters_for_domain_remote_cache(domain) per_request_realm_filters_cache[domain] = realm_filters_for_domain_remote_cache(domain)
@@ -273,6 +298,7 @@ def realm_filters_for_domain(domain):
@cache_with_key(get_realm_filters_cache_key, timeout=3600*24*7) @cache_with_key(get_realm_filters_cache_key, timeout=3600*24*7)
def realm_filters_for_domain_remote_cache(domain): def realm_filters_for_domain_remote_cache(domain):
# type: (str) -> List[Tuple[str, str]]
filters = [] filters = []
for realm_filter in RealmFilter.objects.filter(realm=get_realm(domain)): for realm_filter in RealmFilter.objects.filter(realm=get_realm(domain)):
filters.append((realm_filter.pattern, realm_filter.url_format_string)) filters.append((realm_filter.pattern, realm_filter.url_format_string))
@@ -288,6 +314,7 @@ def all_realm_filters():
return filters return filters
def flush_realm_filter(sender, **kwargs): def flush_realm_filter(sender, **kwargs):
# type: (Any, **Any) -> None
realm = kwargs['instance'].realm realm = kwargs['instance'].realm
cache_delete(get_realm_filters_cache_key(realm.domain)) cache_delete(get_realm_filters_cache_key(realm.domain))
try: try:
@@ -404,6 +431,7 @@ class UserProfile(AbstractBaseUser, PermissionsMixin):
objects = UserManager() # type: UserManager objects = UserManager() # type: UserManager
def can_admin_user(self, target_user): def can_admin_user(self, target_user):
# type: (UserProfile) -> bool
"""Returns whether this user has permission to modify target_user""" """Returns whether this user has permission to modify target_user"""
if target_user.bot_owner == self: if target_user.bot_owner == self:
return True return True
@@ -413,6 +441,7 @@ class UserProfile(AbstractBaseUser, PermissionsMixin):
return False return False
def last_reminder_tzaware(self): def last_reminder_tzaware(self):
# type: () -> str
if self.last_reminder is not None and timezone.is_naive(self.last_reminder): if self.last_reminder is not None and timezone.is_naive(self.last_reminder):
logging.warning("Loaded a user_profile.last_reminder for user %s that's not tz-aware: %s" logging.warning("Loaded a user_profile.last_reminder for user %s that's not tz-aware: %s"
% (self.email, self.last_reminder)) % (self.email, self.last_reminder))
@@ -427,16 +456,19 @@ class UserProfile(AbstractBaseUser, PermissionsMixin):
@staticmethod @staticmethod
def emails_from_ids(user_ids): def emails_from_ids(user_ids):
# type: (Sequence[int]) -> Dict[int, text_type]
rows = UserProfile.objects.filter(id__in=user_ids).values('id', 'email') rows = UserProfile.objects.filter(id__in=user_ids).values('id', 'email')
return {row['id']: row['email'] for row in rows} return {row['id']: row['email'] for row in rows}
def can_create_streams(self): def can_create_streams(self):
# type: () -> bool
if self.is_realm_admin or not self.realm.create_stream_by_admins_only: if self.is_realm_admin or not self.realm.create_stream_by_admins_only:
return True return True
else: else:
return False return False
def receives_offline_notifications(user_profile): def receives_offline_notifications(user_profile):
# type: (UserProfile) -> bool
return ((user_profile.enable_offline_email_notifications or return ((user_profile.enable_offline_email_notifications or
user_profile.enable_offline_push_notifications) and user_profile.enable_offline_push_notifications) and
not user_profile.is_bot) not user_profile.is_bot)
@@ -488,6 +520,7 @@ class MitUser(models.Model):
status = models.IntegerField(default=0) status = models.IntegerField(default=0)
def generate_email_token_for_stream(): def generate_email_token_for_stream():
# type: () -> str
return generate_random_token(32) return generate_random_token(32)
class Stream(models.Model): class Stream(models.Model):
@@ -511,6 +544,7 @@ class Stream(models.Model):
return self.__repr__() return self.__repr__()
def is_public(self): def is_public(self):
# type: () -> bool
# All streams are private at MIT. # All streams are private at MIT.
return self.realm.domain != "mit.edu" and not self.invite_only return self.realm.domain != "mit.edu" and not self.invite_only
@@ -519,6 +553,7 @@ class Stream(models.Model):
@classmethod @classmethod
def create(cls, name, realm): def create(cls, name, realm):
# type: (Any, str, Realm) -> Tuple[Stream, Recipient]
stream = cls(name=name, realm=realm) stream = cls(name=name, realm=realm)
stream.save() stream.save()
@@ -527,6 +562,7 @@ class Stream(models.Model):
return (stream, recipient) return (stream, recipient)
def num_subscribers(self): def num_subscribers(self):
# type: () -> int
return Subscription.objects.filter( return Subscription.objects.filter(
recipient__type=Recipient.STREAM, recipient__type=Recipient.STREAM,
recipient__type_id=self.id, recipient__type_id=self.id,
@@ -536,6 +572,7 @@ class Stream(models.Model):
# This is stream information that is sent to clients # This is stream information that is sent to clients
def to_dict(self): def to_dict(self):
# type: () -> Dict[str, Any]
return dict(name=self.name, return dict(name=self.name,
stream_id=self.id, stream_id=self.id,
description=self.description, description=self.description,
@@ -545,6 +582,7 @@ post_save.connect(flush_stream, sender=Stream)
post_delete.connect(flush_stream, sender=Stream) post_delete.connect(flush_stream, sender=Stream)
def valid_stream_name(name): def valid_stream_name(name):
# type: (text_type) -> bool
return name != "" return name != ""
# The Recipient table is used to map Messages to the set of users who # The Recipient table is used to map Messages to the set of users who
@@ -572,6 +610,7 @@ class Recipient(models.Model):
HUDDLE: 'huddle' } HUDDLE: 'huddle' }
def type_name(self): def type_name(self):
# type: () -> str
# Raises KeyError if invalid # Raises KeyError if invalid
return self._type_names[self.type] return self._type_names[self.type]
@@ -587,22 +626,26 @@ class Client(models.Model):
get_client_cache = {} # type: Dict[str, Client] get_client_cache = {} # type: Dict[str, Client]
def get_client(name): def get_client(name):
# type: (str) -> Client
if name not in get_client_cache: if name not in get_client_cache:
result = get_client_remote_cache(name) result = get_client_remote_cache(name)
get_client_cache[name] = result get_client_cache[name] = result
return get_client_cache[name] return get_client_cache[name]
def get_client_cache_key(name): def get_client_cache_key(name):
# type: (str) -> str
return 'get_client:%s' % (make_safe_digest(name),) return 'get_client:%s' % (make_safe_digest(name),)
@cache_with_key(get_client_cache_key, timeout=3600*24*7) @cache_with_key(get_client_cache_key, timeout=3600*24*7)
def get_client_remote_cache(name): def get_client_remote_cache(name):
# type: (str) -> Client
(client, _) = Client.objects.get_or_create(name=name) (client, _) = Client.objects.get_or_create(name=name)
return client return client
# get_stream_backend takes either a realm id or a realm # get_stream_backend takes either a realm id or a realm
@cache_with_key(get_stream_cache_key, timeout=3600*24*7) @cache_with_key(get_stream_cache_key, timeout=3600*24*7)
def get_stream_backend(stream_name, realm): def get_stream_backend(stream_name, realm):
# type: (text_type, Realm) -> Stream
if isinstance(realm, Realm): if isinstance(realm, Realm):
realm_id = realm.id realm_id = realm.id
else: else:
@@ -611,6 +654,7 @@ def get_stream_backend(stream_name, realm):
name__iexact=stream_name.strip(), realm_id=realm_id) name__iexact=stream_name.strip(), realm_id=realm_id)
def get_active_streams(realm): def get_active_streams(realm):
# type: (Realm) -> QuerySet
""" """
Return all streams (including invite-only streams) that have not been deactivated. Return all streams (including invite-only streams) that have not been deactivated.
""" """
@@ -618,12 +662,14 @@ def get_active_streams(realm):
# get_stream takes either a realm id or a realm # get_stream takes either a realm id or a realm
def get_stream(stream_name, realm): def get_stream(stream_name, realm):
# type: (text_type, Union[int, Realm]) -> Optional[Stream]
try: try:
return get_stream_backend(stream_name, realm) return get_stream_backend(stream_name, realm)
except Stream.DoesNotExist: except Stream.DoesNotExist:
return None return None
def bulk_get_streams(realm, stream_names): def bulk_get_streams(realm, stream_names):
# type: (Realm, STREAM_NAMES) -> Dict[text_type, Any]
if isinstance(realm, Realm): if isinstance(realm, Realm):
realm_id = realm.id realm_id = realm.id
else: else:
@@ -651,13 +697,16 @@ def bulk_get_streams(realm, stream_names):
id_fetcher=lambda stream: stream.name.lower()) id_fetcher=lambda stream: stream.name.lower())
def get_recipient_cache_key(type, type_id): def get_recipient_cache_key(type, type_id):
# type: (int, int) -> str
return "get_recipient:%s:%s" % (type, type_id,) return "get_recipient:%s:%s" % (type, type_id,)
@cache_with_key(get_recipient_cache_key, timeout=3600*24*7) @cache_with_key(get_recipient_cache_key, timeout=3600*24*7)
def get_recipient(type, type_id): def get_recipient(type, type_id):
# type: (int, int) -> Recipient
return Recipient.objects.get(type_id=type_id, type=type) return Recipient.objects.get(type_id=type_id, type=type)
def bulk_get_recipients(type, type_ids): def bulk_get_recipients(type, type_ids):
# type: (int, List[int]) -> Dict[int, Any]
def cache_key_function(type_id): def cache_key_function(type_id):
return get_recipient_cache_key(type, type_id) return get_recipient_cache_key(type, type_id)
def query_function(type_ids): def query_function(type_ids):
@@ -668,18 +717,23 @@ def bulk_get_recipients(type, type_ids):
# NB: This function is currently unused, but may come in handy. # NB: This function is currently unused, but may come in handy.
def linebreak(string): def linebreak(string):
# type: (str) -> str
return string.replace('\n\n', '<p/>').replace('\n', '<br/>') return string.replace('\n\n', '<p/>').replace('\n', '<br/>')
def extract_message_dict(message_str): def extract_message_dict(message_str):
# type: (str) -> Dict[str, Any]
return ujson.loads(zlib.decompress(message_str).decode("utf-8")) return ujson.loads(zlib.decompress(message_str).decode("utf-8"))
def stringify_message_dict(message_dict): def stringify_message_dict(message_dict):
# type: (Dict[Any, Any]) -> str
return zlib.compress(ujson.dumps(message_dict).encode("utf-8")) return zlib.compress(ujson.dumps(message_dict).encode("utf-8"))
def to_dict_cache_key_id(message_id, apply_markdown): def to_dict_cache_key_id(message_id, apply_markdown):
# type: (int, bool) -> str
return 'message_dict:%d:%d' % (message_id, apply_markdown) return 'message_dict:%d:%d' % (message_id, apply_markdown)
def to_dict_cache_key(message, apply_markdown): def to_dict_cache_key(message, apply_markdown):
# type: (Message, bool) -> str
return to_dict_cache_key_id(message.id, apply_markdown) return to_dict_cache_key_id(message.id, apply_markdown)
class Message(models.Model): class Message(models.Model):
@@ -705,9 +759,11 @@ class Message(models.Model):
return self.__repr__() return self.__repr__()
def get_realm(self): def get_realm(self):
# type: () -> Realm
return self.sender.realm return self.sender.realm
def render_markdown(self, content, domain=None): def render_markdown(self, content, domain=None):
# type: (str, Optional[str]) -> str
"""Return HTML for given markdown. Bugdown may add properties to the """Return HTML for given markdown. Bugdown may add properties to the
message object such as `mentions_user_ids` and `mentions_wildcard`. message object such as `mentions_user_ids` and `mentions_wildcard`.
These are only on this Django object and are not saved in the These are only on this Django object and are not saved in the
@@ -739,6 +795,7 @@ class Message(models.Model):
return rendered_content return rendered_content
def set_rendered_content(self, rendered_content, save = False): def set_rendered_content(self, rendered_content, save = False):
# type: (str, bool) -> bool
"""Set the content on the message. """Set the content on the message.
""" """
global bugdown global bugdown
@@ -756,9 +813,11 @@ class Message(models.Model):
return False return False
def save_rendered_content(self): def save_rendered_content(self):
# type: () -> None
self.save(update_fields=["rendered_content", "rendered_content_version"]) self.save(update_fields=["rendered_content", "rendered_content_version"])
def maybe_render_content(self, domain, save = False): def maybe_render_content(self, domain, save = False):
# type: (str, bool) -> bool
"""Render the markdown if there is no existing rendered_content""" """Render the markdown if there is no existing rendered_content"""
global bugdown global bugdown
if bugdown is None: if bugdown is None:
@@ -771,16 +830,20 @@ class Message(models.Model):
@staticmethod @staticmethod
def need_to_render_content(rendered_content, rendered_content_version): def need_to_render_content(rendered_content, rendered_content_version):
# type: (str, int) -> bool
return rendered_content is None or rendered_content_version < bugdown.version return rendered_content is None or rendered_content_version < bugdown.version
def to_dict(self, apply_markdown): def to_dict(self, apply_markdown):
# type: (bool) -> Dict[str, Any]
return extract_message_dict(self.to_dict_json(apply_markdown)) return extract_message_dict(self.to_dict_json(apply_markdown))
@cache_with_key(to_dict_cache_key, timeout=3600*24) @cache_with_key(to_dict_cache_key, timeout=3600*24)
def to_dict_json(self, apply_markdown): def to_dict_json(self, apply_markdown):
# type: (bool) -> str
return stringify_message_dict(self.to_dict_uncached(apply_markdown)) return stringify_message_dict(self.to_dict_uncached(apply_markdown))
def to_dict_uncached(self, apply_markdown): def to_dict_uncached(self, apply_markdown):
# type: (bool) -> Dict[str, Any]
return Message.build_message_dict( return Message.build_message_dict(
apply_markdown = apply_markdown, apply_markdown = apply_markdown,
message = self, message = self,
@@ -807,6 +870,7 @@ class Message(models.Model):
@staticmethod @staticmethod
def build_dict_from_raw_db_row(row, apply_markdown): def build_dict_from_raw_db_row(row, apply_markdown):
# type: (Dict[str, Any], bool) -> Dict[str, Any]
''' '''
row is a row from a .values() call, and it needs to have row is a row from a .values() call, and it needs to have
all the relevant fields populated all the relevant fields populated
@@ -859,6 +923,7 @@ class Message(models.Model):
recipient_type, recipient_type,
recipient_type_id, recipient_type_id,
): ):
# type: (bool, Message, int, int, str, str, str, int, str, int, int, str, str, str, str, str, bool, str, int, int, int) -> Dict[str, Any]
global bugdown global bugdown
if bugdown is None: if bugdown is None:
from zerver.lib import bugdown from zerver.lib import bugdown
@@ -943,6 +1008,7 @@ class Message(models.Model):
return obj return obj
def to_log_dict(self): def to_log_dict(self):
# type: () -> Dict[str, Any]
return dict( return dict(
id = self.id, id = self.id,
sender_id = self.sender.id, sender_id = self.sender.id,
@@ -959,6 +1025,7 @@ class Message(models.Model):
@staticmethod @staticmethod
def get_raw_db_rows(needed_ids): def get_raw_db_rows(needed_ids):
# type: (List[int]) -> List[Dict[str, Any]]
# This is a special purpose function optimized for # This is a special purpose function optimized for
# callers like get_old_messages_backend(). # callers like get_old_messages_backend().
fields = [ fields = [
@@ -987,10 +1054,12 @@ class Message(models.Model):
@classmethod @classmethod
def remove_unreachable(cls): def remove_unreachable(cls):
# type: (Any) -> None
"""Remove all Messages that are not referred to by any UserMessage.""" """Remove all Messages that are not referred to by any UserMessage."""
cls.objects.exclude(id__in = UserMessage.objects.values('message_id')).delete() cls.objects.exclude(id__in = UserMessage.objects.values('message_id')).delete()
def sent_by_human(self): def sent_by_human(self):
# type: () -> bool
sending_client = self.sending_client.name.lower() sending_client = self.sending_client.name.lower()
return (sending_client in ('zulipandroid', 'zulipios', 'zulipdesktop', return (sending_client in ('zulipandroid', 'zulipios', 'zulipdesktop',
@@ -999,17 +1068,21 @@ class Message(models.Model):
@staticmethod @staticmethod
def content_has_attachment(content): def content_has_attachment(content):
# type: (text_type) -> Match
return re.search('[/\-]user[\-_]uploads[/\.-]', content) return re.search('[/\-]user[\-_]uploads[/\.-]', content)
@staticmethod @staticmethod
def content_has_image(content): def content_has_image(content):
# type: (text_type) -> bool
return bool(re.search('[/\-]user[\-_]uploads[/\.-]\S+\.(bmp|gif|jpg|jpeg|png|webp)', content, re.IGNORECASE)) return bool(re.search('[/\-]user[\-_]uploads[/\.-]\S+\.(bmp|gif|jpg|jpeg|png|webp)', content, re.IGNORECASE))
@staticmethod @staticmethod
def content_has_link(content): def content_has_link(content):
# type: (text_type) -> bool
return 'http://' in content or 'https://' in content or '/user_uploads' in content return 'http://' in content or 'https://' in content or '/user_uploads' in content
def update_calculated_fields(self): def update_calculated_fields(self):
# type: () -> None
# TODO: rendered_content could also be considered a calculated field # TODO: rendered_content could also be considered a calculated field
content = self.content content = self.content
self.has_attachment = bool(Message.content_has_attachment(content)) self.has_attachment = bool(Message.content_has_attachment(content))
@@ -1018,11 +1091,13 @@ class Message(models.Model):
@receiver(pre_save, sender=Message) @receiver(pre_save, sender=Message)
def pre_save_message(sender, **kwargs): def pre_save_message(sender, **kwargs):
# type: (Any, **Any) -> None
if kwargs['update_fields'] is None or "content" in kwargs['update_fields']: if kwargs['update_fields'] is None or "content" in kwargs['update_fields']:
message = kwargs['instance'] message = kwargs['instance']
message.update_calculated_fields() message.update_calculated_fields()
def get_context_for_message(message): def get_context_for_message(message):
# type: (Message) -> List[Message]
return Message.objects.filter( return Message.objects.filter(
recipient_id=message.recipient_id, recipient_id=message.recipient_id,
subject=message.subject, subject=message.subject,
@@ -1063,9 +1138,11 @@ class UserMessage(models.Model):
return (u"<UserMessage: %s / %s (%s)>" % (display_recipient, self.user_profile.email, self.flags_list())).encode("utf-8") return (u"<UserMessage: %s / %s (%s)>" % (display_recipient, self.user_profile.email, self.flags_list())).encode("utf-8")
def flags_list(self): def flags_list(self):
# type: () -> List[str]
return [flag for flag in self.flags.keys() if getattr(self.flags, flag).is_set] return [flag for flag in self.flags.keys() if getattr(self.flags, flag).is_set]
def parse_usermessage_flags(val): def parse_usermessage_flags(val):
# type: (int) -> List[str]
flags = [] flags = []
mask = 1 mask = 1
for flag in UserMessage.ALL_FLAGS: for flag in UserMessage.ALL_FLAGS:
@@ -1089,20 +1166,25 @@ class Attachment(models.Model):
return (u"<Attachment: %s>" % (self.file_name)) return (u"<Attachment: %s>" % (self.file_name))
def is_claimed(self): def is_claimed(self):
# type: () -> bool
return self.messages.count() > 0 return self.messages.count() > 0
def get_url(self): def get_url(self):
# type: () -> str
return "/user_uploads/%s" % (self.path_id) return "/user_uploads/%s" % (self.path_id)
def get_attachments_by_owner_id(uid): def get_attachments_by_owner_id(uid):
# type: (int) -> List[Attachment]
return Attachment.objects.filter(owner=uid).select_related('owner') return Attachment.objects.filter(owner=uid).select_related('owner')
def get_owners_from_file_name(file_name): def get_owners_from_file_name(file_name):
# type: (str) -> List[Attachment]
# The returned vaule will list of owners since different users can upload # The returned vaule will list of owners since different users can upload
# same files with the same filename. # same files with the same filename.
return Attachment.objects.filter(file_name=file_name).select_related('owner') return Attachment.objects.filter(file_name=file_name).select_related('owner')
def get_old_unclaimed_attachments(weeks_ago): def get_old_unclaimed_attachments(weeks_ago):
# type: (int) -> List[Attachment]
delta_weeks_ago = timezone.now() - datetime.timedelta(weeks=weeks_ago) delta_weeks_ago = timezone.now() - datetime.timedelta(weeks=weeks_ago)
old_attachments = Attachment.objects.filter(messages=None, create_time__lt=delta_weeks_ago) old_attachments = Attachment.objects.filter(messages=None, create_time__lt=delta_weeks_ago)
return old_attachments return old_attachments
@@ -1133,23 +1215,28 @@ class Subscription(models.Model):
@cache_with_key(user_profile_by_id_cache_key, timeout=3600*24*7) @cache_with_key(user_profile_by_id_cache_key, timeout=3600*24*7)
def get_user_profile_by_id(uid): def get_user_profile_by_id(uid):
# type: (int) -> UserProfile
return UserProfile.objects.select_related().get(id=uid) return UserProfile.objects.select_related().get(id=uid)
@cache_with_key(user_profile_by_email_cache_key, timeout=3600*24*7) @cache_with_key(user_profile_by_email_cache_key, timeout=3600*24*7)
def get_user_profile_by_email(email): def get_user_profile_by_email(email):
# type: (text_type) -> UserProfile
return UserProfile.objects.select_related().get(email__iexact=email.strip()) return UserProfile.objects.select_related().get(email__iexact=email.strip())
@cache_with_key(active_user_dicts_in_realm_cache_key, timeout=3600*24*7) @cache_with_key(active_user_dicts_in_realm_cache_key, timeout=3600*24*7)
def get_active_user_dicts_in_realm(realm): def get_active_user_dicts_in_realm(realm):
# type: (Realm) -> List[Dict[str, Any]]
return UserProfile.objects.filter(realm=realm, is_active=True) \ return UserProfile.objects.filter(realm=realm, is_active=True) \
.values(*active_user_dict_fields) .values(*active_user_dict_fields)
@cache_with_key(active_bot_dicts_in_realm_cache_key, timeout=3600*24*7) @cache_with_key(active_bot_dicts_in_realm_cache_key, timeout=3600*24*7)
def get_active_bot_dicts_in_realm(realm): def get_active_bot_dicts_in_realm(realm):
# type: (Realm) -> List[Dict[str, Any]]
return UserProfile.objects.filter(realm=realm, is_active=True, is_bot=True) \ return UserProfile.objects.filter(realm=realm, is_active=True, is_bot=True) \
.values(*active_bot_dict_fields) .values(*active_bot_dict_fields)
def get_owned_bot_dicts(user_profile, include_all_realm_bots_if_admin=True): def get_owned_bot_dicts(user_profile, include_all_realm_bots_if_admin=True):
# type: (UserProfile, bool) -> List[Dict[str, Any]]
if user_profile.is_realm_admin and include_all_realm_bots_if_admin: if user_profile.is_realm_admin and include_all_realm_bots_if_admin:
result = get_active_bot_dicts_in_realm(user_profile.realm) result = get_active_bot_dicts_in_realm(user_profile.realm)
else: else:
@@ -1167,6 +1254,7 @@ def get_owned_bot_dicts(user_profile, include_all_realm_bots_if_admin=True):
for botdict in result] for botdict in result]
def get_prereg_user_by_email(email): def get_prereg_user_by_email(email):
# type: (str) -> PreregistrationUser
# A user can be invited many times, so only return the result of the latest # A user can be invited many times, so only return the result of the latest
# invite. # invite.
return PreregistrationUser.objects.filter(email__iexact=email.strip()).latest("invited_at") return PreregistrationUser.objects.filter(email__iexact=email.strip()).latest("invited_at")
@@ -1183,19 +1271,23 @@ class Huddle(models.Model):
huddle_hash = models.CharField(max_length=40, db_index=True, unique=True) huddle_hash = models.CharField(max_length=40, db_index=True, unique=True)
def get_huddle_hash(id_list): def get_huddle_hash(id_list):
# type: (List[int]) -> str
id_list = sorted(set(id_list)) id_list = sorted(set(id_list))
hash_key = ",".join(str(x) for x in id_list) hash_key = ",".join(str(x) for x in id_list)
return make_safe_digest(hash_key) return make_safe_digest(hash_key)
def huddle_hash_cache_key(huddle_hash): def huddle_hash_cache_key(huddle_hash):
# type: (str) -> str
return "huddle_by_hash:%s" % (huddle_hash,) return "huddle_by_hash:%s" % (huddle_hash,)
def get_huddle(id_list): def get_huddle(id_list):
# type: (List[int]) -> Huddle
huddle_hash = get_huddle_hash(id_list) huddle_hash = get_huddle_hash(id_list)
return get_huddle_backend(huddle_hash, id_list) return get_huddle_backend(huddle_hash, id_list)
@cache_with_key(lambda huddle_hash, id_list: huddle_hash_cache_key(huddle_hash), timeout=3600*24*7) @cache_with_key(lambda huddle_hash, id_list: huddle_hash_cache_key(huddle_hash), timeout=3600*24*7)
def get_huddle_backend(huddle_hash, id_list): def get_huddle_backend(huddle_hash, id_list):
# type: (str, List[int]) -> Huddle
(huddle, created) = Huddle.objects.get_or_create(huddle_hash=huddle_hash) (huddle, created) = Huddle.objects.get_or_create(huddle_hash=huddle_hash)
if created: if created:
with transaction.atomic(): with transaction.atomic():
@@ -1208,6 +1300,7 @@ def get_huddle_backend(huddle_hash, id_list):
return huddle return huddle
def get_realm(domain): def get_realm(domain):
# type: (text_type) -> Optional[Realm]
if not domain: if not domain:
return None return None
try: try:
@@ -1216,6 +1309,7 @@ def get_realm(domain):
return None return None
def clear_database(): def clear_database():
# type: () -> None
pylibmc.Client(['127.0.0.1']).flush_all() pylibmc.Client(['127.0.0.1']).flush_all()
model = None # type: Any model = None # type: Any
for model in [Message, Stream, UserProfile, Recipient, for model in [Message, Stream, UserProfile, Recipient,
@@ -1253,6 +1347,7 @@ class UserPresence(models.Model):
@staticmethod @staticmethod
def status_to_string(status): def status_to_string(status):
# type: (int) -> str
if status == UserPresence.ACTIVE: if status == UserPresence.ACTIVE:
return 'active' return 'active'
elif status == UserPresence.IDLE: elif status == UserPresence.IDLE:
@@ -1260,8 +1355,8 @@ class UserPresence(models.Model):
@staticmethod @staticmethod
def get_status_dict_by_realm(realm_id): def get_status_dict_by_realm(realm_id):
# type: (Any) -> Any # type: (int) -> defaultdict[Any, Dict[Any, Any]]
user_statuses = defaultdict(dict) # type: Dict[Any, Dict[Any, Any]] user_statuses = defaultdict(dict) # type: defaultdict[Any, Dict[Any, Any]]
query = UserPresence.objects.filter( query = UserPresence.objects.filter(
user_profile__realm_id=realm_id, user_profile__realm_id=realm_id,
@@ -1300,6 +1395,7 @@ class UserPresence(models.Model):
@staticmethod @staticmethod
def to_presence_dict(client_name=None, status=None, timestamp=None, push_enabled=None, def to_presence_dict(client_name=None, status=None, timestamp=None, push_enabled=None,
has_push_devices=None, is_mirror_dummy=None): has_push_devices=None, is_mirror_dummy=None):
# type: (Optional[str], Optional[int], Optional[int], Optional[bool], Optional[bool], Optional[bool]) -> Dict[str, Any]
presence_val = UserPresence.status_to_string(status) presence_val = UserPresence.status_to_string(status)
timestamp = datetime_to_timestamp(timestamp) timestamp = datetime_to_timestamp(timestamp)
return dict( return dict(
@@ -1310,6 +1406,7 @@ class UserPresence(models.Model):
) )
def to_dict(self): def to_dict(self):
# type: () -> Dict[str, Any]
return UserPresence.to_presence_dict( return UserPresence.to_presence_dict(
client_name=self.client.name, client_name=self.client.name,
status=self.status, status=self.status,
@@ -1318,6 +1415,7 @@ class UserPresence(models.Model):
@staticmethod @staticmethod
def status_from_string(status): def status_from_string(status):
# type: (str) -> Optional[int]
if status == 'active': if status == 'active':
status_val = UserPresence.ACTIVE status_val = UserPresence.ACTIVE
elif status == 'idle': elif status == 'idle':

View File

@@ -346,7 +346,7 @@ def restore_saved_messages():
sender_email = old_message["sender_email"] sender_email = old_message["sender_email"]
domain = split_email_to_domain(sender_email) domain = str(split_email_to_domain(sender_email))
realm_set.add(domain) realm_set.add(domain)
if old_message["sender_email"] not in email_set: if old_message["sender_email"] not in email_set: