cache: Add ignore_unhashable_lru_cache function.

This is a wrapper over lru_cache function. It adds following features on
top of lru_cache:

    * It will not cache result of functions with unhashable arguments.
    * It will clear cache whenever zerver.lib.cache.KEY_PREFIX changes.
This commit is contained in:
Umair Khan
2018-01-12 12:57:10 +05:00
committed by Tim Abbott
parent 91ce455ac3
commit 0eca2e102d
2 changed files with 95 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
from functools import wraps
from django.utils.lru_cache import lru_cache
from django.core.cache import cache as djcache
from django.core.cache import caches
from django.conf import settings
@@ -454,3 +455,47 @@ def to_dict_cache_key(message):
def flush_message(sender: Any, **kwargs: Any) -> None:
message = kwargs['instance']
cache_delete(to_dict_cache_key_id(message.id))
DECORATOR = Callable[[Callable[..., Any]], Callable[..., Any]]
def ignore_unhashable_lru_cache(maxsize: int=128, typed: bool=False) -> DECORATOR:
"""
This is a wrapper over lru_cache function. It adds following features on
top of lru_cache:
* It will not cache result of functions with unhashable arguments.
* It will clear cache whenever zerver.lib.cache.KEY_PREFIX changes.
"""
internal_decorator = lru_cache(maxsize=maxsize, typed=typed)
def decorator(user_function: Callable[..., Any]) -> Callable[..., Any]:
cache_enabled_user_function = internal_decorator(user_function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not hasattr(cache_enabled_user_function, 'key_prefix'):
cache_enabled_user_function.key_prefix = KEY_PREFIX
if cache_enabled_user_function.key_prefix != KEY_PREFIX:
# Clear cache when cache.KEY_PREFIX changes. This is used in
# tests.
cache_enabled_user_function.cache_clear()
cache_enabled_user_function.key_prefix = KEY_PREFIX
try:
return cache_enabled_user_function(*args, **kwargs)
except TypeError:
# args or kwargs contains an element which is unhashable. In
# this case we don't cache the result.
pass
# Deliberately calling this function from outside of exception
# handler to get a more descriptive traceback. Otherise traceback
# can include the exception from cached_enabled_user_function as
# well.
return user_function(*args, **kwargs)
setattr(wrapper, 'cache_info', cache_enabled_user_function.cache_info)
setattr(wrapper, 'cache_clear', cache_enabled_user_function.cache_clear)
return wrapper
return decorator

View File

@@ -34,6 +34,7 @@ from zerver.decorator import (
rate_limit, validate_api_key, logged_in_and_active,
return_success_on_head_request
)
from zerver.lib.cache import ignore_unhashable_lru_cache
from zerver.lib.validator import (
check_string, check_dict, check_dict_only, check_bool, check_float, check_int, check_list, Validator,
check_variable_type, equals, check_none_or, check_url, check_short_string
@@ -1285,3 +1286,52 @@ class TestUserAgentParsing(ZulipTestCase):
user_agents_parsed[ret["name"]] += int(count)
self.assertEqual(len(parse_errors), 0)
class TestIgnoreUnhashableLRUCache(ZulipTestCase):
def test_cache_hit(self) -> None:
@ignore_unhashable_lru_cache()
def f(arg: Any) -> Any:
return arg
def get_cache_info() -> Tuple[int, int, int]:
info = getattr(f, 'cache_info')()
hits = getattr(info, 'hits')
misses = getattr(info, 'misses')
currsize = getattr(info, 'currsize')
return hits, misses, currsize
def clear_cache() -> None:
getattr(f, 'cache_clear')()
# Check hashable argument.
result = f(1)
hits, misses, currsize = get_cache_info()
# First one should be a miss.
self.assertEqual(hits, 0)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(result, 1)
result = f(1)
hits, misses, currsize = get_cache_info()
# Second one should be a hit.
self.assertEqual(hits, 1)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(result, 1)
# Check unhashable argument.
result = f([1])
hits, misses, currsize = get_cache_info()
# Cache should not be used.
self.assertEqual(hits, 1)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(result, [1])
# Clear cache.
clear_cache()
hits, misses, currsize = get_cache_info()
self.assertEqual(hits, 0)
self.assertEqual(misses, 0)
self.assertEqual(currsize, 0)