diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index 0c4abfaeb5..856f3fe5ca 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -1,6 +1,7 @@ from functools import wraps +from django.utils.lru_cache import lru_cache from django.core.cache import cache as djcache from django.core.cache import caches from django.conf import settings @@ -454,3 +455,47 @@ def to_dict_cache_key(message): def flush_message(sender: Any, **kwargs: Any) -> None: message = kwargs['instance'] cache_delete(to_dict_cache_key_id(message.id)) + +DECORATOR = Callable[[Callable[..., Any]], Callable[..., Any]] + +def ignore_unhashable_lru_cache(maxsize: int=128, typed: bool=False) -> DECORATOR: + """ + This is a wrapper over lru_cache function. It adds following features on + top of lru_cache: + + * It will not cache result of functions with unhashable arguments. + * It will clear cache whenever zerver.lib.cache.KEY_PREFIX changes. + """ + internal_decorator = lru_cache(maxsize=maxsize, typed=typed) + + def decorator(user_function: Callable[..., Any]) -> Callable[..., Any]: + cache_enabled_user_function = internal_decorator(user_function) + + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not hasattr(cache_enabled_user_function, 'key_prefix'): + cache_enabled_user_function.key_prefix = KEY_PREFIX + + if cache_enabled_user_function.key_prefix != KEY_PREFIX: + # Clear cache when cache.KEY_PREFIX changes. This is used in + # tests. + cache_enabled_user_function.cache_clear() + cache_enabled_user_function.key_prefix = KEY_PREFIX + + try: + return cache_enabled_user_function(*args, **kwargs) + except TypeError: + # args or kwargs contains an element which is unhashable. In + # this case we don't cache the result. + pass + + # Deliberately calling this function from outside of exception + # handler to get a more descriptive traceback. Otherise traceback + # can include the exception from cached_enabled_user_function as + # well. + return user_function(*args, **kwargs) + + setattr(wrapper, 'cache_info', cache_enabled_user_function.cache_info) + setattr(wrapper, 'cache_clear', cache_enabled_user_function.cache_clear) + return wrapper + + return decorator diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index e44fbc9374..f15a896b46 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -34,6 +34,7 @@ from zerver.decorator import ( rate_limit, validate_api_key, logged_in_and_active, return_success_on_head_request ) +from zerver.lib.cache import ignore_unhashable_lru_cache from zerver.lib.validator import ( check_string, check_dict, check_dict_only, check_bool, check_float, check_int, check_list, Validator, check_variable_type, equals, check_none_or, check_url, check_short_string @@ -1285,3 +1286,52 @@ class TestUserAgentParsing(ZulipTestCase): user_agents_parsed[ret["name"]] += int(count) self.assertEqual(len(parse_errors), 0) + +class TestIgnoreUnhashableLRUCache(ZulipTestCase): + def test_cache_hit(self) -> None: + @ignore_unhashable_lru_cache() + def f(arg: Any) -> Any: + return arg + + def get_cache_info() -> Tuple[int, int, int]: + info = getattr(f, 'cache_info')() + hits = getattr(info, 'hits') + misses = getattr(info, 'misses') + currsize = getattr(info, 'currsize') + return hits, misses, currsize + + def clear_cache() -> None: + getattr(f, 'cache_clear')() + + # Check hashable argument. + result = f(1) + hits, misses, currsize = get_cache_info() + # First one should be a miss. + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + self.assertEqual(result, 1) + + result = f(1) + hits, misses, currsize = get_cache_info() + # Second one should be a hit. + self.assertEqual(hits, 1) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + self.assertEqual(result, 1) + + # Check unhashable argument. + result = f([1]) + hits, misses, currsize = get_cache_info() + # Cache should not be used. + self.assertEqual(hits, 1) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + self.assertEqual(result, [1]) + + # Clear cache. + clear_cache() + hits, misses, currsize = get_cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + self.assertEqual(currsize, 0)