zerver/lib: Use python 3 syntax for typing.

Split by tabbott from a larger commit; this covers a batch of files
with no open PRs touching them.
This commit is contained in:
rht
2017-11-05 11:15:10 +01:00
committed by Tim Abbott
parent 73d0f1db81
commit 561ba33f69
15 changed files with 119 additions and 206 deletions

View File

@@ -6,20 +6,17 @@ import ujson
from typing import Dict, Iterable, List, Text
@cache_with_key(realm_alert_words_cache_key, timeout=3600*24)
def alert_words_in_realm(realm):
# type: (Realm) -> Dict[int, List[Text]]
def alert_words_in_realm(realm: Realm) -> Dict[int, List[Text]]:
users_query = UserProfile.objects.filter(realm=realm, is_active=True)
alert_word_data = users_query.filter(~Q(alert_words=ujson.dumps([]))).values('id', 'alert_words')
all_user_words = dict((elt['id'], ujson.loads(elt['alert_words'])) for elt in alert_word_data)
user_ids_with_words = dict((user_id, w) for (user_id, w) in all_user_words.items() if len(w))
return user_ids_with_words
def user_alert_words(user_profile):
# type: (UserProfile) -> List[Text]
def user_alert_words(user_profile: UserProfile) -> List[Text]:
return ujson.loads(user_profile.alert_words)
def add_user_alert_words(user_profile, alert_words):
# type: (UserProfile, Iterable[Text]) -> List[Text]
def add_user_alert_words(user_profile: UserProfile, alert_words: Iterable[Text]) -> List[Text]:
words = user_alert_words(user_profile)
new_words = [w for w in alert_words if w not in words]
@@ -29,8 +26,7 @@ def add_user_alert_words(user_profile, alert_words):
return words
def remove_user_alert_words(user_profile, alert_words):
# type: (UserProfile, Iterable[Text]) -> List[Text]
def remove_user_alert_words(user_profile: UserProfile, alert_words: Iterable[Text]) -> List[Text]:
words = user_alert_words(user_profile)
words = [w for w in words if w not in alert_words]
@@ -38,7 +34,6 @@ def remove_user_alert_words(user_profile, alert_words):
return words
def set_user_alert_words(user_profile, alert_words):
# type: (UserProfile, List[Text]) -> None
def set_user_alert_words(user_profile: UserProfile, alert_words: List[Text]) -> None:
user_profile.alert_words = ujson.dumps(alert_words)
user_profile.save(update_fields=['alert_words'])

View File

@@ -10,8 +10,7 @@ from zerver.lib.upload import upload_backend, MEDIUM_AVATAR_SIZE
from zerver.models import UserProfile
import urllib
def avatar_url(user_profile, medium=False, client_gravatar=False):
# type: (UserProfile, bool, bool) -> Text
def avatar_url(user_profile: UserProfile, medium: bool=False, client_gravatar: bool=False) -> Text:
return get_avatar_field(
user_id=user_profile.id,
@@ -23,8 +22,7 @@ def avatar_url(user_profile, medium=False, client_gravatar=False):
client_gravatar=client_gravatar,
)
def avatar_url_from_dict(userdict, medium=False):
# type: (Dict[str, Any], bool) -> Text
def avatar_url_from_dict(userdict: Dict[str, Any], medium: bool=False) -> Text:
'''
DEPRECATED: We should start using
get_avatar_field to populate users,
@@ -91,30 +89,30 @@ def get_avatar_field(user_id,
url += '&version=%d' % (avatar_version,)
return url
def get_gravatar_url(email, avatar_version, medium=False):
# type: (Text, int, bool) -> Text
def get_gravatar_url(email: Text, avatar_version: int, medium: bool=False) -> Text:
url = _get_unversioned_gravatar_url(email, medium)
url += '&version=%d' % (avatar_version,)
return url
def _get_unversioned_gravatar_url(email, medium):
# type: (Text, bool) -> Text
def _get_unversioned_gravatar_url(email: Text, medium: bool) -> Text:
if settings.ENABLE_GRAVATAR:
gravitar_query_suffix = "&s=%s" % (MEDIUM_AVATAR_SIZE,) if medium else ""
hash_key = gravatar_hash(email)
return "https://secure.gravatar.com/avatar/%s?d=identicon%s" % (hash_key, gravitar_query_suffix)
return settings.DEFAULT_AVATAR_URI+'?x=x'
def _get_unversioned_avatar_url(user_profile_id, avatar_source, realm_id, email=None, medium=False):
# type: (int, Text, int, Optional[Text], bool) -> Text
def _get_unversioned_avatar_url(user_profile_id: int,
avatar_source: Text,
realm_id: int,
email: Optional[Text]=None,
medium: bool=False) -> Text:
if avatar_source == 'U':
hash_key = user_avatar_path_from_ids(user_profile_id, realm_id)
return upload_backend.get_avatar_url(hash_key, medium=medium)
assert email is not None
return _get_unversioned_gravatar_url(email, medium)
def absolute_avatar_url(user_profile):
# type: (UserProfile) -> Text
def absolute_avatar_url(user_profile: UserProfile) -> Text:
"""Absolute URLs are used to simplify logic for applications that
won't be served by browsers, such as rendering GCM notifications."""
return urllib.parse.urljoin(user_profile.realm.uri, avatar_url(user_profile))

View File

@@ -10,8 +10,7 @@ if False:
import hashlib
def gravatar_hash(email):
# type: (Text) -> Text
def gravatar_hash(email: Text) -> Text:
"""Compute the Gravatar hash for an email address."""
# Non-ASCII characters aren't permitted by the currently active e-mail
# RFCs. However, the IETF has published https://tools.ietf.org/html/rfc4952,
@@ -20,8 +19,7 @@ def gravatar_hash(email):
# not error out on it.
return make_safe_digest(email.lower(), hashlib.md5)
def user_avatar_hash(uid):
# type: (Text) -> Text
def user_avatar_hash(uid: Text) -> Text:
# WARNING: If this method is changed, you may need to do a migration
# similar to zerver/migrations/0060_move_avatars_to_be_uid_based.py .
@@ -39,7 +37,6 @@ def user_avatar_path(user_profile):
# similar to zerver/migrations/0060_move_avatars_to_be_uid_based.py .
return user_avatar_path_from_ids(user_profile.id, user_profile.realm_id)
def user_avatar_path_from_ids(user_profile_id, realm_id):
# type: (int, int) -> Text
def user_avatar_path_from_ids(user_profile_id: int, realm_id: int) -> Text:
user_id_hash = user_avatar_hash(str(user_profile_id))
return '%s/%s' % (str(realm_id), user_id_hash)

View File

@@ -5,8 +5,12 @@ from zerver.models import Realm, Stream, UserProfile, Huddle, \
Subscription, Recipient, Client, RealmAuditLog, get_huddle_hash
from zerver.lib.create_user import create_user_profile
def bulk_create_users(realm, users_raw, bot_type=None, bot_owner=None, tos_version=None, timezone=""):
# type: (Realm, Set[Tuple[Text, Text, Text, bool]], Optional[int], Optional[UserProfile], Optional[Text], Text) -> None
def bulk_create_users(realm: Realm,
users_raw: Set[Tuple[Text, Text, Text, bool]],
bot_type: Optional[int]=None,
bot_owner: Optional[UserProfile]=None,
tos_version: Optional[Text]=None,
timezone: Text="") -> None:
"""
Creates and saves a UserProfile with the given email.
Has some code based off of UserManage.create_user, but doesn't .save()
@@ -54,8 +58,8 @@ def bulk_create_users(realm, users_raw, bot_type=None, bot_owner=None, tos_versi
recipient=recipients_by_email[email]))
Subscription.objects.bulk_create(subscriptions_to_create)
def bulk_create_streams(realm, stream_dict):
# type: (Realm, Dict[Text, Dict[Text, Any]]) -> None
def bulk_create_streams(realm: Realm,
stream_dict: Dict[Text, Dict[Text, Any]]) -> None:
existing_streams = frozenset([name.lower() for name in
Stream.objects.filter(realm=realm)
.values_list('name', flat=True)])
@@ -87,8 +91,7 @@ def bulk_create_streams(realm, stream_dict):
type=Recipient.STREAM))
Recipient.objects.bulk_create(recipients_to_create)
def bulk_create_clients(client_list):
# type: (Iterable[Text]) -> None
def bulk_create_clients(client_list: Iterable[Text]) -> None:
existing_clients = set(client.name for client in Client.objects.select_related().all()) # type: Set[Text]
clients_to_create = [] # type: List[Client]
@@ -98,8 +101,7 @@ def bulk_create_clients(client_list):
existing_clients.add(name)
Client.objects.bulk_create(clients_to_create)
def bulk_create_huddles(users, huddle_user_list):
# type: (Dict[Text, UserProfile], Iterable[Iterable[Text]]) -> None
def bulk_create_huddles(users: Dict[Text, UserProfile], huddle_user_list: Iterable[Iterable[Text]]) -> None:
huddles = {} # type: Dict[Text, Huddle]
huddles_by_id = {} # type: Dict[int, Huddle]
huddle_set = set() # type: Set[Tuple[Text, Tuple[int, ...]]]

View File

@@ -7,8 +7,7 @@ from typing import Text
# Encodes the provided URL using the same algorithm used by the camo
# caching https image proxy
def get_camo_url(url):
# type: (Text) -> Text
def get_camo_url(url: Text) -> Text:
# Only encode the url if Camo is enabled
if settings.CAMO_URI == '':
return url

View File

@@ -36,8 +36,7 @@ import struct
# there is already an ASN.1 implementation, but in the interest of
# limiting MIT Kerberos's exposure to malformed ccaches, encode it
# ourselves. To that end, here's the laziest DER encoder ever.
def der_encode_length(length):
# type: (int) -> bytes
def der_encode_length(length: int) -> bytes:
if length <= 127:
return struct.pack('!B', length)
out = b""
@@ -47,12 +46,10 @@ def der_encode_length(length):
out = struct.pack('!B', len(out) | 0x80) + out
return out
def der_encode_tlv(tag, value):
# type: (int, bytes) -> bytes
def der_encode_tlv(tag: int, value: bytes) -> bytes:
return struct.pack('!B', tag) + der_encode_length(len(value)) + value
def der_encode_integer_value(val):
# type: (int) -> bytes
def der_encode_integer_value(val: int) -> bytes:
if not isinstance(val, int):
raise TypeError("int")
# base 256, MSB first, two's complement, minimum number of octets
@@ -74,34 +71,28 @@ def der_encode_integer_value(val):
val >>= 8
return out
def der_encode_integer(val):
# type: (int) -> bytes
def der_encode_integer(val: int) -> bytes:
return der_encode_tlv(0x02, der_encode_integer_value(val))
def der_encode_int32(val):
# type: (int) -> bytes
def der_encode_int32(val: int) -> bytes:
if val < -2147483648 or val > 2147483647:
raise ValueError("Bad value")
return der_encode_integer(val)
def der_encode_uint32(val):
# type: (int) -> bytes
def der_encode_uint32(val: int) -> bytes:
if val < 0 or val > 4294967295:
raise ValueError("Bad value")
return der_encode_integer(val)
def der_encode_string(val):
# type: (Text) -> bytes
def der_encode_string(val: Text) -> bytes:
if not isinstance(val, Text):
raise TypeError("unicode")
return der_encode_tlv(0x1b, val.encode("utf-8"))
def der_encode_octet_string(val):
# type: (bytes) -> bytes
def der_encode_octet_string(val: bytes) -> bytes:
if not isinstance(val, bytes):
raise TypeError("bytes")
return der_encode_tlv(0x04, val)
def der_encode_sequence(tlvs, tagged=True):
# type: (List[Optional[bytes]], Optional[bool]) -> bytes
def der_encode_sequence(tlvs: List[Optional[bytes]], tagged: Optional[bool]=True) -> bytes:
body = []
for i, tlv in enumerate(tlvs):
# Missing optional elements represented as None.
@@ -113,8 +104,7 @@ def der_encode_sequence(tlvs, tagged=True):
body.append(tlv)
return der_encode_tlv(0x30, b"".join(body))
def der_encode_ticket(tkt):
# type: (Dict[str, Any]) -> bytes
def der_encode_ticket(tkt: Dict[str, Any]) -> bytes:
return der_encode_tlv(
0x61, # Ticket
der_encode_sequence(
@@ -136,34 +126,29 @@ def der_encode_ticket(tkt):
# Kerberos ccache writing code. Using format documentation from here:
# http://www.gnu.org/software/shishi/manual/html_node/The-Credential-Cache-Binary-File-Format.html
def ccache_counted_octet_string(data):
# type: (bytes) -> bytes
def ccache_counted_octet_string(data: bytes) -> bytes:
if not isinstance(data, bytes):
raise TypeError("bytes")
return struct.pack("!I", len(data)) + data
def ccache_principal(name, realm):
# type: (Dict[str, str], str) -> bytes
def ccache_principal(name: Dict[str, str], realm: str) -> bytes:
header = struct.pack("!II", name["nameType"], len(name["nameString"]))
return (header + ccache_counted_octet_string(force_bytes(realm)) +
b"".join(ccache_counted_octet_string(force_bytes(c))
for c in name["nameString"]))
def ccache_key(key):
# type: (Dict[str, str]) -> bytes
def ccache_key(key: Dict[str, str]) -> bytes:
return (struct.pack("!H", key["keytype"]) +
ccache_counted_octet_string(base64.b64decode(key["keyvalue"])))
def flags_to_uint32(flags):
# type: (List[str]) -> int
def flags_to_uint32(flags: List[str]) -> int:
ret = 0
for i, v in enumerate(flags):
if v:
ret |= 1 << (31 - i)
return ret
def ccache_credential(cred):
# type: (Dict[str, Any]) -> bytes
def ccache_credential(cred: Dict[str, Any]) -> bytes:
out = ccache_principal(cred["cname"], cred["crealm"])
out += ccache_principal(cred["sname"], cred["srealm"])
out += ccache_key(cred["key"])
@@ -181,8 +166,7 @@ def ccache_credential(cred):
out += ccache_counted_octet_string(b"")
return out
def make_ccache(cred):
# type: (Dict[str, Any]) -> bytes
def make_ccache(cred: Dict[str, Any]) -> bytes:
# Do we need a DeltaTime header? The ccache I get just puts zero
# in there, so do the same.
out = struct.pack("!HHHHII",

View File

@@ -8,8 +8,7 @@ from contextlib import contextmanager
from typing import Iterator, IO, Any, Union
@contextmanager
def flock(lockfile, shared=False):
# type: (Union[int, IO[Any]], bool) -> Iterator[None]
def flock(lockfile: Union[int, IO[Any]], shared: bool=False) -> Iterator[None]:
"""Lock a file object using flock(2) for the duration of a 'with' statement.
If shared is True, use a LOCK_SH lock, otherwise LOCK_EX."""
@@ -21,8 +20,7 @@ def flock(lockfile, shared=False):
fcntl.flock(lockfile, fcntl.LOCK_UN)
@contextmanager
def lockfile(filename, shared=False):
# type: (str, bool) -> Iterator[None]
def lockfile(filename: str, shared: bool=False) -> Iterator[None]:
"""Lock a file using flock(2) for the duration of a 'with' statement.
If shared is True, use a LOCK_SH lock, otherwise LOCK_EX.

View File

@@ -11,8 +11,10 @@ ParamsT = Union[Iterable[Any], Mapping[Text, Any]]
# Similar to the tracking done in Django's CursorDebugWrapper, but done at the
# psycopg2 cursor level so it works with SQLAlchemy.
def wrapper_execute(self, action, sql, params=()):
# type: (CursorObj, Callable[[NonBinaryStr, Optional[ParamsT]], CursorObj], NonBinaryStr, Optional[ParamsT]) -> CursorObj
def wrapper_execute(self: CursorObj,
action: Callable[[NonBinaryStr, Optional[ParamsT]], CursorObj],
sql: NonBinaryStr,
params: Optional[ParamsT]=()) -> CursorObj:
start = time.time()
try:
return action(sql, params)
@@ -26,29 +28,26 @@ def wrapper_execute(self, action, sql, params=()):
class TimeTrackingCursor(cursor):
"""A psycopg2 cursor class that tracks the time spent executing queries."""
def execute(self, query, vars=None):
# type: (NonBinaryStr, Optional[ParamsT]) -> TimeTrackingCursor
def execute(self, query: NonBinaryStr,
vars: Optional[ParamsT]=None) -> 'TimeTrackingCursor':
return wrapper_execute(self, super().execute, query, vars)
def executemany(self, query, vars):
# type: (NonBinaryStr, Iterable[Any]) -> TimeTrackingCursor
def executemany(self, query: NonBinaryStr,
vars: Iterable[Any]) -> 'TimeTrackingCursor':
return wrapper_execute(self, super().executemany, query, vars)
class TimeTrackingConnection(connection):
"""A psycopg2 connection class that uses TimeTrackingCursors."""
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.queries = [] # type: List[Dict[str, str]]
super().__init__(*args, **kwargs)
def cursor(self, *args, **kwargs):
# type: (*Any, **Any) -> TimeTrackingCursor
def cursor(self, *args: Any, **kwargs: Any) -> TimeTrackingCursor:
kwargs.setdefault('cursor_factory', TimeTrackingCursor)
return connection.cursor(self, *args, **kwargs)
def reset_queries():
# type: () -> None
def reset_queries() -> None:
from django.db import connections
for conn in connections.all():
if conn.connection is not None:

View File

@@ -27,8 +27,7 @@ talon.init()
logger = logging.getLogger(__name__)
def redact_stream(error_message):
# type: (Text) -> Text
def redact_stream(error_message: Text) -> Text:
domain = settings.EMAIL_GATEWAY_PATTERN.rsplit('@')[-1]
stream_match = re.search('\\b(.*?)@' + domain, error_message)
if stream_match:
@@ -36,8 +35,7 @@ def redact_stream(error_message):
return error_message.replace(stream_name, "X" * len(stream_name))
return error_message
def report_to_zulip(error_message):
# type: (Text) -> None
def report_to_zulip(error_message: Text) -> None:
if settings.ERROR_BOT is None:
return
error_bot = get_system_bot(settings.ERROR_BOT)
@@ -45,8 +43,7 @@ def report_to_zulip(error_message):
send_zulip(settings.ERROR_BOT, error_stream, u"email mirror error",
u"""~~~\n%s\n~~~""" % (error_message,))
def log_and_report(email_message, error_message, debug_info):
# type: (message.Message, Text, Dict[str, Any]) -> None
def log_and_report(email_message: message.Message, error_message: Text, debug_info: Dict[str, Any]) -> None:
scrubbed_error = u"Sender: %s\n%s" % (email_message.get("From"),
redact_stream(error_message))
@@ -67,26 +64,22 @@ def log_and_report(email_message, error_message, debug_info):
redis_client = get_redis_client()
def missed_message_redis_key(token):
# type: (Text) -> Text
def missed_message_redis_key(token: Text) -> Text:
return 'missed_message:' + token
def is_missed_message_address(address):
# type: (Text) -> bool
def is_missed_message_address(address: Text) -> bool:
msg_string = get_email_gateway_message_string_from_address(address)
return is_mm_32_format(msg_string)
def is_mm_32_format(msg_string):
# type: (Optional[Text]) -> bool
def is_mm_32_format(msg_string: Optional[Text]) -> bool:
'''
Missed message strings are formatted with a little "mm" prefix
followed by a randomly generated 32-character string.
'''
return msg_string is not None and msg_string.startswith('mm') and len(msg_string) == 34
def get_missed_message_token_from_address(address):
# type: (Text) -> Text
def get_missed_message_token_from_address(address: Text) -> Text:
msg_string = get_email_gateway_message_string_from_address(address)
if msg_string is None:
@@ -98,8 +91,7 @@ def get_missed_message_token_from_address(address):
# strip off the 'mm' before returning the redis key
return msg_string[2:]
def create_missed_message_address(user_profile, message):
# type: (UserProfile, Message) -> str
def create_missed_message_address(user_profile: UserProfile, message: Message) -> str:
if settings.EMAIL_GATEWAY_PATTERN == '':
logger.warning("EMAIL_GATEWAY_PATTERN is an empty string, using "
"NOREPLY_EMAIL_ADDRESS in the 'from' field.")
@@ -132,8 +124,7 @@ def create_missed_message_address(user_profile, message):
return settings.EMAIL_GATEWAY_PATTERN % (address,)
def mark_missed_message_address_as_used(address):
# type: (Text) -> None
def mark_missed_message_address_as_used(address: Text) -> None:
token = get_missed_message_token_from_address(address)
key = missed_message_redis_key(token)
with redis_client.pipeline() as pipeline:
@@ -144,8 +135,7 @@ def mark_missed_message_address_as_used(address):
redis_client.delete(key)
raise ZulipEmailForwardError('Missed message address has already been used')
def construct_zulip_body(message, realm):
# type: (message.Message, Realm) -> Text
def construct_zulip_body(message: message.Message, realm: Realm) -> Text:
body = extract_body(message)
# Remove null characters, since Zulip will reject
body = body.replace("\x00", "")
@@ -156,8 +146,7 @@ def construct_zulip_body(message, realm):
body = '(No email body)'
return body
def send_to_missed_message_address(address, message):
# type: (Text, message.Message) -> None
def send_to_missed_message_address(address: Text, message: message.Message) -> None:
token = get_missed_message_token_from_address(address)
key = missed_message_redis_key(token)
result = redis_client.hmget(key, 'user_profile_id', 'recipient_id', 'subject')
@@ -194,8 +183,7 @@ def send_to_missed_message_address(address, message):
class ZulipEmailForwardError(Exception):
pass
def send_zulip(sender, stream, topic, content):
# type: (Text, Stream, Text, Text) -> None
def send_zulip(sender: Text, stream: Stream, topic: Text, content: Text) -> None:
internal_send_message(
stream.realm,
sender,
@@ -205,16 +193,14 @@ def send_zulip(sender, stream, topic, content):
content[:2000],
email_gateway=True)
def valid_stream(stream_name, token):
# type: (Text, Text) -> bool
def valid_stream(stream_name: Text, token: Text) -> bool:
try:
stream = Stream.objects.get(email_token=token)
return stream.name.lower() == stream_name.lower()
except Stream.DoesNotExist:
return False
def get_message_part_by_type(message, content_type):
# type: (message.Message, Text) -> Optional[Text]
def get_message_part_by_type(message: message.Message, content_type: Text) -> Optional[Text]:
charsets = message.get_charsets()
for idx, part in enumerate(message.walk()):
@@ -225,8 +211,7 @@ def get_message_part_by_type(message, content_type):
return content.decode(charsets[idx], errors="ignore")
return None
def extract_body(message):
# type: (message.Message) -> Text
def extract_body(message: message.Message) -> Text:
# If the message contains a plaintext version of the body, use
# that.
plaintext_content = get_message_part_by_type(message, "text/plain")
@@ -240,8 +225,7 @@ def extract_body(message):
raise ZulipEmailForwardError("Unable to find plaintext or HTML message body")
def filter_footer(text):
# type: (Text) -> Text
def filter_footer(text: Text) -> Text:
# Try to filter out obvious footers.
possible_footers = [line for line in text.split("\n") if line.strip().startswith("--")]
if len(possible_footers) != 1:
@@ -251,8 +235,7 @@ def filter_footer(text):
return text.partition("--")[0].strip()
def extract_and_upload_attachments(message, realm):
# type: (message.Message, Realm) -> Text
def extract_and_upload_attachments(message: message.Message, realm: Realm) -> Text:
user_profile = get_system_bot(settings.EMAIL_GATEWAY_BOT)
attachment_links = []
@@ -279,8 +262,7 @@ def extract_and_upload_attachments(message, realm):
return u"\n".join(attachment_links)
def extract_and_validate(email):
# type: (Text) -> Stream
def extract_and_validate(email: Text) -> Stream:
temp = decode_email_address(email)
if temp is None:
raise ZulipEmailForwardError("Malformed email recipient " + email)
@@ -291,8 +273,7 @@ def extract_and_validate(email):
return Stream.objects.get(email_token=token)
def find_emailgateway_recipient(message):
# type: (message.Message) -> Text
def find_emailgateway_recipient(message: message.Message) -> Text:
# We can't use Delivered-To; if there is a X-Gm-Original-To
# it is more accurate, so try to find the most-accurate
# recipient list in descending priority order
@@ -312,8 +293,8 @@ def find_emailgateway_recipient(message):
raise ZulipEmailForwardError("Missing recipient in mirror email")
def process_stream_message(to, subject, message, debug_info):
# type: (Text, Text, message.Message, Dict[str, Any]) -> None
def process_stream_message(to: Text, subject: Text, message: message.Message,
debug_info: Dict[str, Any]) -> None:
stream = extract_and_validate(to)
body = construct_zulip_body(message, stream.realm)
debug_info["stream"] = stream
@@ -321,14 +302,12 @@ def process_stream_message(to, subject, message, debug_info):
logger.info("Successfully processed email to %s (%s)" % (
stream.name, stream.realm.string_id))
def process_missed_message(to, message, pre_checked):
# type: (Text, message.Message, bool) -> None
def process_missed_message(to: Text, message: message.Message, pre_checked: bool) -> None:
if not pre_checked:
mark_missed_message_address_as_used(to)
send_to_missed_message_address(to, message)
def process_message(message, rcpt_to=None, pre_checked=False):
# type: (message.Message, Optional[Text], bool) -> None
def process_message(message: message.Message, rcpt_to: Optional[Text]=None, pre_checked: bool=False) -> None:
subject_header = message.get("Subject", "(no subject)")
encoded_subject, encoding = decode_header(subject_header)[0]
if encoding is None:
@@ -357,8 +336,7 @@ def process_message(message, rcpt_to=None, pre_checked=False):
log_and_report(message, str(e), debug_info)
def mirror_email_message(data):
# type: (Dict[Text, Text]) -> Dict[str, str]
def mirror_email_message(data: Dict[Text, Text]) -> Dict[str, str]:
rcpt_to = data['recipient']
if is_missed_message_address(rcpt_to):
try:

View File

@@ -13,8 +13,7 @@ import time
client = get_redis_client()
def has_enough_time_expired_since_last_message(sender_email, min_delay):
# type: (Text, float) -> bool
def has_enough_time_expired_since_last_message(sender_email: Text, min_delay: float) -> bool:
# This function returns a boolean, but it also has the side effect
# of noting that a new message was received.
key = 'zilencer:feedback:%s' % (sender_email,)
@@ -25,8 +24,7 @@ def has_enough_time_expired_since_last_message(sender_email, min_delay):
delay = t - int(last_time)
return delay > min_delay
def get_ticket_number():
# type: () -> int
def get_ticket_number() -> int:
num_file = '/var/tmp/.feedback-bot-ticket-number'
try:
ticket_number = int(open(num_file).read()) + 1
@@ -35,8 +33,7 @@ def get_ticket_number():
open(num_file, 'w').write('%d' % (ticket_number,))
return ticket_number
def deliver_feedback_by_zulip(message):
# type: (Mapping[str, Any]) -> None
def deliver_feedback_by_zulip(message: Mapping[str, Any]) -> None:
subject = "%s" % (message["sender_email"],)
if len(subject) > 60:
@@ -67,8 +64,7 @@ def deliver_feedback_by_zulip(message):
internal_send_message(user_profile.realm, settings.FEEDBACK_BOT,
"stream", settings.FEEDBACK_STREAM, subject, content)
def handle_feedback(event):
# type: (Mapping[str, Any]) -> None
def handle_feedback(event: Mapping[str, Any]) -> None:
if not settings.ENABLE_FEEDBACK:
return
if settings.FEEDBACK_EMAIL is not None:

View File

@@ -20,8 +20,7 @@ migration runs.
logger = logging.getLogger('zulip.fix_unreads')
logger.setLevel(logging.WARNING)
def build_topic_mute_checker(cursor, user_profile):
# type: (CursorObj, UserProfile) -> Callable[[int, Text], bool]
def build_topic_mute_checker(cursor: CursorObj, user_profile: UserProfile) -> Callable[[int, Text], bool]:
'''
This function is similar to the function of the same name
in zerver/lib/topic_mutes.py, but it works without the ORM,
@@ -44,14 +43,12 @@ def build_topic_mute_checker(cursor, user_profile):
for (recipient_id, topic_name) in rows
}
def is_muted(recipient_id, topic):
# type: (int, Text) -> bool
def is_muted(recipient_id: int, topic: Text) -> bool:
return (recipient_id, topic.lower()) in tups
return is_muted
def update_unread_flags(cursor, user_message_ids):
# type: (CursorObj, List[int]) -> None
def update_unread_flags(cursor: CursorObj, user_message_ids: List[int]) -> None:
um_id_list = ', '.join(str(id) for id in user_message_ids)
query = '''
UPDATE zerver_usermessage
@@ -62,8 +59,7 @@ def update_unread_flags(cursor, user_message_ids):
cursor.execute(query)
def get_timing(message, f):
# type: (str, Callable[[], None]) -> None
def get_timing(message: str, f: Callable[[], None]) -> None:
start = time.time()
logger.info(message)
f()
@@ -71,13 +67,11 @@ def get_timing(message, f):
logger.info('elapsed time: %.03f\n' % (elapsed,))
def fix_unsubscribed(cursor, user_profile):
# type: (CursorObj, UserProfile) -> None
def fix_unsubscribed(cursor: CursorObj, user_profile: UserProfile) -> None:
recipient_ids = []
def find_recipients():
# type: () -> None
def find_recipients() -> None:
query = '''
SELECT
zerver_subscription.recipient_id
@@ -108,8 +102,7 @@ def fix_unsubscribed(cursor, user_profile):
user_message_ids = []
def find():
# type: () -> None
def find() -> None:
recips = ', '.join(str(id) for id in recipient_ids)
query = '''
@@ -144,8 +137,7 @@ def fix_unsubscribed(cursor, user_profile):
if not user_message_ids:
return
def fix():
# type: () -> None
def fix() -> None:
update_unread_flags(cursor, user_message_ids)
get_timing(
@@ -153,8 +145,7 @@ def fix_unsubscribed(cursor, user_profile):
fix
)
def fix_pre_pointer(cursor, user_profile):
# type: (CursorObj, UserProfile) -> None
def fix_pre_pointer(cursor: CursorObj, user_profile: UserProfile) -> None:
pointer = user_profile.pointer
@@ -163,8 +154,7 @@ def fix_pre_pointer(cursor, user_profile):
recipient_ids = []
def find_non_muted_recipients():
# type: () -> None
def find_non_muted_recipients() -> None:
query = '''
SELECT
zerver_subscription.recipient_id
@@ -196,8 +186,7 @@ def fix_pre_pointer(cursor, user_profile):
user_message_ids = []
def find_old_ids():
# type: () -> None
def find_old_ids() -> None:
recips = ', '.join(str(id) for id in recipient_ids)
is_topic_muted = build_topic_mute_checker(cursor, user_profile)
@@ -238,8 +227,7 @@ def fix_pre_pointer(cursor, user_profile):
if not user_message_ids:
return
def fix():
# type: () -> None
def fix() -> None:
update_unread_flags(cursor, user_message_ids)
get_timing(
@@ -247,8 +235,7 @@ def fix_pre_pointer(cursor, user_profile):
fix
)
def fix(user_profile):
# type: (UserProfile) -> None
def fix(user_profile: UserProfile) -> None:
logger.info('\n---\nFixing %s:' % (user_profile.email,))
with connection.cursor() as cursor:
fix_unsubscribed(cursor, user_profile)

View File

@@ -82,12 +82,10 @@ class Integration:
stream_name = self.name
self.stream_name = stream_name
def is_enabled(self):
# type: () -> bool
def is_enabled(self) -> bool:
return True
def get_logo_url(self):
# type: () -> Optional[str]
def get_logo_url(self) -> Optional[str]:
logo_file_path_svg = str(pathlib.PurePath(
settings.STATIC_ROOT,
*self.DEFAULT_LOGO_STATIC_PATH_SVG.format(name=self.name).split('/')[1:]
@@ -139,8 +137,7 @@ class BotIntegration(Integration):
self.doc = doc
class EmailIntegration(Integration):
def is_enabled(self):
# type: () -> bool
def is_enabled(self) -> bool:
return settings.EMAIL_GATEWAY_PATTERN != ""
class WebhookIntegration(Integration):
@@ -183,8 +180,7 @@ class WebhookIntegration(Integration):
self.doc = doc
@property
def url_object(self):
# type: () -> LocaleRegexProvider
def url_object(self) -> LocaleRegexProvider:
return url(self.url, self.function)
class HubotIntegration(Integration):
@@ -234,8 +230,7 @@ class GithubIntegration(WebhookIntegration):
)
@property
def url_object(self):
# type: () -> None
def url_object(self) -> None:
return
class EmbeddedBotIntegration(Integration):
@@ -245,8 +240,7 @@ class EmbeddedBotIntegration(Integration):
'''
DEFAULT_CLIENT_NAME = 'Zulip{name}EmbeddedBot'
def __init__(self, name, *args, **kwargs):
# type: (str, *Any, **Any) -> None
def __init__(self, name: str, *args: Any, **kwargs: Any) -> None:
assert kwargs.get("client_name") is None
client_name = self.DEFAULT_CLIENT_NAME.format(name=name.title())
super().__init__(

View File

@@ -9,8 +9,7 @@ from zerver.models import Realm, UserProfile, Message, Reaction, get_system_bot
from typing import Any, Dict, List, Mapping, Text
def send_initial_pms(user):
# type: (UserProfile) -> None
def send_initial_pms(user: UserProfile) -> None:
organization_setup_text = ""
if user.is_realm_admin:
help_url = user.realm.uri + "/help/getting-your-organization-started-with-zulip"
@@ -33,8 +32,7 @@ def send_initial_pms(user):
internal_send_private_message(user.realm, get_system_bot(settings.WELCOME_BOT),
user, content)
def setup_initial_streams(realm):
# type: (Realm) -> None
def setup_initial_streams(realm: Realm) -> None:
stream_dicts = [
{'name': "general"},
{'name': "new members",
@@ -46,8 +44,7 @@ def setup_initial_streams(realm):
create_streams_if_needed(realm, stream_dicts)
set_default_streams(realm, {stream['name']: {} for stream in stream_dicts})
def send_initial_realm_messages(realm):
# type: (Realm) -> None
def send_initial_realm_messages(realm: Realm) -> None:
welcome_bot = get_system_bot(settings.WELCOME_BOT)
# Make sure each stream created in the realm creation process has at least one message below
# Order corresponds to the ordering of the streams on the left sidebar, to make the initial Home

View File

@@ -6,16 +6,13 @@ import sqlalchemy
# This is a Pool that doesn't close connections. Therefore it can be used with
# existing Django database connections.
class NonClosingPool(sqlalchemy.pool.NullPool):
def status(self):
# type: () -> str
def status(self) -> str:
return "NonClosingPool"
def _do_return_conn(self, conn):
# type: (sqlalchemy.engine.base.Connection) -> None
def _do_return_conn(self, conn: sqlalchemy.engine.base.Connection) -> None:
pass
def recreate(self):
# type: () -> NonClosingPool
def recreate(self) -> 'NonClosingPool':
return self.__class__(creator=self._creator,
recycle=self._recycle,
use_threadlocal=self._use_threadlocal,
@@ -25,12 +22,10 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
_dispatch=self.dispatch)
sqlalchemy_engine = None
def get_sqlalchemy_connection():
# type: () -> sqlalchemy.engine.base.Connection
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
global sqlalchemy_engine
if sqlalchemy_engine is None:
def get_dj_conn():
# type: () -> TimeTrackingConnection
def get_dj_conn() -> TimeTrackingConnection:
connection.ensure_connection()
return connection.connection
sqlalchemy_engine = sqlalchemy.create_engine('postgresql://',

View File

@@ -34,8 +34,7 @@ from typing import Any, Dict, Mapping, Union, TypeVar, Text
NonBinaryStr = TypeVar('NonBinaryStr', str, Text)
# This is used to represent text or native strings
def force_text(s, encoding='utf-8'):
# type: (Union[Text, bytes], str) -> Text
def force_text(s: Union[Text, bytes], encoding: str='utf-8') -> Text:
"""converts a string to a text string"""
if isinstance(s, Text):
return s
@@ -44,8 +43,7 @@ def force_text(s, encoding='utf-8'):
else:
raise TypeError("force_text expects a string type")
def force_bytes(s, encoding='utf-8'):
# type: (Union[Text, bytes], str) -> bytes
def force_bytes(s: Union[Text, bytes], encoding: str='utf-8') -> bytes:
"""converts a string to binary string"""
if isinstance(s, bytes):
return s
@@ -54,8 +52,7 @@ def force_bytes(s, encoding='utf-8'):
else:
raise TypeError("force_bytes expects a string type")
def force_str(s, encoding='utf-8'):
# type: (Union[Text, bytes], str) -> str
def force_str(s: Union[Text, bytes], encoding: str='utf-8') -> str:
"""converts a string to a native string"""
if isinstance(s, str):
return s
@@ -74,16 +71,13 @@ class ModelReprMixin:
This mixin will automatically define __str__ and __repr__.
"""
def __unicode__(self):
# type: () -> Text
def __unicode__(self) -> Text:
# Originally raised an exception, but Django (e.g. the ./manage.py shell)
# was catching the exception and not displaying any sort of error
return u"Implement __unicode__ in your subclass of ModelReprMixin!"
def __str__(self):
# type: () -> str
def __str__(self) -> str:
return force_str(self.__unicode__())
def __repr__(self):
# type: () -> str
def __repr__(self) -> str:
return force_str(self.__unicode__())