diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index f6470d5e7c..9cc6b9a163 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -534,3 +534,33 @@ def ignore_unhashable_lru_cache(maxsize: int=128, typed: bool=False) -> DECORATO return wrapper return decorator + +def dict_to_items_tuple(user_function: Callable[..., Any]) -> Callable[..., Any]: + """Wrapper that converts any dict args to dict item tuples.""" + def dict_to_tuple(arg: Any) -> Any: + if isinstance(arg, dict): + return tuple(sorted(arg.items())) + return arg + + def wrapper(*args: Any, **kwargs: Any) -> Any: + new_args = (dict_to_tuple(arg) for arg in args) + return user_function(*new_args, **kwargs) + + return wrapper + +def items_tuple_to_dict(user_function: Callable[..., Any]) -> Callable[..., Any]: + """Wrapper that converts any dict items tuple args to dicts.""" + def dict_items_to_dict(arg: Any) -> Any: + if isinstance(arg, tuple): + try: + return dict(arg) + except TypeError: + pass + return arg + + def wrapper(*args: Any, **kwargs: Any) -> Any: + new_args = (dict_items_to_dict(arg) for arg in args) + new_kwargs = {key: dict_items_to_dict(val) for key, val in kwargs.items()} + return user_function(*new_args, **new_kwargs) + + return wrapper diff --git a/zerver/templatetags/app_filters.py b/zerver/templatetags/app_filters.py index faaf0685a4..b10498f5cd 100644 --- a/zerver/templatetags/app_filters.py +++ b/zerver/templatetags/app_filters.py @@ -19,7 +19,7 @@ import zerver.lib.bugdown.help_settings_links import zerver.lib.bugdown.help_relative_links import zerver.lib.bugdown.help_emoticon_translations_table import zerver.lib.bugdown.include -from zerver.lib.cache import ignore_unhashable_lru_cache +from zerver.lib.cache import ignore_unhashable_lru_cache, dict_to_items_tuple, items_tuple_to_dict register = Library() @@ -67,10 +67,12 @@ docs_without_macros = [ "incoming-webhooks-walkthrough.md", ] -# Much of the time, render_markdown_path is called with hashable -# arguments, so this decorator is effective even though it only caches -# the results when called if none of the arguments are unhashable. +# render_markdown_path is passed a context dictionary (unhashable), which +# results in the calls not being cached. To work around this, we convert the +# dict to a tuple of dict items to cache the results. +@dict_to_items_tuple @ignore_unhashable_lru_cache(512) +@items_tuple_to_dict @register.filter(name='render_markdown_path', is_safe=True) def render_markdown_path(markdown_file_path: str, context: Optional[Dict[Any, Any]]=None, diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index ad9d60db62..6621b62bca 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -40,7 +40,7 @@ from zerver.decorator import ( return_success_on_head_request, to_not_negative_int_or_none, zulip_login_required ) -from zerver.lib.cache import ignore_unhashable_lru_cache +from zerver.lib.cache import ignore_unhashable_lru_cache, dict_to_items_tuple, items_tuple_to_dict from zerver.lib.validator import ( check_string, check_dict, check_dict_only, check_bool, check_float, check_int, check_list, Validator, check_variable_type, equals, check_none_or, check_url, check_short_string, @@ -1759,13 +1759,74 @@ class TestIgnoreUnhashableLRUCache(ZulipTestCase): self.assertEqual(result, 1) # Check unhashable argument. - result = f([1]) + result = f({1: 2}) hits, misses, currsize = get_cache_info() # Cache should not be used. self.assertEqual(hits, 1) self.assertEqual(misses, 1) self.assertEqual(currsize, 1) - self.assertEqual(result, [1]) + self.assertEqual(result, {1: 2}) + + # Clear cache. + clear_cache() + hits, misses, currsize = get_cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + self.assertEqual(currsize, 0) + + def test_cache_hit_dict_args(self) -> None: + @ignore_unhashable_lru_cache() + @items_tuple_to_dict + def g(arg: Any) -> Any: + return arg + + def get_cache_info() -> Tuple[int, int, int]: + info = getattr(g, 'cache_info')() + hits = getattr(info, 'hits') + misses = getattr(info, 'misses') + currsize = getattr(info, 'currsize') + return hits, misses, currsize + + def clear_cache() -> None: + getattr(g, 'cache_clear')() + + # Not used as a decorator on the definition to allow defining + # get_cache_info and clear_cache + f = dict_to_items_tuple(g) + + # Check hashable argument. + result = f(1) + hits, misses, currsize = get_cache_info() + # First one should be a miss. + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + self.assertEqual(result, 1) + + result = f(1) + hits, misses, currsize = get_cache_info() + # Second one should be a hit. + self.assertEqual(hits, 1) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + self.assertEqual(result, 1) + + # Check dict argument. + result = f({1: 2}) + hits, misses, currsize = get_cache_info() + # First one is a miss + self.assertEqual(hits, 1) + self.assertEqual(misses, 2) + self.assertEqual(currsize, 2) + self.assertEqual(result, {1: 2}) + + result = f({1: 2}) + hits, misses, currsize = get_cache_info() + # Second one should be a hit. + self.assertEqual(hits, 2) + self.assertEqual(misses, 2) + self.assertEqual(currsize, 2) + self.assertEqual(result, {1: 2}) # Clear cache. clear_cache()