zerver/lib: Use python 3 syntax for typing.

Edited by tabbott to improve various line-wrapping decisions.
This commit is contained in:
rht
2017-11-05 11:15:10 +01:00
committed by Tim Abbott
parent 229a8b38c0
commit ee546a33a3
8 changed files with 523 additions and 746 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -9,15 +9,13 @@ from typing import Text, Optional, List, Tuple
class StateError(Exception): class StateError(Exception):
pass pass
def get_bot_storage(bot_profile, key): def get_bot_storage(bot_profile: UserProfile, key: Text) -> Text:
# type: (UserProfile, Text) -> Text
try: try:
return BotStorageData.objects.get(bot_profile=bot_profile, key=key).value return BotStorageData.objects.get(bot_profile=bot_profile, key=key).value
except BotStorageData.DoesNotExist: except BotStorageData.DoesNotExist:
raise StateError("Key does not exist.") raise StateError("Key does not exist.")
def get_bot_storage_size(bot_profile, key=None): def get_bot_storage_size(bot_profile: UserProfile, key: Optional[Text]=None) -> int:
# type: (UserProfile, Optional[Text]) -> int
if key is None: if key is None:
return BotStorageData.objects.filter(bot_profile=bot_profile) \ return BotStorageData.objects.filter(bot_profile=bot_profile) \
.annotate(key_size=Length('key'), value_size=Length('value')) \ .annotate(key_size=Length('key'), value_size=Length('value')) \

View File

@@ -61,11 +61,6 @@ FullNameInfo = TypedDict('FullNameInfo', {
version = 1 version = 1
_T = TypeVar('_T') _T = TypeVar('_T')
# We need to avoid this running at runtime, but mypy will see this.
# The problem is that under python 2, Element isn't exactly a type,
# which means that at runtime Union causes this to blow up.
if False:
# mypy requires the Optional to be inside Union
ElementStringNone = Union[Element, Optional[Text]] ElementStringNone = Union[Element, Optional[Text]]
AVATAR_REGEX = r'!avatar\((?P<email>[^)]*)\)' AVATAR_REGEX = r'!avatar\((?P<email>[^)]*)\)'
@@ -82,8 +77,7 @@ STREAM_LINK_REGEX = r"""
class BugdownRenderingException(Exception): class BugdownRenderingException(Exception):
pass pass
def url_embed_preview_enabled_for_realm(message): def url_embed_preview_enabled_for_realm(message: Optional[Message]) -> bool:
# type: (Optional[Message]) -> bool
if message is not None: if message is not None:
realm = message.get_realm() # type: Optional[Realm] realm = message.get_realm() # type: Optional[Realm]
else: else:
@@ -95,8 +89,7 @@ def url_embed_preview_enabled_for_realm(message):
return True return True
return realm.inline_url_embed_preview return realm.inline_url_embed_preview
def image_preview_enabled_for_realm(): def image_preview_enabled_for_realm() -> bool:
# type: () -> bool
global current_message global current_message
if current_message is not None: if current_message is not None:
realm = current_message.get_realm() # type: Optional[Realm] realm = current_message.get_realm() # type: Optional[Realm]
@@ -108,8 +101,7 @@ def image_preview_enabled_for_realm():
return True return True
return realm.inline_image_preview return realm.inline_image_preview
def list_of_tlds(): def list_of_tlds() -> List[Text]:
# type: () -> List[Text]
# HACK we manually blacklist a few domains # HACK we manually blacklist a few domains
blacklist = ['PY\n', "MD\n"] blacklist = ['PY\n', "MD\n"]
@@ -120,8 +112,9 @@ def list_of_tlds():
tlds.sort(key=len, reverse=True) tlds.sort(key=len, reverse=True)
return tlds return tlds
def walk_tree(root, processor, stop_after_first=False): def walk_tree(root: Element,
# type: (Element, Callable[[Element], Optional[_T]], bool) -> List[_T] processor: Callable[[Element], Optional[_T]],
stop_after_first: bool=False) -> List[_T]:
results = [] results = []
queue = deque([root]) queue = deque([root])
@@ -166,8 +159,7 @@ def add_a(root, url, link, title=None, desc=None,
desc_div.set("class", "message_inline_image_desc") desc_div.set("class", "message_inline_image_desc")
def add_embed(root, link, extracted_data): def add_embed(root: Element, link: Text, extracted_data: Dict[Text, Any]) -> None:
# type: (Element, Text, Dict[Text, Any]) -> None
container = markdown.util.etree.SubElement(root, "div") container = markdown.util.etree.SubElement(root, "div")
container.set("class", "message_embed") container.set("class", "message_embed")
@@ -206,8 +198,7 @@ def add_embed(root, link, extracted_data):
@cache_with_key(lambda tweet_id: tweet_id, cache_name="database", with_statsd_key="tweet_data") @cache_with_key(lambda tweet_id: tweet_id, cache_name="database", with_statsd_key="tweet_data")
def fetch_tweet_data(tweet_id): def fetch_tweet_data(tweet_id: Text) -> Optional[Dict[Text, Any]]:
# type: (Text) -> Optional[Dict[Text, Any]]
if settings.TEST_SUITE: if settings.TEST_SUITE:
from . import testing_mocks from . import testing_mocks
res = testing_mocks.twitter(tweet_id) res = testing_mocks.twitter(tweet_id)
@@ -266,8 +257,7 @@ HEAD_END_RE = re.compile('^/head[ >]')
META_START_RE = re.compile('^meta[ >]') META_START_RE = re.compile('^meta[ >]')
META_END_RE = re.compile('^/meta[ >]') META_END_RE = re.compile('^/meta[ >]')
def fetch_open_graph_image(url): def fetch_open_graph_image(url: Text) -> Optional[Dict[str, Any]]:
# type: (Text) -> Optional[Dict[str, Any]]
in_head = False in_head = False
# HTML will auto close meta tags, when we start the next tag add # HTML will auto close meta tags, when we start the next tag add
# a closing tag if it has not been closed yet. # a closing tag if it has not been closed yet.
@@ -333,8 +323,7 @@ def fetch_open_graph_image(url):
desc = og_desc.get('content') desc = og_desc.get('content')
return {'image': image, 'title': title, 'desc': desc} return {'image': image, 'title': title, 'desc': desc}
def get_tweet_id(url): def get_tweet_id(url: Text) -> Optional[Text]:
# type: (Text) -> Optional[Text]
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
if not (parsed_url.netloc == 'twitter.com' or parsed_url.netloc.endswith('.twitter.com')): if not (parsed_url.netloc == 'twitter.com' or parsed_url.netloc.endswith('.twitter.com')):
return None return None
@@ -350,8 +339,7 @@ def get_tweet_id(url):
return tweet_id_match.group("tweetid") return tweet_id_match.group("tweetid")
class InlineHttpsProcessor(markdown.treeprocessors.Treeprocessor): class InlineHttpsProcessor(markdown.treeprocessors.Treeprocessor):
def run(self, root): def run(self, root: Element) -> None:
# type: (Element) -> None
# Get all URLs from the blob # Get all URLs from the blob
found_imgs = walk_tree(root, lambda e: e if e.tag == "img" else None) found_imgs = walk_tree(root, lambda e: e if e.tag == "img" else None)
for img in found_imgs: for img in found_imgs:
@@ -365,14 +353,12 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
TWITTER_MAX_IMAGE_HEIGHT = 400 TWITTER_MAX_IMAGE_HEIGHT = 400
TWITTER_MAX_TO_PREVIEW = 3 TWITTER_MAX_TO_PREVIEW = 3
def __init__(self, md, bugdown): def __init__(self, md: markdown.Markdown, bugdown: 'Bugdown') -> None:
# type: (markdown.Markdown, Bugdown) -> None
# Passing in bugdown for access to config to check if realm is zulip.com # Passing in bugdown for access to config to check if realm is zulip.com
self.bugdown = bugdown self.bugdown = bugdown
markdown.treeprocessors.Treeprocessor.__init__(self, md) markdown.treeprocessors.Treeprocessor.__init__(self, md)
def get_actual_image_url(self, url): def get_actual_image_url(self, url: Text) -> Text:
# type: (Text) -> Text
# Add specific per-site cases to convert image-preview urls to image urls. # Add specific per-site cases to convert image-preview urls to image urls.
# See https://github.com/zulip/zulip/issues/4658 for more information # See https://github.com/zulip/zulip/issues/4658 for more information
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
@@ -386,8 +372,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
return url return url
def is_image(self, url): def is_image(self, url: Text) -> bool:
# type: (Text) -> bool
if not image_preview_enabled_for_realm(): if not image_preview_enabled_for_realm():
return False return False
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
@@ -397,8 +382,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
return True return True
return False return False
def dropbox_image(self, url): def dropbox_image(self, url: Text) -> Optional[Dict[str, Any]]:
# type: (Text) -> Optional[Dict[str, Any]]
# TODO: The returned Dict could possibly be a TypedDict in future. # TODO: The returned Dict could possibly be a TypedDict in future.
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
if (parsed_url.netloc == 'dropbox.com' or parsed_url.netloc.endswith('.dropbox.com')): if (parsed_url.netloc == 'dropbox.com' or parsed_url.netloc.endswith('.dropbox.com')):
@@ -443,8 +427,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
return image_info return image_info
return None return None
def youtube_id(self, url): def youtube_id(self, url: Text) -> Optional[Text]:
# type: (Text) -> Optional[Text]
if not image_preview_enabled_for_realm(): if not image_preview_enabled_for_realm():
return None return None
# Youtube video id extraction regular expression from http://pastebin.com/KyKAFv1s # Youtube video id extraction regular expression from http://pastebin.com/KyKAFv1s
@@ -457,16 +440,17 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
return None return None
return match.group(2) return match.group(2)
def youtube_image(self, url): def youtube_image(self, url: Text) -> Optional[Text]:
# type: (Text) -> Optional[Text]
yt_id = self.youtube_id(url) yt_id = self.youtube_id(url)
if yt_id is not None: if yt_id is not None:
return "https://i.ytimg.com/vi/%s/default.jpg" % (yt_id,) return "https://i.ytimg.com/vi/%s/default.jpg" % (yt_id,)
return None return None
def twitter_text(self, text, urls, user_mentions, media): def twitter_text(self, text: Text,
# type: (Text, List[Dict[Text, Text]], List[Dict[Text, Any]], List[Dict[Text, Any]]) -> Element urls: List[Dict[Text, Text]],
user_mentions: List[Dict[Text, Any]],
media: List[Dict[Text, Any]]) -> Element:
""" """
Use data from the twitter API to turn links, mentions and media into A Use data from the twitter API to turn links, mentions and media into A
tags. Also convert unicode emojis to images. tags. Also convert unicode emojis to images.
@@ -542,8 +526,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
to_process.sort(key=lambda x: x['start']) to_process.sort(key=lambda x: x['start'])
p = current_node = markdown.util.etree.Element('p') p = current_node = markdown.util.etree.Element('p')
def set_text(text): def set_text(text: Text) -> None:
# type: (Text) -> None
""" """
Helper to set the text or the tail of the current_node Helper to set the text or the tail of the current_node
""" """
@@ -571,8 +554,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
set_text(text[current_index:]) set_text(text[current_index:])
return p return p
def twitter_link(self, url): def twitter_link(self, url: Text) -> Optional[Element]:
# type: (Text) -> Optional[Element]
tweet_id = get_tweet_id(url) tweet_id = get_tweet_id(url)
if tweet_id is None: if tweet_id is None:
@@ -641,16 +623,14 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return None return None
def get_url_data(self, e): def get_url_data(self, e: Element) -> Optional[Tuple[Text, Text]]:
# type: (Element) -> Optional[Tuple[Text, Text]]
if e.tag == "a": if e.tag == "a":
if e.text is not None: if e.text is not None:
return (e.get("href"), e.text) return (e.get("href"), e.text)
return (e.get("href"), e.get("href")) return (e.get("href"), e.get("href"))
return None return None
def is_only_element(self, root, url): def is_only_element(self, root: Element, url: str) -> bool:
# type: (Element, str) -> bool
# Check if the url is the only content of the message. # Check if the url is the only content of the message.
if not len(root) == 1: if not len(root) == 1:
@@ -668,8 +648,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
return True return True
def run(self, root): def run(self, root: Element) -> None:
# type: (Element) -> None
# Get all URLs from the blob # Get all URLs from the blob
found_urls = walk_tree(root, self.get_url_data) found_urls = walk_tree(root, self.get_url_data)
@@ -735,8 +714,7 @@ class InlineInterestingLinkProcessor(markdown.treeprocessors.Treeprocessor):
class Avatar(markdown.inlinepatterns.Pattern): class Avatar(markdown.inlinepatterns.Pattern):
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
img = markdown.util.etree.Element('img') img = markdown.util.etree.Element('img')
email_address = match.group('email') email_address = match.group('email')
email = email_address.strip().lower() email = email_address.strip().lower()
@@ -753,8 +731,7 @@ class Avatar(markdown.inlinepatterns.Pattern):
img.set('alt', email) img.set('alt', email)
return img return img
def possible_avatar_emails(content): def possible_avatar_emails(content: Text) -> Set[Text]:
# type: (Text) -> Set[Text]
emails = set() emails = set()
for regex in [AVATAR_REGEX, GRAVATAR_REGEX]: for regex in [AVATAR_REGEX, GRAVATAR_REGEX]:
matches = re.findall(regex, content) matches = re.findall(regex, content)
@@ -819,8 +796,7 @@ unicode_emoji_regex = '(?P<syntax>['\
# For more information, please refer to the following article: # For more information, please refer to the following article:
# http://crocodillon.com/blog/parsing-emoji-unicode-in-javascript # http://crocodillon.com/blog/parsing-emoji-unicode-in-javascript
def make_emoji(codepoint, display_string): def make_emoji(codepoint: Text, display_string: Text) -> Element:
# type: (Text, Text) -> Element
# Replace underscore in emoji's title with space # Replace underscore in emoji's title with space
title = display_string[1:-1].replace("_", " ") title = display_string[1:-1].replace("_", " ")
span = markdown.util.etree.Element('span') span = markdown.util.etree.Element('span')
@@ -829,8 +805,7 @@ def make_emoji(codepoint, display_string):
span.text = display_string span.text = display_string
return span return span
def make_realm_emoji(src, display_string): def make_realm_emoji(src: Text, display_string: Text) -> Element:
# type: (Text, Text) -> Element
elt = markdown.util.etree.Element('img') elt = markdown.util.etree.Element('img')
elt.set('src', src) elt.set('src', src)
elt.set('class', 'emoji') elt.set('class', 'emoji')
@@ -838,8 +813,7 @@ def make_realm_emoji(src, display_string):
elt.set("title", display_string[1:-1].replace("_", " ")) elt.set("title", display_string[1:-1].replace("_", " "))
return elt return elt
def unicode_emoji_to_codepoint(unicode_emoji): def unicode_emoji_to_codepoint(unicode_emoji: Text) -> Text:
# type: (Text) -> Text
codepoint = hex(ord(unicode_emoji))[2:] codepoint = hex(ord(unicode_emoji))[2:]
# Unicode codepoints are minimum of length 4, padded # Unicode codepoints are minimum of length 4, padded
# with zeroes if the length is less than zero. # with zeroes if the length is less than zero.
@@ -848,8 +822,7 @@ def unicode_emoji_to_codepoint(unicode_emoji):
return codepoint return codepoint
class UnicodeEmoji(markdown.inlinepatterns.Pattern): class UnicodeEmoji(markdown.inlinepatterns.Pattern):
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
orig_syntax = match.group('syntax') orig_syntax = match.group('syntax')
codepoint = unicode_emoji_to_codepoint(orig_syntax) codepoint = unicode_emoji_to_codepoint(orig_syntax)
if codepoint in codepoint_to_name: if codepoint in codepoint_to_name:
@@ -859,8 +832,7 @@ class UnicodeEmoji(markdown.inlinepatterns.Pattern):
return None return None
class Emoji(markdown.inlinepatterns.Pattern): class Emoji(markdown.inlinepatterns.Pattern):
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
orig_syntax = match.group("syntax") orig_syntax = match.group("syntax")
name = orig_syntax[1:-1] name = orig_syntax[1:-1]
@@ -877,15 +849,13 @@ class Emoji(markdown.inlinepatterns.Pattern):
else: else:
return None return None
def content_has_emoji_syntax(content): def content_has_emoji_syntax(content: Text) -> bool:
# type: (Text) -> bool
return re.search(EMOJI_REGEX, content) is not None return re.search(EMOJI_REGEX, content) is not None
class StreamSubscribeButton(markdown.inlinepatterns.Pattern): class StreamSubscribeButton(markdown.inlinepatterns.Pattern):
# This markdown extension has required javascript in # This markdown extension has required javascript in
# static/js/custom_markdown.js # static/js/custom_markdown.js
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Element:
# type: (Match[Text]) -> Element
stream_name = match.group('stream_name') stream_name = match.group('stream_name')
stream_name = stream_name.replace('\\)', ')').replace('\\\\', '\\') stream_name = stream_name.replace('\\)', ')').replace('\\\\', '\\')
@@ -907,8 +877,7 @@ class ModalLink(markdown.inlinepatterns.Pattern):
A pattern that allows including in-app modal links in messages. A pattern that allows including in-app modal links in messages.
""" """
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Element:
# type: (Match[Text]) -> Element
relative_url = match.group('relative_url') relative_url = match.group('relative_url')
text = match.group('text') text = match.group('text')
@@ -920,8 +889,7 @@ class ModalLink(markdown.inlinepatterns.Pattern):
return a_tag return a_tag
class Tex(markdown.inlinepatterns.Pattern): class Tex(markdown.inlinepatterns.Pattern):
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> Element:
# type: (Match[Text]) -> Element
rendered = render_tex(match.group('body'), is_inline=True) rendered = render_tex(match.group('body'), is_inline=True)
if rendered is not None: if rendered is not None:
return etree.fromstring(rendered.encode('utf-8')) return etree.fromstring(rendered.encode('utf-8'))
@@ -932,8 +900,7 @@ class Tex(markdown.inlinepatterns.Pattern):
return span return span
upload_title_re = re.compile("^(https?://[^/]*)?(/user_uploads/\\d+)(/[^/]*)?/[^/]*/(?P<filename>[^/]*)$") upload_title_re = re.compile("^(https?://[^/]*)?(/user_uploads/\\d+)(/[^/]*)?/[^/]*/(?P<filename>[^/]*)$")
def url_filename(url): def url_filename(url: Text) -> Text:
# type: (Text) -> Text
"""Extract the filename if a URL is an uploaded file, or return the original URL""" """Extract the filename if a URL is an uploaded file, or return the original URL"""
match = upload_title_re.match(url) match = upload_title_re.match(url)
if match: if match:
@@ -941,16 +908,14 @@ def url_filename(url):
else: else:
return url return url
def fixup_link(link, target_blank=True): def fixup_link(link: markdown.util.etree.Element, target_blank: bool=True) -> None:
# type: (markdown.util.etree.Element, bool) -> None
"""Set certain attributes we want on every link.""" """Set certain attributes we want on every link."""
if target_blank: if target_blank:
link.set('target', '_blank') link.set('target', '_blank')
link.set('title', url_filename(link.get('href'))) link.set('title', url_filename(link.get('href')))
def sanitize_url(url): def sanitize_url(url: Text) -> Optional[Text]:
# type: (Text) -> Optional[Text]
""" """
Sanitize a url against xss attacks. Sanitize a url against xss attacks.
See the docstring on markdown.inlinepatterns.LinkPattern.sanitize_url. See the docstring on markdown.inlinepatterns.LinkPattern.sanitize_url.
@@ -1004,8 +969,7 @@ def sanitize_url(url):
# Url passes all tests. Return url as-is. # Url passes all tests. Return url as-is.
return urllib.parse.urlunparse((scheme, netloc, path, params, query, fragment)) return urllib.parse.urlunparse((scheme, netloc, path, params, query, fragment))
def url_to_a(url, text = None): def url_to_a(url: Text, text: Optional[Text]=None) -> Union[Element, Text]:
# type: (Text, Optional[Text]) -> Union[Element, Text]
a = markdown.util.etree.Element('a') a = markdown.util.etree.Element('a')
href = sanitize_url(url) href = sanitize_url(url)
@@ -1032,8 +996,7 @@ def url_to_a(url, text = None):
return a return a
class VerbosePattern(markdown.inlinepatterns.Pattern): class VerbosePattern(markdown.inlinepatterns.Pattern):
def __init__(self, pattern): def __init__(self, pattern: Text) -> None:
# type: (Text) -> None
markdown.inlinepatterns.Pattern.__init__(self, ' ') markdown.inlinepatterns.Pattern.__init__(self, ' ')
# HACK: we just had python-markdown compile an empty regex. # HACK: we just had python-markdown compile an empty regex.
@@ -1044,8 +1007,7 @@ class VerbosePattern(markdown.inlinepatterns.Pattern):
re.DOTALL | re.UNICODE | re.VERBOSE) re.DOTALL | re.UNICODE | re.VERBOSE)
class AutoLink(VerbosePattern): class AutoLink(VerbosePattern):
def handleMatch(self, match): def handleMatch(self, match: Match[Text]) -> ElementStringNone:
# type: (Match[Text]) -> ElementStringNone
url = match.group('url') url = match.group('url')
return url_to_a(url) return url_to_a(url)
@@ -1058,8 +1020,7 @@ class UListProcessor(markdown.blockprocessors.UListProcessor):
TAG = 'ul' TAG = 'ul'
RE = re.compile('^[ ]{0,3}[*][ ]+(.*)') RE = re.compile('^[ ]{0,3}[*][ ]+(.*)')
def __init__(self, parser): def __init__(self, parser: Any) -> None:
# type: (Any) -> None
# HACK: Set the tab length to 2 just for the initialization of # HACK: Set the tab length to 2 just for the initialization of
# this class, so that bulleted lists (and only bulleted lists) # this class, so that bulleted lists (and only bulleted lists)
@@ -1074,8 +1035,7 @@ class ListIndentProcessor(markdown.blockprocessors.ListIndentProcessor):
Based on markdown.blockprocessors.ListIndentProcessor, but with 2-space indent Based on markdown.blockprocessors.ListIndentProcessor, but with 2-space indent
""" """
def __init__(self, parser): def __init__(self, parser: Any) -> None:
# type: (Any) -> None
# HACK: Set the tab length to 2 just for the initialization of # HACK: Set the tab length to 2 just for the initialization of
# this class, so that bulleted lists (and only bulleted lists) # this class, so that bulleted lists (and only bulleted lists)
@@ -1095,8 +1055,7 @@ class BugdownUListPreprocessor(markdown.preprocessors.Preprocessor):
LI_RE = re.compile('^[ ]{0,3}[*][ ]+(.*)', re.MULTILINE) LI_RE = re.compile('^[ ]{0,3}[*][ ]+(.*)', re.MULTILINE)
HANGING_ULIST_RE = re.compile('^.+\\n([ ]{0,3}[*][ ]+.*)', re.MULTILINE) HANGING_ULIST_RE = re.compile('^.+\\n([ ]{0,3}[*][ ]+.*)', re.MULTILINE)
def run(self, lines): def run(self, lines: List[Text]) -> List[Text]:
# type: (List[Text]) -> List[Text]
""" Insert a newline between a paragraph and ulist if missing """ """ Insert a newline between a paragraph and ulist if missing """
inserts = 0 inserts = 0
fence = None fence = None
@@ -1123,8 +1082,7 @@ class BugdownUListPreprocessor(markdown.preprocessors.Preprocessor):
class LinkPattern(markdown.inlinepatterns.Pattern): class LinkPattern(markdown.inlinepatterns.Pattern):
""" Return a link element from the given match. """ """ Return a link element from the given match. """
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
href = m.group(9) href = m.group(9)
if not href: if not href:
return None return None
@@ -1141,8 +1099,7 @@ class LinkPattern(markdown.inlinepatterns.Pattern):
fixup_link(el, target_blank = (href[:1] != '#')) fixup_link(el, target_blank = (href[:1] != '#'))
return el return el
def prepare_realm_pattern(source): def prepare_realm_pattern(source: Text) -> Text:
# type: (Text) -> Text
""" Augment a realm filter so it only matches after start-of-string, """ Augment a realm filter so it only matches after start-of-string,
whitespace, or opening delimiters, won't match if there are word whitespace, or opening delimiters, won't match if there are word
characters directly after, and saves what was matched as "name". """ characters directly after, and saves what was matched as "name". """
@@ -1153,20 +1110,19 @@ def prepare_realm_pattern(source):
class RealmFilterPattern(markdown.inlinepatterns.Pattern): class RealmFilterPattern(markdown.inlinepatterns.Pattern):
""" Applied a given realm filter to the input """ """ Applied a given realm filter to the input """
def __init__(self, source_pattern, format_string, markdown_instance=None): def __init__(self, source_pattern: Text,
# type: (Text, Text, Optional[markdown.Markdown]) -> None format_string: Text,
markdown_instance: Optional[markdown.Markdown]=None) -> None:
self.pattern = prepare_realm_pattern(source_pattern) self.pattern = prepare_realm_pattern(source_pattern)
self.format_string = format_string self.format_string = format_string
markdown.inlinepatterns.Pattern.__init__(self, self.pattern, markdown_instance) markdown.inlinepatterns.Pattern.__init__(self, self.pattern, markdown_instance)
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Union[Element, Text]:
# type: (Match[Text]) -> Union[Element, Text]
return url_to_a(self.format_string % m.groupdict(), return url_to_a(self.format_string % m.groupdict(),
m.group("name")) m.group("name"))
class UserMentionPattern(markdown.inlinepatterns.Pattern): class UserMentionPattern(markdown.inlinepatterns.Pattern):
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
match = m.group(2) match = m.group(2)
if current_message and db_data is not None: if current_message and db_data is not None:
@@ -1202,8 +1158,7 @@ class UserMentionPattern(markdown.inlinepatterns.Pattern):
return None return None
class UserGroupMentionPattern(markdown.inlinepatterns.Pattern): class UserGroupMentionPattern(markdown.inlinepatterns.Pattern):
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
match = m.group(2) match = m.group(2)
if current_message and db_data is not None: if current_message and db_data is not None:
@@ -1226,15 +1181,13 @@ class UserGroupMentionPattern(markdown.inlinepatterns.Pattern):
return None return None
class StreamPattern(VerbosePattern): class StreamPattern(VerbosePattern):
def find_stream_by_name(self, name): def find_stream_by_name(self, name: Match[Text]) -> Optional[Dict[str, Any]]:
# type: (Match[Text]) -> Optional[Dict[str, Any]]
if db_data is None: if db_data is None:
return None return None
stream = db_data['stream_names'].get(name) stream = db_data['stream_names'].get(name)
return stream return stream
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
name = m.group('stream_name') name = m.group('stream_name')
if current_message: if current_message:
@@ -1254,14 +1207,12 @@ class StreamPattern(VerbosePattern):
return el return el
return None return None
def possible_linked_stream_names(content): def possible_linked_stream_names(content: Text) -> Set[Text]:
# type: (Text) -> Set[Text]
matches = re.findall(STREAM_LINK_REGEX, content, re.VERBOSE) matches = re.findall(STREAM_LINK_REGEX, content, re.VERBOSE)
return set(matches) return set(matches)
class AlertWordsNotificationProcessor(markdown.preprocessors.Preprocessor): class AlertWordsNotificationProcessor(markdown.preprocessors.Preprocessor):
def run(self, lines): def run(self, lines: Iterable[Text]) -> Iterable[Text]:
# type: (Iterable[Text]) -> Iterable[Text]
if current_message and db_data is not None: if current_message and db_data is not None:
# We check for alert words here, the set of which are # We check for alert words here, the set of which are
# dependent on which users may see this message. # dependent on which users may see this message.
@@ -1292,8 +1243,7 @@ class AlertWordsNotificationProcessor(markdown.preprocessors.Preprocessor):
# Markdown link, breaking up the link. This is a monkey-patch, but it # Markdown link, breaking up the link. This is a monkey-patch, but it
# might be worth sending a version of this change upstream. # might be worth sending a version of this change upstream.
class AtomicLinkPattern(LinkPattern): class AtomicLinkPattern(LinkPattern):
def handleMatch(self, m): def handleMatch(self, m: Match[Text]) -> Optional[Element]:
# type: (Match[Text]) -> Optional[Element]
ret = LinkPattern.handleMatch(self, m) ret = LinkPattern.handleMatch(self, m)
if ret is None: if ret is None:
return None return None
@@ -1307,8 +1257,7 @@ DEFAULT_BUGDOWN_KEY = -1
ZEPHYR_MIRROR_BUGDOWN_KEY = -2 ZEPHYR_MIRROR_BUGDOWN_KEY = -2
class Bugdown(markdown.Extension): class Bugdown(markdown.Extension):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Union[bool, int, List[Any]]) -> None:
# type: (*Any, **Union[bool, int, List[Any]]) -> None
# define default configs # define default configs
self.config = { self.config = {
"realm_filters": [kwargs['realm_filters'], "realm_filters": [kwargs['realm_filters'],
@@ -1320,8 +1269,7 @@ class Bugdown(markdown.Extension):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md: markdown.Markdown, md_globals: Dict[str, Any]) -> None:
# type: (markdown.Markdown, Dict[str, Any]) -> None
del md.preprocessors['reference'] del md.preprocessors['reference']
if self.getConfig('code_block_processor_disabled'): if self.getConfig('code_block_processor_disabled'):
@@ -1476,13 +1424,11 @@ md_engines = {} # type: Dict[Tuple[int, bool], markdown.Markdown]
realm_filter_data = {} # type: Dict[int, List[Tuple[Text, Text, int]]] realm_filter_data = {} # type: Dict[int, List[Tuple[Text, Text, int]]]
class EscapeHtml(markdown.Extension): class EscapeHtml(markdown.Extension):
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md: markdown.Markdown, md_globals: Dict[str, Any]) -> None:
# type: (markdown.Markdown, Dict[str, Any]) -> None
del md.preprocessors['html_block'] del md.preprocessors['html_block']
del md.inlinePatterns['html'] del md.inlinePatterns['html']
def make_md_engine(realm_filters_key, email_gateway): def make_md_engine(realm_filters_key: int, email_gateway: bool) -> None:
# type: (int, bool) -> None
md_engine_key = (realm_filters_key, email_gateway) md_engine_key = (realm_filters_key, email_gateway)
if md_engine_key in md_engines: if md_engine_key in md_engines:
del md_engines[md_engine_key] del md_engines[md_engine_key]
@@ -1503,8 +1449,7 @@ def make_md_engine(realm_filters_key, email_gateway):
realm=realm_filters_key, realm=realm_filters_key,
code_block_processor_disabled=email_gateway)]) code_block_processor_disabled=email_gateway)])
def subject_links(realm_filters_key, subject): def subject_links(realm_filters_key: int, subject: Text) -> List[Text]:
# type: (int, Text) -> List[Text]
matches = [] # type: List[Text] matches = [] # type: List[Text]
realm_filters = realm_filters_for_realm(realm_filters_key) realm_filters = realm_filters_for_realm(realm_filters_key)
@@ -1515,8 +1460,7 @@ def subject_links(realm_filters_key, subject):
matches += [realm_filter[1] % m.groupdict()] matches += [realm_filter[1] % m.groupdict()]
return matches return matches
def maybe_update_markdown_engines(realm_filters_key, email_gateway): def maybe_update_markdown_engines(realm_filters_key: Optional[int], email_gateway: bool) -> None:
# type: (Optional[int], bool) -> None
# If realm_filters_key is None, load all filters # If realm_filters_key is None, load all filters
global realm_filter_data global realm_filter_data
if realm_filters_key is None: if realm_filters_key is None:
@@ -1551,8 +1495,7 @@ def maybe_update_markdown_engines(realm_filters_key, email_gateway):
# We also use repr() to improve reproducibility, and to escape terminal control # We also use repr() to improve reproducibility, and to escape terminal control
# codes, which can do surprisingly nasty things. # codes, which can do surprisingly nasty things.
_privacy_re = re.compile('\\w', flags=re.UNICODE) _privacy_re = re.compile('\\w', flags=re.UNICODE)
def privacy_clean_markdown(content): def privacy_clean_markdown(content: Text) -> Text:
# type: (Text) -> Text
return repr(_privacy_re.sub('x', content)) return repr(_privacy_re.sub('x', content))
@@ -1565,16 +1508,14 @@ current_message = None # type: Optional[Message]
# threads themselves, as well. # threads themselves, as well.
db_data = None # type: Optional[Dict[Text, Any]] db_data = None # type: Optional[Dict[Text, Any]]
def log_bugdown_error(msg): def log_bugdown_error(msg: str) -> None:
# type: (str) -> None
"""We use this unusual logging approach to log the bugdown error, in """We use this unusual logging approach to log the bugdown error, in
order to prevent AdminZulipHandler from sending the santized order to prevent AdminZulipHandler from sending the santized
original markdown formatting into another Zulip message, which original markdown formatting into another Zulip message, which
could cause an infinite exception loop.""" could cause an infinite exception loop."""
logging.getLogger('').error(msg) logging.getLogger('').error(msg)
def get_email_info(realm_id, emails): def get_email_info(realm_id: int, emails: Set[Text]) -> Dict[Text, FullNameInfo]:
# type: (int, Set[Text]) -> Dict[Text, FullNameInfo]
if not emails: if not emails:
return dict() return dict()
@@ -1598,8 +1539,7 @@ def get_email_info(realm_id, emails):
} }
return dct return dct
def get_full_name_info(realm_id, full_names): def get_full_name_info(realm_id: int, full_names: Set[Text]) -> Dict[Text, FullNameInfo]:
# type: (int, Set[Text]) -> Dict[Text, FullNameInfo]
if not full_names: if not full_names:
return dict() return dict()
@@ -1626,8 +1566,7 @@ def get_full_name_info(realm_id, full_names):
return dct return dct
class MentionData: class MentionData:
def __init__(self, realm_id, content): def __init__(self, realm_id: int, content: Text) -> None:
# type: (int, Text) -> None
full_names = possible_mentions(content) full_names = possible_mentions(content)
self.full_name_info = get_full_name_info(realm_id, full_names) self.full_name_info = get_full_name_info(realm_id, full_names)
self.user_ids = { self.user_ids = {
@@ -1645,12 +1584,10 @@ class MentionData:
user_profile_id = info['user_profile_id'] user_profile_id = info['user_profile_id']
self.user_group_members[group_id].append(user_profile_id) self.user_group_members[group_id].append(user_profile_id)
def get_user(self, name): def get_user(self, name: Text) -> Optional[FullNameInfo]:
# type: (Text) -> Optional[FullNameInfo]
return self.full_name_info.get(name.lower(), None) return self.full_name_info.get(name.lower(), None)
def get_user_ids(self): def get_user_ids(self) -> Set[int]:
# type: () -> Set[int]
""" """
Returns the user IDs that might have been mentioned by this Returns the user IDs that might have been mentioned by this
content. Note that because this data structure has not parsed content. Note that because this data structure has not parsed
@@ -1659,16 +1596,13 @@ class MentionData:
""" """
return self.user_ids return self.user_ids
def get_user_group(self, name): def get_user_group(self, name: Text) -> Optional[UserGroup]:
# type: (Text) -> Optional[UserGroup]
return self.user_group_name_info.get(name.lower(), None) return self.user_group_name_info.get(name.lower(), None)
def get_group_members(self, user_group_id): def get_group_members(self, user_group_id: int) -> List[int]:
# type: (int) -> List[int]
return self.user_group_members.get(user_group_id, []) return self.user_group_members.get(user_group_id, [])
def get_user_group_name_info(realm_id, user_group_names): def get_user_group_name_info(realm_id: int, user_group_names: Set[Text]) -> Dict[Text, UserGroup]:
# type: (int, Set[Text]) -> Dict[Text, UserGroup]
if not user_group_names: if not user_group_names:
return dict() return dict()
@@ -1677,8 +1611,7 @@ def get_user_group_name_info(realm_id, user_group_names):
dct = {row.name.lower(): row for row in rows} dct = {row.name.lower(): row for row in rows}
return dct return dct
def get_stream_name_info(realm, stream_names): def get_stream_name_info(realm: Realm, stream_names: Set[Text]) -> Dict[Text, FullNameInfo]:
# type: (Realm, Set[Text]) -> Dict[Text, FullNameInfo]
if not stream_names: if not stream_names:
return dict() return dict()
@@ -1703,9 +1636,13 @@ def get_stream_name_info(realm, stream_names):
return dct return dct
def do_convert(content, message=None, message_realm=None, possible_words=None, sent_by_bot=False, def do_convert(content: Text,
mention_data=None, email_gateway=False): message: Optional[Message]=None,
# type: (Text, Optional[Message], Optional[Realm], Optional[Set[Text]], Optional[bool], Optional[MentionData], Optional[bool]) -> Text message_realm: Optional[Realm]=None,
possible_words: Optional[Set[Text]]=None,
sent_by_bot: Optional[bool]=False,
mention_data: Optional[MentionData]=None,
email_gateway: Optional[bool]=False) -> Text:
"""Convert Markdown to HTML, with Zulip-specific settings and hacks.""" """Convert Markdown to HTML, with Zulip-specific settings and hacks."""
# This logic is a bit convoluted, but the overall goal is to support a range of use cases: # This logic is a bit convoluted, but the overall goal is to support a range of use cases:
# * Nothing is passed in other than content -> just run default options (e.g. for docs) # * Nothing is passed in other than content -> just run default options (e.g. for docs)
@@ -1803,30 +1740,30 @@ bugdown_time_start = 0.0
bugdown_total_time = 0.0 bugdown_total_time = 0.0
bugdown_total_requests = 0 bugdown_total_requests = 0
def get_bugdown_time(): def get_bugdown_time() -> float:
# type: () -> float
return bugdown_total_time return bugdown_total_time
def get_bugdown_requests(): def get_bugdown_requests() -> int:
# type: () -> int
return bugdown_total_requests return bugdown_total_requests
def bugdown_stats_start(): def bugdown_stats_start() -> None:
# type: () -> None
global bugdown_time_start global bugdown_time_start
bugdown_time_start = time.time() bugdown_time_start = time.time()
def bugdown_stats_finish(): def bugdown_stats_finish() -> None:
# type: () -> None
global bugdown_total_time global bugdown_total_time
global bugdown_total_requests global bugdown_total_requests
global bugdown_time_start global bugdown_time_start
bugdown_total_requests += 1 bugdown_total_requests += 1
bugdown_total_time += (time.time() - bugdown_time_start) bugdown_total_time += (time.time() - bugdown_time_start)
def convert(content, message=None, message_realm=None, possible_words=None, sent_by_bot=False, def convert(content: Text,
mention_data=None, email_gateway=False): message: Optional[Message]=None,
# type: (Text, Optional[Message], Optional[Realm], Optional[Set[Text]], Optional[bool], Optional[MentionData], Optional[bool]) -> Text message_realm: Optional[Realm]=None,
possible_words: Optional[Set[Text]]=None,
sent_by_bot: Optional[bool]=False,
mention_data: Optional[MentionData]=None,
email_gateway: Optional[bool]=False) -> Text:
bugdown_stats_start() bugdown_stats_start()
ret = do_convert(content, message, message_realm, ret = do_convert(content, message, message_realm,
possible_words, sent_by_bot, mention_data, email_gateway) possible_words, sent_by_bot, mention_data, email_gateway)

View File

@@ -110,8 +110,7 @@ LANG_TAG = ' class="%s"'
class FencedCodeExtension(markdown.Extension): class FencedCodeExtension(markdown.Extension):
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md: markdown.Markdown, md_globals: Dict[str, Any]) -> None:
# type: (markdown.Markdown, Dict[str, Any]) -> None
""" Add FencedBlockPreprocessor to the Markdown instance. """ """ Add FencedBlockPreprocessor to the Markdown instance. """
md.registerExtension(self) md.registerExtension(self)
@@ -127,41 +126,34 @@ class FencedCodeExtension(markdown.Extension):
class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor): class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
def __init__(self, md): def __init__(self, md: markdown.Markdown) -> None:
# type: (markdown.Markdown) -> None
markdown.preprocessors.Preprocessor.__init__(self, md) markdown.preprocessors.Preprocessor.__init__(self, md)
self.checked_for_codehilite = False self.checked_for_codehilite = False
self.codehilite_conf = {} # type: Dict[str, List[Any]] self.codehilite_conf = {} # type: Dict[str, List[Any]]
def run(self, lines): def run(self, lines: Iterable[Text]) -> List[Text]:
# type: (Iterable[Text]) -> List[Text]
""" Match and store Fenced Code Blocks in the HtmlStash. """ """ Match and store Fenced Code Blocks in the HtmlStash. """
output = [] # type: List[Text] output = [] # type: List[Text]
class BaseHandler: class BaseHandler:
def handle_line(self, line): def handle_line(self, line: Text) -> None:
# type: (Text) -> None
raise NotImplementedError() raise NotImplementedError()
def done(self): def done(self) -> None:
# type: () -> None
raise NotImplementedError() raise NotImplementedError()
processor = self processor = self
handlers = [] # type: List[BaseHandler] handlers = [] # type: List[BaseHandler]
def push(handler): def push(handler: BaseHandler) -> None:
# type: (BaseHandler) -> None
handlers.append(handler) handlers.append(handler)
def pop(): def pop() -> None:
# type: () -> None
handlers.pop() handlers.pop()
def check_for_new_fence(output, line): def check_for_new_fence(output: MutableSequence[Text], line: Text) -> None:
# type: (MutableSequence[Text], Text) -> None
m = FENCE_RE.match(line) m = FENCE_RE.match(line)
if m: if m:
fence = m.group('fence') fence = m.group('fence')
@@ -172,20 +164,16 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
output.append(line) output.append(line)
class OuterHandler(BaseHandler): class OuterHandler(BaseHandler):
def __init__(self, output): def __init__(self, output: MutableSequence[Text]) -> None:
# type: (MutableSequence[Text]) -> None
self.output = output self.output = output
def handle_line(self, line): def handle_line(self, line: Text) -> None:
# type: (Text) -> None
check_for_new_fence(self.output, line) check_for_new_fence(self.output, line)
def done(self): def done(self) -> None:
# type: () -> None
pop() pop()
def generic_handler(output, fence, lang): def generic_handler(output: MutableSequence[Text], fence: Text, lang: Text) -> BaseHandler:
# type: (MutableSequence[Text], Text, Text) -> BaseHandler
if lang in ('quote', 'quoted'): if lang in ('quote', 'quoted'):
return QuoteHandler(output, fence) return QuoteHandler(output, fence)
elif lang in ('math', 'tex', 'latex'): elif lang in ('math', 'tex', 'latex'):
@@ -194,22 +182,19 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
return CodeHandler(output, fence, lang) return CodeHandler(output, fence, lang)
class CodeHandler(BaseHandler): class CodeHandler(BaseHandler):
def __init__(self, output, fence, lang): def __init__(self, output: MutableSequence[Text], fence: Text, lang: Text) -> None:
# type: (MutableSequence[Text], Text, Text) -> None
self.output = output self.output = output
self.fence = fence self.fence = fence
self.lang = lang self.lang = lang
self.lines = [] # type: List[Text] self.lines = [] # type: List[Text]
def handle_line(self, line): def handle_line(self, line: Text) -> None:
# type: (Text) -> None
if line.rstrip() == self.fence: if line.rstrip() == self.fence:
self.done() self.done()
else: else:
self.lines.append(line.rstrip()) self.lines.append(line.rstrip())
def done(self): def done(self) -> None:
# type: () -> None
text = '\n'.join(self.lines) text = '\n'.join(self.lines)
text = processor.format_code(self.lang, text) text = processor.format_code(self.lang, text)
text = processor.placeholder(text) text = processor.placeholder(text)
@@ -220,21 +205,18 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
pop() pop()
class QuoteHandler(BaseHandler): class QuoteHandler(BaseHandler):
def __init__(self, output, fence): def __init__(self, output: MutableSequence[Text], fence: Text) -> None:
# type: (MutableSequence[Text], Text) -> None
self.output = output self.output = output
self.fence = fence self.fence = fence
self.lines = [] # type: List[Text] self.lines = [] # type: List[Text]
def handle_line(self, line): def handle_line(self, line: Text) -> None:
# type: (Text) -> None
if line.rstrip() == self.fence: if line.rstrip() == self.fence:
self.done() self.done()
else: else:
check_for_new_fence(self.lines, line) check_for_new_fence(self.lines, line)
def done(self): def done(self) -> None:
# type: () -> None
text = '\n'.join(self.lines) text = '\n'.join(self.lines)
text = processor.format_quote(text) text = processor.format_quote(text)
processed_lines = text.split('\n') processed_lines = text.split('\n')
@@ -244,21 +226,18 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
pop() pop()
class TexHandler(BaseHandler): class TexHandler(BaseHandler):
def __init__(self, output, fence): def __init__(self, output: MutableSequence[Text], fence: Text) -> None:
# type: (MutableSequence[Text], Text) -> None
self.output = output self.output = output
self.fence = fence self.fence = fence
self.lines = [] # type: List[Text] self.lines = [] # type: List[Text]
def handle_line(self, line): def handle_line(self, line: Text) -> None:
# type: (Text) -> None
if line.rstrip() == self.fence: if line.rstrip() == self.fence:
self.done() self.done()
else: else:
self.lines.append(line) self.lines.append(line)
def done(self): def done(self) -> None:
# type: () -> None
text = '\n'.join(self.lines) text = '\n'.join(self.lines)
text = processor.format_tex(text) text = processor.format_tex(text)
text = processor.placeholder(text) text = processor.placeholder(text)
@@ -284,8 +263,7 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
output.append('') output.append('')
return output return output
def format_code(self, lang, text): def format_code(self, lang: Text, text: Text) -> Text:
# type: (Text, Text) -> Text
if lang: if lang:
langclass = LANG_TAG % (lang,) langclass = LANG_TAG % (lang,)
else: else:
@@ -318,8 +296,7 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
return code return code
def format_quote(self, text): def format_quote(self, text: Text) -> Text:
# type: (Text) -> Text
paragraphs = text.split("\n\n") paragraphs = text.split("\n\n")
quoted_paragraphs = [] quoted_paragraphs = []
for paragraph in paragraphs: for paragraph in paragraphs:
@@ -327,8 +304,7 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
quoted_paragraphs.append("\n".join("> " + line for line in lines if line != '')) quoted_paragraphs.append("\n".join("> " + line for line in lines if line != ''))
return "\n\n".join(quoted_paragraphs) return "\n\n".join(quoted_paragraphs)
def format_tex(self, text): def format_tex(self, text: Text) -> Text:
# type: (Text) -> Text
paragraphs = text.split("\n\n") paragraphs = text.split("\n\n")
tex_paragraphs = [] tex_paragraphs = []
for paragraph in paragraphs: for paragraph in paragraphs:
@@ -340,12 +316,10 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
escape(paragraph) + '</span>') escape(paragraph) + '</span>')
return "\n\n".join(tex_paragraphs) return "\n\n".join(tex_paragraphs)
def placeholder(self, code): def placeholder(self, code: Text) -> Text:
# type: (Text) -> Text
return self.markdown.htmlStash.store(code, safe=True) return self.markdown.htmlStash.store(code, safe=True)
def _escape(self, txt): def _escape(self, txt: Text) -> Text:
# type: (Text) -> Text
""" basic html escaping """ """ basic html escaping """
txt = txt.replace('&', '&amp;') txt = txt.replace('&', '&amp;')
txt = txt.replace('<', '&lt;') txt = txt.replace('<', '&lt;')
@@ -354,8 +328,7 @@ class FencedBlockPreprocessor(markdown.preprocessors.Preprocessor):
return txt return txt
def makeExtension(*args, **kwargs): def makeExtension(*args: Any, **kwargs: None) -> FencedCodeExtension:
# type: (*Any, **Union[bool, None, Text]) -> FencedCodeExtension
return FencedCodeExtension(*args, **kwargs) return FencedCodeExtension(*args, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -33,29 +33,24 @@ remote_cache_time_start = 0.0
remote_cache_total_time = 0.0 remote_cache_total_time = 0.0
remote_cache_total_requests = 0 remote_cache_total_requests = 0
def get_remote_cache_time(): def get_remote_cache_time() -> float:
# type: () -> float
return remote_cache_total_time return remote_cache_total_time
def get_remote_cache_requests(): def get_remote_cache_requests() -> int:
# type: () -> int
return remote_cache_total_requests return remote_cache_total_requests
def remote_cache_stats_start(): def remote_cache_stats_start() -> None:
# type: () -> None
global remote_cache_time_start global remote_cache_time_start
remote_cache_time_start = time.time() remote_cache_time_start = time.time()
def remote_cache_stats_finish(): def remote_cache_stats_finish() -> None:
# type: () -> None
global remote_cache_total_time global remote_cache_total_time
global remote_cache_total_requests global remote_cache_total_requests
global remote_cache_time_start global remote_cache_time_start
remote_cache_total_requests += 1 remote_cache_total_requests += 1
remote_cache_total_time += (time.time() - remote_cache_time_start) remote_cache_total_time += (time.time() - remote_cache_time_start)
def get_or_create_key_prefix(): def get_or_create_key_prefix() -> Text:
# type: () -> Text
if settings.CASPER_TESTS: if settings.CASPER_TESTS:
# This sets the prefix for the benefit of the Casper tests. # This sets the prefix for the benefit of the Casper tests.
# #
@@ -99,32 +94,27 @@ def get_or_create_key_prefix():
KEY_PREFIX = get_or_create_key_prefix() # type: Text KEY_PREFIX = get_or_create_key_prefix() # type: Text
def bounce_key_prefix_for_testing(test_name): def bounce_key_prefix_for_testing(test_name: Text) -> None:
# type: (Text) -> None
global KEY_PREFIX global KEY_PREFIX
KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':' KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':'
# We are taking the hash of the KEY_PREFIX to decrease the size of the key. # We are taking the hash of the KEY_PREFIX to decrease the size of the key.
# Memcached keys should have a length of less than 256. # Memcached keys should have a length of less than 256.
KEY_PREFIX = hashlib.sha1(KEY_PREFIX.encode('utf-8')).hexdigest() KEY_PREFIX = hashlib.sha1(KEY_PREFIX.encode('utf-8')).hexdigest()
def get_cache_backend(cache_name): def get_cache_backend(cache_name: Optional[str]) -> BaseCache:
# type: (Optional[str]) -> BaseCache
if cache_name is None: if cache_name is None:
return djcache return djcache
return caches[cache_name] return caches[cache_name]
def get_cache_with_key(keyfunc, cache_name=None): def get_cache_with_key(keyfunc: Any, cache_name: Optional[str]=None) -> Any:
# type: (Any, Optional[str]) -> Any
""" """
The main goal of this function getting value from the cache like in the "cache_with_key". The main goal of this function getting value from the cache like in the "cache_with_key".
A cache value can contain any data including the "None", so A cache value can contain any data including the "None", so
here used exception for case if value isn't found in the cache. here used exception for case if value isn't found in the cache.
""" """
def decorator(func): def decorator(func: Callable[..., Any]) -> (Callable[..., Any]):
# type: (Callable[..., Any]) -> (Callable[..., Any])
@wraps(func) @wraps(func)
def func_with_caching(*args, **kwargs): def func_with_caching(*args: Any, **kwargs: Any) -> Callable[..., Any]:
# type: (*Any, **Any) -> Callable[..., Any]
key = keyfunc(*args, **kwargs) key = keyfunc(*args, **kwargs)
val = cache_get(key, cache_name=cache_name) val = cache_get(key, cache_name=cache_name)
if val is not None: if val is not None:
@@ -144,11 +134,9 @@ def cache_with_key(keyfunc, cache_name=None, timeout=None, with_statsd_key=None)
for avoiding collisions with other uses of this decorator or for avoiding collisions with other uses of this decorator or
other uses of caching.""" other uses of caching."""
def decorator(func): def decorator(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
# type: (Callable[..., ReturnT]) -> Callable[..., ReturnT]
@wraps(func) @wraps(func)
def func_with_caching(*args, **kwargs): def func_with_caching(*args: Any, **kwargs: Any) -> ReturnT:
# type: (*Any, **Any) -> ReturnT
key = keyfunc(*args, **kwargs) key = keyfunc(*args, **kwargs)
val = cache_get(key, cache_name=cache_name) val = cache_get(key, cache_name=cache_name)
@@ -180,31 +168,28 @@ def cache_with_key(keyfunc, cache_name=None, timeout=None, with_statsd_key=None)
return decorator return decorator
def cache_set(key, val, cache_name=None, timeout=None): def cache_set(key: Text, val: Any, cache_name: Optional[str]=None, timeout: Optional[int]=None) -> None:
# type: (Text, Any, Optional[str], Optional[int]) -> None
remote_cache_stats_start() remote_cache_stats_start()
cache_backend = get_cache_backend(cache_name) cache_backend = get_cache_backend(cache_name)
cache_backend.set(KEY_PREFIX + key, (val,), timeout=timeout) cache_backend.set(KEY_PREFIX + key, (val,), timeout=timeout)
remote_cache_stats_finish() remote_cache_stats_finish()
def cache_get(key, cache_name=None): def cache_get(key: Text, cache_name: Optional[str]=None) -> Any:
# type: (Text, Optional[str]) -> Any
remote_cache_stats_start() remote_cache_stats_start()
cache_backend = get_cache_backend(cache_name) cache_backend = get_cache_backend(cache_name)
ret = cache_backend.get(KEY_PREFIX + key) ret = cache_backend.get(KEY_PREFIX + key)
remote_cache_stats_finish() remote_cache_stats_finish()
return ret return ret
def cache_get_many(keys, cache_name=None): def cache_get_many(keys: List[Text], cache_name: Optional[str]=None) -> Dict[Text, Any]:
# type: (List[Text], Optional[str]) -> Dict[Text, Any]
keys = [KEY_PREFIX + key for key in keys] keys = [KEY_PREFIX + key for key in keys]
remote_cache_stats_start() remote_cache_stats_start()
ret = get_cache_backend(cache_name).get_many(keys) ret = get_cache_backend(cache_name).get_many(keys)
remote_cache_stats_finish() remote_cache_stats_finish()
return dict([(key[len(KEY_PREFIX):], value) for key, value in ret.items()]) return dict([(key[len(KEY_PREFIX):], value) for key, value in ret.items()])
def cache_set_many(items, cache_name=None, timeout=None): def cache_set_many(items: Dict[Text, Any], cache_name: Optional[str]=None,
# type: (Dict[Text, Any], Optional[str], Optional[int]) -> None timeout: Optional[int]=None) -> None:
new_items = {} new_items = {}
for key in items: for key in items:
new_items[KEY_PREFIX + key] = items[key] new_items[KEY_PREFIX + key] = items[key]
@@ -213,14 +198,12 @@ def cache_set_many(items, cache_name=None, timeout=None):
get_cache_backend(cache_name).set_many(items, timeout=timeout) get_cache_backend(cache_name).set_many(items, timeout=timeout)
remote_cache_stats_finish() remote_cache_stats_finish()
def cache_delete(key, cache_name=None): def cache_delete(key: Text, cache_name: Optional[str]=None) -> None:
# type: (Text, Optional[str]) -> None
remote_cache_stats_start() remote_cache_stats_start()
get_cache_backend(cache_name).delete(KEY_PREFIX + key) get_cache_backend(cache_name).delete(KEY_PREFIX + key)
remote_cache_stats_finish() remote_cache_stats_finish()
def cache_delete_many(items, cache_name=None): def cache_delete_many(items: Iterable[Text], cache_name: Optional[str]=None) -> None:
# type: (Iterable[Text], Optional[str]) -> None
remote_cache_stats_start() remote_cache_stats_start()
get_cache_backend(cache_name).delete_many( get_cache_backend(cache_name).delete_many(
KEY_PREFIX + item for item in items) KEY_PREFIX + item for item in items)
@@ -289,8 +272,7 @@ def generic_bulk_cached_fetch(
return dict((object_id, cached_objects[cache_keys[object_id]]) for object_id in object_ids return dict((object_id, cached_objects[cache_keys[object_id]]) for object_id in object_ids
if cache_keys[object_id] in cached_objects) if cache_keys[object_id] in cached_objects)
def cache(func): def cache(func: Callable[..., ReturnT]) -> Callable[..., ReturnT]:
# type: (Callable[..., ReturnT]) -> Callable[..., ReturnT]
"""Decorator which applies Django caching to a function. """Decorator which applies Django caching to a function.
Uses a key based on the function's name, filename, and Uses a key based on the function's name, filename, and
@@ -299,20 +281,17 @@ def cache(func):
func_uniqifier = '%s-%s' % (func.__code__.co_filename, func.__name__) func_uniqifier = '%s-%s' % (func.__code__.co_filename, func.__name__)
@wraps(func) @wraps(func)
def keyfunc(*args, **kwargs): def keyfunc(*args: Any, **kwargs: Any) -> str:
# type: (*Any, **Any) -> str
# Django complains about spaces because memcached rejects them # Django complains about spaces because memcached rejects them
key = func_uniqifier + repr((args, kwargs)) key = func_uniqifier + repr((args, kwargs))
return key.replace('-', '--').replace(' ', '-s') return key.replace('-', '--').replace(' ', '-s')
return cache_with_key(keyfunc)(func) return cache_with_key(keyfunc)(func)
def display_recipient_cache_key(recipient_id): def display_recipient_cache_key(recipient_id: int) -> Text:
# type: (int) -> Text
return u"display_recipient_dict:%d" % (recipient_id,) return u"display_recipient_dict:%d" % (recipient_id,)
def user_profile_by_email_cache_key(email): def user_profile_by_email_cache_key(email: Text) -> Text:
# type: (Text) -> Text
# See the comment in zerver/lib/avatar_hash.py:gravatar_hash for why we # See the comment in zerver/lib/avatar_hash.py:gravatar_hash for why we
# are proactively encoding email addresses even though they will # are proactively encoding email addresses even though they will
# with high likelihood be ASCII-only for the foreseeable future. # with high likelihood be ASCII-only for the foreseeable future.
@@ -326,16 +305,13 @@ def user_profile_cache_key(email, realm):
# type: (Text, Realm) -> Text # type: (Text, Realm) -> Text
return user_profile_cache_key_id(email, realm.id) return user_profile_cache_key_id(email, realm.id)
def bot_profile_cache_key(email): def bot_profile_cache_key(email: Text) -> Text:
# type: (Text) -> Text
return u"bot_profile:%s" % (make_safe_digest(email.strip())) return u"bot_profile:%s" % (make_safe_digest(email.strip()))
def user_profile_by_id_cache_key(user_profile_id): def user_profile_by_id_cache_key(user_profile_id: int) -> Text:
# type: (int) -> Text
return u"user_profile_by_id:%s" % (user_profile_id,) return u"user_profile_by_id:%s" % (user_profile_id,)
def user_profile_by_api_key_cache_key(api_key): def user_profile_by_api_key_cache_key(api_key: Text) -> Text:
# type: (Text) -> Text
return u"user_profile_by_api_key:%s" % (api_key,) return u"user_profile_by_api_key:%s" % (api_key,)
# TODO: Refactor these cache helpers into another file that can import # TODO: Refactor these cache helpers into another file that can import
@@ -346,12 +322,10 @@ realm_user_dict_fields = [
'avatar_source', 'avatar_version', 'is_active', 'avatar_source', 'avatar_version', 'is_active',
'is_realm_admin', 'is_bot', 'realm_id', 'timezone'] # type: List[str] 'is_realm_admin', 'is_bot', 'realm_id', 'timezone'] # type: List[str]
def realm_user_dicts_cache_key(realm_id): def realm_user_dicts_cache_key(realm_id: int) -> Text:
# type: (int) -> Text
return u"realm_user_dicts:%s" % (realm_id,) return u"realm_user_dicts:%s" % (realm_id,)
def active_user_ids_cache_key(realm_id): def active_user_ids_cache_key(realm_id: int) -> Text:
# type: (int) -> Text
return u"active_user_ids:%s" % (realm_id,) return u"active_user_ids:%s" % (realm_id,)
bot_dict_fields = ['id', 'full_name', 'short_name', 'bot_type', 'email', bot_dict_fields = ['id', 'full_name', 'short_name', 'bot_type', 'email',
@@ -366,8 +340,7 @@ def bot_dicts_in_realm_cache_key(realm):
# type: (Realm) -> Text # type: (Realm) -> Text
return u"bot_dicts_in_realm:%s" % (realm.id,) return u"bot_dicts_in_realm:%s" % (realm.id,)
def get_stream_cache_key(stream_name, realm_id): def get_stream_cache_key(stream_name: Text, realm_id: int) -> Text:
# type: (Text, int) -> Text
return u"stream_by_realm_and_name:%s:%s" % ( return u"stream_by_realm_and_name:%s:%s" % (
realm_id, make_safe_digest(stream_name.strip().lower())) realm_id, make_safe_digest(stream_name.strip().lower()))
@@ -392,13 +365,11 @@ def delete_display_recipient_cache(user_profile):
# Called by models.py to flush the user_profile cache whenever we save # Called by models.py to flush the user_profile cache whenever we save
# a user_profile object # a user_profile object
def flush_user_profile(sender, **kwargs): def flush_user_profile(sender: Any, **kwargs: Any) -> None:
# type: (Any, **Any) -> None
user_profile = kwargs['instance'] user_profile = kwargs['instance']
delete_user_profile_caches([user_profile]) delete_user_profile_caches([user_profile])
def changed(fields): def changed(fields: List[str]) -> bool:
# type: (List[str]) -> bool
if kwargs.get('update_fields') is None: if kwargs.get('update_fields') is None:
# adds/deletes should invalidate the cache # adds/deletes should invalidate the cache
return True return True
@@ -434,8 +405,7 @@ def flush_user_profile(sender, **kwargs):
# Called by models.py to flush various caches whenever we save # Called by models.py to flush various caches whenever we save
# a Realm object. The main tricky thing here is that Realm info is # a Realm object. The main tricky thing here is that Realm info is
# generally cached indirectly through user_profile objects. # generally cached indirectly through user_profile objects.
def flush_realm(sender, **kwargs): def flush_realm(sender: Any, **kwargs: Any) -> None:
# type: (Any, **Any) -> None
realm = kwargs['instance'] realm = kwargs['instance']
users = realm.get_active_users() users = realm.get_active_users()
delete_user_profile_caches(users) delete_user_profile_caches(users)
@@ -452,8 +422,7 @@ def realm_alert_words_cache_key(realm):
# Called by models.py to flush the stream cache whenever we save a stream # Called by models.py to flush the stream cache whenever we save a stream
# object. # object.
def flush_stream(sender, **kwargs): def flush_stream(sender: Any, **kwargs: Any) -> None:
# type: (Any, **Any) -> None
from zerver.models import UserProfile from zerver.models import UserProfile
stream = kwargs['instance'] stream = kwargs['instance']
items_for_remote_cache = {} items_for_remote_cache = {}
@@ -466,15 +435,13 @@ def flush_stream(sender, **kwargs):
Q(default_events_register_stream=stream)).exists(): Q(default_events_register_stream=stream)).exists():
cache_delete(bot_dicts_in_realm_cache_key(stream.realm)) cache_delete(bot_dicts_in_realm_cache_key(stream.realm))
def to_dict_cache_key_id(message_id): def to_dict_cache_key_id(message_id: int) -> Text:
# type: (int) -> Text
return 'message_dict:%d' % (message_id,) return 'message_dict:%d' % (message_id,)
def to_dict_cache_key(message): def to_dict_cache_key(message):
# type: (Message) -> Text # type: (Message) -> Text
return to_dict_cache_key_id(message.id) return to_dict_cache_key_id(message.id)
def flush_message(sender, **kwargs): def flush_message(sender: Any, **kwargs: Any) -> None:
# type: (Any, **Any) -> None
message = kwargs['instance'] message = kwargs['instance']
cache_delete(to_dict_cache_key_id(message.id)) cache_delete(to_dict_cache_key_id(message.id))

View File

@@ -32,8 +32,7 @@ import ujson
import urllib import urllib
from collections import defaultdict from collections import defaultdict
def one_click_unsubscribe_link(user_profile, email_type): def one_click_unsubscribe_link(user_profile: UserProfile, email_type: str) -> str:
# type: (UserProfile, str) -> str
""" """
Generate a unique link that a logged-out user can visit to unsubscribe from Generate a unique link that a logged-out user can visit to unsubscribe from
Zulip e-mails without having to first log in. Zulip e-mails without having to first log in.
@@ -42,33 +41,28 @@ def one_click_unsubscribe_link(user_profile, email_type):
Confirmation.UNSUBSCRIBE, Confirmation.UNSUBSCRIBE,
url_args = {'email_type': email_type}) url_args = {'email_type': email_type})
def hash_util_encode(string): def hash_util_encode(string: Text) -> Text:
# type: (Text) -> Text
# Do the same encoding operation as hash_util.encodeHashComponent on the # Do the same encoding operation as hash_util.encodeHashComponent on the
# frontend. # frontend.
# `safe` has a default value of "/", but we want those encoded, too. # `safe` has a default value of "/", but we want those encoded, too.
return urllib.parse.quote( return urllib.parse.quote(
string.encode("utf-8"), safe=b"").replace(".", "%2E").replace("%", ".") string.encode("utf-8"), safe=b"").replace(".", "%2E").replace("%", ".")
def pm_narrow_url(realm, participants): def pm_narrow_url(realm: Realm, participants: List[Text]) -> Text:
# type: (Realm, List[Text]) -> Text
participants.sort() participants.sort()
base_url = u"%s/#narrow/pm-with/" % (realm.uri,) base_url = u"%s/#narrow/pm-with/" % (realm.uri,)
return base_url + hash_util_encode(",".join(participants)) return base_url + hash_util_encode(",".join(participants))
def stream_narrow_url(realm, stream): def stream_narrow_url(realm: Realm, stream: Text) -> Text:
# type: (Realm, Text) -> Text
base_url = u"%s/#narrow/stream/" % (realm.uri,) base_url = u"%s/#narrow/stream/" % (realm.uri,)
return base_url + hash_util_encode(stream) return base_url + hash_util_encode(stream)
def topic_narrow_url(realm, stream, topic): def topic_narrow_url(realm: Realm, stream: Text, topic: Text) -> Text:
# type: (Realm, Text, Text) -> Text
base_url = u"%s/#narrow/stream/" % (realm.uri,) base_url = u"%s/#narrow/stream/" % (realm.uri,)
return u"%s%s/topic/%s" % (base_url, hash_util_encode(stream), return u"%s%s/topic/%s" % (base_url, hash_util_encode(stream),
hash_util_encode(topic)) hash_util_encode(topic))
def relative_to_full_url(base_url, content): def relative_to_full_url(base_url: Text, content: Text) -> Text:
# type: (Text, Text) -> Text
# Convert relative URLs to absolute URLs. # Convert relative URLs to absolute URLs.
fragment = lxml.html.fromstring(content) fragment = lxml.html.fromstring(content)
@@ -101,10 +95,8 @@ def relative_to_full_url(base_url, content):
return content return content
def fix_emojis(content, base_url, emojiset): def fix_emojis(content: Text, base_url: Text, emojiset: Text) -> Text:
# type: (Text, Text, Text) -> Text def make_emoji_img_elem(emoji_span_elem: Any) -> Dict[str, Any]:
def make_emoji_img_elem(emoji_span_elem):
# type: (Any) -> Dict[str, Any]
# Convert the emoji spans to img tags. # Convert the emoji spans to img tags.
classes = emoji_span_elem.get('class') classes = emoji_span_elem.get('class')
match = re.search('emoji-(?P<emoji_code>\S+)', classes) match = re.search('emoji-(?P<emoji_code>\S+)', classes)
@@ -138,8 +130,7 @@ def fix_emojis(content, base_url, emojiset):
content = lxml.html.tostring(fragment).decode('utf-8') content = lxml.html.tostring(fragment).decode('utf-8')
return content return content
def build_message_list(user_profile, messages): def build_message_list(user_profile: UserProfile, messages: List[Message]) -> List[Dict[str, Any]]:
# type: (UserProfile, List[Message]) -> List[Dict[str, Any]]
""" """
Builds the message list object for the missed message email template. Builds the message list object for the missed message email template.
The messages are collapsed into per-recipient and per-sender blocks, like The messages are collapsed into per-recipient and per-sender blocks, like
@@ -147,22 +138,19 @@ def build_message_list(user_profile, messages):
""" """
messages_to_render = [] # type: List[Dict[str, Any]] messages_to_render = [] # type: List[Dict[str, Any]]
def sender_string(message): def sender_string(message: Message) -> Text:
# type: (Message) -> Text
if message.recipient.type in (Recipient.STREAM, Recipient.HUDDLE): if message.recipient.type in (Recipient.STREAM, Recipient.HUDDLE):
return message.sender.full_name return message.sender.full_name
else: else:
return '' return ''
def fix_plaintext_image_urls(content): def fix_plaintext_image_urls(content: Text) -> Text:
# type: (Text) -> Text
# Replace image URLs in plaintext content of the form # Replace image URLs in plaintext content of the form
# [image name](image url) # [image name](image url)
# with a simple hyperlink. # with a simple hyperlink.
return re.sub(r"\[(\S*)\]\((\S*)\)", r"\2", content) return re.sub(r"\[(\S*)\]\((\S*)\)", r"\2", content)
def build_message_payload(message): def build_message_payload(message: Message) -> Dict[str, Text]:
# type: (Message) -> Dict[str, Text]
plain = message.content plain = message.content
plain = fix_plaintext_image_urls(plain) plain = fix_plaintext_image_urls(plain)
# There's a small chance of colliding with non-Zulip URLs containing # There's a small chance of colliding with non-Zulip URLs containing
@@ -181,14 +169,12 @@ def build_message_list(user_profile, messages):
return {'plain': plain, 'html': html} return {'plain': plain, 'html': html}
def build_sender_payload(message): def build_sender_payload(message: Message) -> Dict[str, Any]:
# type: (Message) -> Dict[str, Any]
sender = sender_string(message) sender = sender_string(message)
return {'sender': sender, return {'sender': sender,
'content': [build_message_payload(message)]} 'content': [build_message_payload(message)]}
def message_header(user_profile, message): def message_header(user_profile: UserProfile, message: Message) -> Dict[str, Any]:
# type: (UserProfile, Message) -> Dict[str, Any]
disp_recipient = get_display_recipient(message.recipient) disp_recipient = get_display_recipient(message.recipient)
if message.recipient.type == Recipient.PERSONAL: if message.recipient.type == Recipient.PERSONAL:
header = u"You and %s" % (message.sender.full_name,) header = u"You and %s" % (message.sender.full_name,)
@@ -264,8 +250,9 @@ def build_message_list(user_profile, messages):
return messages_to_render return messages_to_render
@statsd_increment("missed_message_reminders") @statsd_increment("missed_message_reminders")
def do_send_missedmessage_events_reply_in_zulip(user_profile, missed_messages, message_count): def do_send_missedmessage_events_reply_in_zulip(user_profile: UserProfile,
# type: (UserProfile, List[Message], int) -> None missed_messages: List[Message],
message_count: int) -> None:
""" """
Send a reminder email to a user if she's missed some PMs by being offline. Send a reminder email to a user if she's missed some PMs by being offline.
@@ -384,8 +371,7 @@ def do_send_missedmessage_events_reply_in_zulip(user_profile, missed_messages, m
user_profile.last_reminder = timezone_now() user_profile.last_reminder = timezone_now()
user_profile.save(update_fields=['last_reminder']) user_profile.save(update_fields=['last_reminder'])
def handle_missedmessage_emails(user_profile_id, missed_email_events): def handle_missedmessage_emails(user_profile_id: int, missed_email_events: Iterable[Dict[str, Any]]) -> None:
# type: (int, Iterable[Dict[str, Any]]) -> None
message_ids = [event.get('message_id') for event in missed_email_events] message_ids = [event.get('message_id') for event in missed_email_events]
user_profile = get_user_profile_by_id(user_profile_id) user_profile = get_user_profile_by_id(user_profile_id)
@@ -429,29 +415,25 @@ def handle_missedmessage_emails(user_profile_id, missed_email_events):
message_count_by_recipient_subject[recipient_subject], message_count_by_recipient_subject[recipient_subject],
) )
def clear_scheduled_invitation_emails(email): def clear_scheduled_invitation_emails(email: str) -> None:
# type: (str) -> None
"""Unlike most scheduled emails, invitation emails don't have an """Unlike most scheduled emails, invitation emails don't have an
existing user object to key off of, so we filter by address here.""" existing user object to key off of, so we filter by address here."""
items = ScheduledEmail.objects.filter(address__iexact=email, items = ScheduledEmail.objects.filter(address__iexact=email,
type=ScheduledEmail.INVITATION_REMINDER) type=ScheduledEmail.INVITATION_REMINDER)
items.delete() items.delete()
def clear_scheduled_emails(user_id, email_type=None): def clear_scheduled_emails(user_id: int, email_type: Optional[int]=None) -> None:
# type: (int, Optional[int]) -> None
items = ScheduledEmail.objects.filter(user_id=user_id) items = ScheduledEmail.objects.filter(user_id=user_id)
if email_type is not None: if email_type is not None:
items = items.filter(type=email_type) items = items.filter(type=email_type)
items.delete() items.delete()
def log_digest_event(msg): def log_digest_event(msg: Text) -> None:
# type: (Text) -> None
import logging import logging
logging.basicConfig(filename=settings.DIGEST_LOG_PATH, level=logging.INFO) logging.basicConfig(filename=settings.DIGEST_LOG_PATH, level=logging.INFO)
logging.info(msg) logging.info(msg)
def enqueue_welcome_emails(user): def enqueue_welcome_emails(user: UserProfile) -> None:
# type: (UserProfile) -> None
from zerver.context_processors import common_context from zerver.context_processors import common_context
if settings.WELCOME_EMAIL_SENDER is not None: if settings.WELCOME_EMAIL_SENDER is not None:
# line break to avoid triggering lint rule # line break to avoid triggering lint rule
@@ -476,8 +458,7 @@ def enqueue_welcome_emails(user):
"zerver/emails/followup_day2", to_user_id=user.id, from_name=from_name, "zerver/emails/followup_day2", to_user_id=user.id, from_name=from_name,
from_address=from_address, context=context, delay=datetime.timedelta(days=1)) from_address=from_address, context=context, delay=datetime.timedelta(days=1))
def convert_html_to_markdown(html): def convert_html_to_markdown(html: Text) -> Text:
# type: (Text) -> Text
# On Linux, the tool installs as html2markdown, and there's a command called # On Linux, the tool installs as html2markdown, and there's a command called
# html2text that does something totally different. On OSX, the tool installs # html2text that does something totally different. On OSX, the tool installs
# as html2text. # as html2text.

View File

@@ -44,12 +44,10 @@ else: # nocoverage -- Not convenient to add test for this.
DeviceToken = Union[PushDeviceToken, RemotePushDeviceToken] DeviceToken = Union[PushDeviceToken, RemotePushDeviceToken]
# We store the token as b64, but apns-client wants hex strings # We store the token as b64, but apns-client wants hex strings
def b64_to_hex(data): def b64_to_hex(data: bytes) -> Text:
# type: (bytes) -> Text
return binascii.hexlify(base64.b64decode(data)).decode('utf-8') return binascii.hexlify(base64.b64decode(data)).decode('utf-8')
def hex_to_b64(data): def hex_to_b64(data: Text) -> bytes:
# type: (Text) -> bytes
return base64.b64encode(binascii.unhexlify(data.encode('utf-8'))) return base64.b64encode(binascii.unhexlify(data.encode('utf-8')))
# #
@@ -58,8 +56,7 @@ def hex_to_b64(data):
_apns_client = None # type: APNsClient _apns_client = None # type: APNsClient
def get_apns_client(): def get_apns_client() -> APNsClient:
# type: () -> APNsClient
global _apns_client global _apns_client
if _apns_client is None: if _apns_client is None:
# NB if called concurrently, this will make excess connections. # NB if called concurrently, this will make excess connections.
@@ -69,8 +66,7 @@ def get_apns_client():
use_sandbox=settings.APNS_SANDBOX) use_sandbox=settings.APNS_SANDBOX)
return _apns_client return _apns_client
def modernize_apns_payload(data): def modernize_apns_payload(data: Dict[str, Any]) -> Dict[str, Any]:
# type: (Dict[str, Any]) -> Dict[str, Any]
'''Take a payload in an unknown Zulip version's format, and return in current format.''' '''Take a payload in an unknown Zulip version's format, and return in current format.'''
# TODO this isn't super robust as is -- if a buggy remote server # TODO this isn't super robust as is -- if a buggy remote server
# sends a malformed payload, we are likely to raise an exception. # sends a malformed payload, we are likely to raise an exception.
@@ -96,8 +92,8 @@ def modernize_apns_payload(data):
APNS_MAX_RETRIES = 3 APNS_MAX_RETRIES = 3
@statsd_increment("apple_push_notification") @statsd_increment("apple_push_notification")
def send_apple_push_notification(user_id, devices, payload_data): def send_apple_push_notification(user_id: int, devices: List[DeviceToken],
# type: (int, List[DeviceToken], Dict[str, Any]) -> None payload_data: Dict[str, Any]) -> None:
logging.info("APNs: Sending notification for user %d to %d devices", logging.info("APNs: Sending notification for user %d to %d devices",
user_id, len(devices)) user_id, len(devices))
payload = APNsPayload(**modernize_apns_payload(payload_data)) payload = APNsPayload(**modernize_apns_payload(payload_data))
@@ -107,8 +103,7 @@ def send_apple_push_notification(user_id, devices, payload_data):
for device in devices: for device in devices:
# TODO obviously this should be made to actually use the async # TODO obviously this should be made to actually use the async
def attempt_send(): def attempt_send() -> Optional[str]:
# type: () -> Optional[str]
stream_id = client.send_notification_async( stream_id = client.send_notification_async(
device.token, payload, topic='org.zulip.Zulip', device.token, payload, topic='org.zulip.Zulip',
expiration=expiration) expiration=expiration)
@@ -144,15 +139,14 @@ if settings.ANDROID_GCM_API_KEY: # nocoverage
else: else:
gcm = None gcm = None
def send_android_push_notification_to_user(user_profile, data): def send_android_push_notification_to_user(user_profile: UserProfile, data: Dict[str, Any]) -> None:
# type: (UserProfile, Dict[str, Any]) -> None
devices = list(PushDeviceToken.objects.filter(user=user_profile, devices = list(PushDeviceToken.objects.filter(user=user_profile,
kind=PushDeviceToken.GCM)) kind=PushDeviceToken.GCM))
send_android_push_notification(devices, data) send_android_push_notification(devices, data)
@statsd_increment("android_push_notification") @statsd_increment("android_push_notification")
def send_android_push_notification(devices, data, remote=False): def send_android_push_notification(devices: List[DeviceToken], data: Dict[str, Any],
# type: (List[DeviceToken], Dict[str, Any], bool) -> None remote: bool=False) -> None:
if not gcm: if not gcm:
logging.warning("Skipping sending a GCM push notification since " logging.warning("Skipping sending a GCM push notification since "
"PUSH_NOTIFICATION_BOUNCER_URL and ANDROID_GCM_API_KEY are both unset") "PUSH_NOTIFICATION_BOUNCER_URL and ANDROID_GCM_API_KEY are both unset")
@@ -218,12 +212,12 @@ def send_android_push_notification(devices, data, remote=False):
# Sending to a bouncer # Sending to a bouncer
# #
def uses_notification_bouncer(): def uses_notification_bouncer() -> bool:
# type: () -> bool
return settings.PUSH_NOTIFICATION_BOUNCER_URL is not None return settings.PUSH_NOTIFICATION_BOUNCER_URL is not None
def send_notifications_to_bouncer(user_profile_id, apns_payload, gcm_payload): def send_notifications_to_bouncer(user_profile_id: int,
# type: (int, Dict[str, Any], Dict[str, Any]) -> None apns_payload: Dict[str, Any],
gcm_payload: Dict[str, Any]) -> None:
post_data = { post_data = {
'user_id': user_profile_id, 'user_id': user_profile_id,
'apns_payload': apns_payload, 'apns_payload': apns_payload,
@@ -231,8 +225,7 @@ def send_notifications_to_bouncer(user_profile_id, apns_payload, gcm_payload):
} }
send_json_to_push_bouncer('POST', 'notify', post_data) send_json_to_push_bouncer('POST', 'notify', post_data)
def send_json_to_push_bouncer(method, endpoint, post_data): def send_json_to_push_bouncer(method: str, endpoint: str, post_data: Dict[str, Any]) -> None:
# type: (str, str, Dict[str, Any]) -> None
send_to_push_bouncer( send_to_push_bouncer(
method, method,
endpoint, endpoint,
@@ -243,8 +236,10 @@ def send_json_to_push_bouncer(method, endpoint, post_data):
class PushNotificationBouncerException(Exception): class PushNotificationBouncerException(Exception):
pass pass
def send_to_push_bouncer(method, endpoint, post_data, extra_headers=None): def send_to_push_bouncer(method: str,
# type: (str, str, Union[Text, Dict[str, Any]], Optional[Dict[str, Any]]) -> None endpoint: str,
post_data: Union[Text, Dict[str, Any]],
extra_headers: Optional[Dict[str, Any]]=None) -> None:
"""While it does actually send the notice, this function has a lot of """While it does actually send the notice, this function has a lot of
code and comments around error handling for the push notifications code and comments around error handling for the push notifications
bouncer. There are several classes of failures, each with its own bouncer. There are several classes of failures, each with its own
@@ -310,15 +305,16 @@ def send_to_push_bouncer(method, endpoint, post_data, extra_headers=None):
# Managing device tokens # Managing device tokens
# #
def num_push_devices_for_user(user_profile, kind = None): def num_push_devices_for_user(user_profile: UserProfile, kind: Optional[int]=None) -> PushDeviceToken:
# type: (UserProfile, Optional[int]) -> PushDeviceToken
if kind is None: if kind is None:
return PushDeviceToken.objects.filter(user=user_profile).count() return PushDeviceToken.objects.filter(user=user_profile).count()
else: else:
return PushDeviceToken.objects.filter(user=user_profile, kind=kind).count() return PushDeviceToken.objects.filter(user=user_profile, kind=kind).count()
def add_push_device_token(user_profile, token_str, kind, ios_app_id=None): def add_push_device_token(user_profile: UserProfile,
# type: (UserProfile, bytes, int, Optional[str]) -> None token_str: bytes,
kind: int,
ios_app_id: Optional[str]=None) -> None:
logging.info("New push device: %d %r %d %r", logging.info("New push device: %d %r %d %r",
user_profile.id, token_str, kind, ios_app_id) user_profile.id, token_str, kind, ios_app_id)
@@ -357,8 +353,7 @@ def add_push_device_token(user_profile, token_str, kind, ios_app_id=None):
else: else:
logging.info("New push device created.") logging.info("New push device created.")
def remove_push_device_token(user_profile, token_str, kind): def remove_push_device_token(user_profile: UserProfile, token_str: bytes, kind: int) -> None:
# type: (UserProfile, bytes, int) -> None
# If we're sending things to the push notification bouncer # If we're sending things to the push notification bouncer
# register this user with them here # register this user with them here
@@ -383,8 +378,7 @@ def remove_push_device_token(user_profile, token_str, kind):
# Push notifications in general # Push notifications in general
# #
def get_alert_from_message(message): def get_alert_from_message(message: Message) -> Text:
# type: (Message) -> Text
""" """
Determine what alert string to display based on the missed messages. Determine what alert string to display based on the missed messages.
""" """
@@ -401,10 +395,8 @@ def get_alert_from_message(message):
else: else:
return "New Zulip mentions and private messages from %s" % (sender_str,) return "New Zulip mentions and private messages from %s" % (sender_str,)
def get_mobile_push_content(rendered_content): def get_mobile_push_content(rendered_content: Text) -> Text:
# type: (Text) -> Text def get_text(elem: LH.HtmlElement) -> Text:
def get_text(elem):
# type: (LH.HtmlElement) -> Text
# Convert default emojis to their unicode equivalent. # Convert default emojis to their unicode equivalent.
classes = elem.get("class", "") classes = elem.get("class", "")
if "emoji" in classes: if "emoji" in classes:
@@ -421,8 +413,7 @@ def get_mobile_push_content(rendered_content):
return elem.text or "" return elem.text or ""
def process(elem): def process(elem: LH.HtmlElement) -> Text:
# type: (LH.HtmlElement) -> Text
plain_text = get_text(elem) plain_text = get_text(elem)
for child in elem: for child in elem:
plain_text += process(child) plain_text += process(child)
@@ -436,8 +427,7 @@ def get_mobile_push_content(rendered_content):
plain_text = process(elem) plain_text = process(elem)
return plain_text return plain_text
def truncate_content(content): def truncate_content(content: Text) -> Text:
# type: (Text) -> Text
# We use unicode character 'HORIZONTAL ELLIPSIS' (U+2026) instead # We use unicode character 'HORIZONTAL ELLIPSIS' (U+2026) instead
# of three dots as this saves two extra characters for textual # of three dots as this saves two extra characters for textual
# content. This function will need to be updated to handle unicode # content. This function will need to be updated to handle unicode
@@ -446,8 +436,7 @@ def truncate_content(content):
return content return content
return content[:200] + "" return content[:200] + ""
def get_apns_payload(message): def get_apns_payload(message: Message) -> Dict[str, Any]:
# type: (Message) -> Dict[str, Any]
text_content = get_mobile_push_content(message.rendered_content) text_content = get_mobile_push_content(message.rendered_content)
truncated_content = truncate_content(text_content) truncated_content = truncate_content(text_content)
return { return {
@@ -464,8 +453,7 @@ def get_apns_payload(message):
} }
} }
def get_gcm_payload(user_profile, message): def get_gcm_payload(user_profile: UserProfile, message: Message) -> Dict[str, Any]:
# type: (UserProfile, Message) -> Dict[str, Any]
text_content = get_mobile_push_content(message.rendered_content) text_content = get_mobile_push_content(message.rendered_content)
truncated_content = truncate_content(text_content) truncated_content = truncate_content(text_content)
@@ -492,8 +480,7 @@ def get_gcm_payload(user_profile, message):
return android_data return android_data
@statsd_increment("push_notifications") @statsd_increment("push_notifications")
def handle_push_notification(user_profile_id, missed_message): def handle_push_notification(user_profile_id: int, missed_message: Dict[str, Any]) -> None:
# type: (int, Dict[str, Any]) -> None
""" """
missed_message is the event received by the missed_message is the event received by the
zerver.worker.queue_processors.PushNotificationWorker.consume function. zerver.worker.queue_processors.PushNotificationWorker.consume function.

View File

@@ -10,8 +10,7 @@ from zerver.models import Realm, Message, UserMessage, ArchivedMessage, Archived
from typing import Any, Dict, Optional, Generator from typing import Any, Dict, Optional, Generator
def get_realm_expired_messages(realm): def get_realm_expired_messages(realm: Any) -> Optional[Dict[str, Any]]:
# type: (Any) -> Optional[Dict[str, Any]]
expired_date = timezone_now() - timedelta(days=realm.message_retention_days) expired_date = timezone_now() - timedelta(days=realm.message_retention_days)
expired_messages = Message.objects.order_by('id').filter(sender__realm=realm, expired_messages = Message.objects.order_by('id').filter(sender__realm=realm,
pub_date__lt=expired_date) pub_date__lt=expired_date)
@@ -20,8 +19,7 @@ def get_realm_expired_messages(realm):
return {'realm_id': realm.id, 'expired_messages': expired_messages} return {'realm_id': realm.id, 'expired_messages': expired_messages}
def get_expired_messages(): def get_expired_messages() -> Generator[Any, None, None]:
# type: () -> Generator[Any, None, None]
# Get all expired messages by Realm. # Get all expired messages by Realm.
realms = Realm.objects.order_by('string_id').filter( realms = Realm.objects.order_by('string_id').filter(
deactivated=False, message_retention_days__isnull=False) deactivated=False, message_retention_days__isnull=False)
@@ -31,8 +29,7 @@ def get_expired_messages():
yield realm_expired_messages yield realm_expired_messages
def move_attachment_message_to_archive_by_message(message_id): def move_attachment_message_to_archive_by_message(message_id: int) -> None:
# type: (int) -> None
# Move attachments messages relation table data to archive. # Move attachments messages relation table data to archive.
query = """ query = """
INSERT INTO zerver_archivedattachment_messages (id, archivedattachment_id, INSERT INTO zerver_archivedattachment_messages (id, archivedattachment_id,
@@ -50,8 +47,7 @@ def move_attachment_message_to_archive_by_message(message_id):
@transaction.atomic @transaction.atomic
def move_message_to_archive(message_id): def move_message_to_archive(message_id: int) -> None:
# type: (int) -> None
msg = list(Message.objects.filter(id=message_id).values()) msg = list(Message.objects.filter(id=message_id).values())
if not msg: if not msg:
raise Message.DoesNotExist raise Message.DoesNotExist