mirror of
				https://github.com/zulip/zulip.git
				synced 2025-11-03 21:43:21 +00:00 
			
		
		
		
	cache: Fix type: ignore issues.
This was hiding an actual type error in test_cache: a mismatch between the object ID type, which is str, and the default id_fetcher, which returns int. Mypy’s insufficient support for default generic arguments basically means we can’t use them without a lot of overloading, and there are not enough callers here to justify that. https://github.com/python/mypy/issues/3737 We avoid this being super messy where the code calls this by adding some less generic wrappers for generic_bulk_cached_fetch. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
		
				
					committed by
					
						
						Tim Abbott
					
				
			
			
				
	
			
			
			
						parent
						
							bb8dcb9b1e
						
					
				
				
					commit
					1b96af2987
				
			@@ -348,23 +348,10 @@ CacheItemT = TypeVar('CacheItemT')
 | 
			
		||||
# serializable objects, will be the object; if encoded, bytes.
 | 
			
		||||
CompressedItemT = TypeVar('CompressedItemT')
 | 
			
		||||
 | 
			
		||||
def default_extractor(obj: CompressedItemT) -> ItemT:
 | 
			
		||||
    return obj  # type: ignore[return-value] # Need a type assert that ItemT=CompressedItemT
 | 
			
		||||
 | 
			
		||||
def default_setter(obj: ItemT) -> CompressedItemT:
 | 
			
		||||
    return obj  # type: ignore[return-value] # Need a type assert that ItemT=CompressedItemT
 | 
			
		||||
 | 
			
		||||
def default_id_fetcher(obj: ItemT) -> ObjKT:
 | 
			
		||||
    return obj.id  # type: ignore[attr-defined] # Need ItemT/CompressedItemT typevars to be a Django protocol
 | 
			
		||||
 | 
			
		||||
def default_cache_transformer(obj: ItemT) -> CacheItemT:
 | 
			
		||||
    return obj  # type: ignore[return-value] # Need a type assert that ItemT=CacheItemT
 | 
			
		||||
 | 
			
		||||
# Required Arguments are as follows:
 | 
			
		||||
# * object_ids: The list of object ids to look up
 | 
			
		||||
# * cache_key_function: object_id => cache key
 | 
			
		||||
# * query_function: [object_ids] => [objects from database]
 | 
			
		||||
# Optional keyword arguments:
 | 
			
		||||
# * setter: Function to call before storing items to cache (e.g. compression)
 | 
			
		||||
# * extractor: Function to call on items returned from cache
 | 
			
		||||
#   (e.g. decompression).  Should be the inverse of the setter
 | 
			
		||||
@@ -378,10 +365,11 @@ def generic_bulk_cached_fetch(
 | 
			
		||||
        cache_key_function: Callable[[ObjKT], str],
 | 
			
		||||
        query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
 | 
			
		||||
        object_ids: Sequence[ObjKT],
 | 
			
		||||
        extractor: Callable[[CompressedItemT], CacheItemT] = default_extractor,
 | 
			
		||||
        setter: Callable[[CacheItemT], CompressedItemT] = default_setter,
 | 
			
		||||
        id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher,
 | 
			
		||||
        cache_transformer: Callable[[ItemT], CacheItemT] = default_cache_transformer,
 | 
			
		||||
        *,
 | 
			
		||||
        extractor: Callable[[CompressedItemT], CacheItemT],
 | 
			
		||||
        setter: Callable[[CacheItemT], CompressedItemT],
 | 
			
		||||
        id_fetcher: Callable[[ItemT], ObjKT],
 | 
			
		||||
        cache_transformer: Callable[[ItemT], CacheItemT],
 | 
			
		||||
) -> Dict[ObjKT, CacheItemT]:
 | 
			
		||||
    if len(object_ids) == 0:
 | 
			
		||||
        # Nothing to fetch.
 | 
			
		||||
@@ -418,6 +406,39 @@ def generic_bulk_cached_fetch(
 | 
			
		||||
    return {object_id: cached_objects[cache_keys[object_id]] for object_id in object_ids
 | 
			
		||||
            if cache_keys[object_id] in cached_objects}
 | 
			
		||||
 | 
			
		||||
def transformed_bulk_cached_fetch(
 | 
			
		||||
    cache_key_function: Callable[[ObjKT], str],
 | 
			
		||||
    query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
 | 
			
		||||
    object_ids: Sequence[ObjKT],
 | 
			
		||||
    *,
 | 
			
		||||
    id_fetcher: Callable[[ItemT], ObjKT],
 | 
			
		||||
    cache_transformer: Callable[[ItemT], CacheItemT],
 | 
			
		||||
) -> Dict[ObjKT, CacheItemT]:
 | 
			
		||||
    return generic_bulk_cached_fetch(
 | 
			
		||||
        cache_key_function,
 | 
			
		||||
        query_function,
 | 
			
		||||
        object_ids,
 | 
			
		||||
        extractor=lambda obj: obj,
 | 
			
		||||
        setter=lambda obj: obj,
 | 
			
		||||
        id_fetcher=id_fetcher,
 | 
			
		||||
        cache_transformer=cache_transformer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def bulk_cached_fetch(
 | 
			
		||||
    cache_key_function: Callable[[ObjKT], str],
 | 
			
		||||
    query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
 | 
			
		||||
    object_ids: Sequence[ObjKT],
 | 
			
		||||
    *,
 | 
			
		||||
    id_fetcher: Callable[[ItemT], ObjKT],
 | 
			
		||||
) -> Dict[ObjKT, ItemT]:
 | 
			
		||||
    return transformed_bulk_cached_fetch(
 | 
			
		||||
        cache_key_function,
 | 
			
		||||
        query_function,
 | 
			
		||||
        object_ids,
 | 
			
		||||
        id_fetcher=id_fetcher,
 | 
			
		||||
        cache_transformer=lambda obj: obj,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def preview_url_cache_key(url: str) -> str:
 | 
			
		||||
    return f"preview_url:{make_safe_digest(url)}"
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,10 +3,11 @@ from typing import Dict, List, Optional, Set, Tuple
 | 
			
		||||
from typing_extensions import TypedDict
 | 
			
		||||
 | 
			
		||||
from zerver.lib.cache import (
 | 
			
		||||
    bulk_cached_fetch,
 | 
			
		||||
    cache_with_key,
 | 
			
		||||
    display_recipient_bulk_get_users_by_id_cache_key,
 | 
			
		||||
    display_recipient_cache_key,
 | 
			
		||||
    generic_bulk_cached_fetch,
 | 
			
		||||
    transformed_bulk_cached_fetch,
 | 
			
		||||
)
 | 
			
		||||
from zerver.lib.types import DisplayRecipientT, UserDisplayRecipient
 | 
			
		||||
from zerver.models import Recipient, Stream, UserProfile, bulk_get_huddle_user_ids
 | 
			
		||||
@@ -49,7 +50,7 @@ def user_dict_id_fetcher(user_dict: UserDisplayRecipient) -> int:
 | 
			
		||||
    return user_dict['id']
 | 
			
		||||
 | 
			
		||||
def bulk_get_user_profile_by_id(uids: List[int]) -> Dict[int, UserDisplayRecipient]:
 | 
			
		||||
    return generic_bulk_cached_fetch(
 | 
			
		||||
    return bulk_cached_fetch(
 | 
			
		||||
        # Use a separate cache key to protect us from conflicts with
 | 
			
		||||
        # the get_user_profile_by_id cache.
 | 
			
		||||
        # (Since we fetch only several fields here)
 | 
			
		||||
@@ -96,7 +97,7 @@ def bulk_fetch_display_recipients(recipient_tuples: Set[Tuple[int, int, int]],
 | 
			
		||||
        return stream['name']
 | 
			
		||||
 | 
			
		||||
    # ItemT = Stream, CacheItemT = str (name), ObjKT = int (recipient_id)
 | 
			
		||||
    stream_display_recipients: Dict[int, str] = generic_bulk_cached_fetch(
 | 
			
		||||
    stream_display_recipients: Dict[int, str] = transformed_bulk_cached_fetch(
 | 
			
		||||
        cache_key_function=display_recipient_cache_key,
 | 
			
		||||
        query_function=stream_query_function,
 | 
			
		||||
        object_ids=[recipient[0] for recipient in stream_recipients],
 | 
			
		||||
@@ -167,7 +168,7 @@ def bulk_fetch_display_recipients(recipient_tuples: Set[Tuple[int, int, int]],
 | 
			
		||||
    # ItemT = Tuple[int, List[UserDisplayRecipient]] (recipient_id, list of corresponding users)
 | 
			
		||||
    # CacheItemT = List[UserDisplayRecipient] (display_recipient list)
 | 
			
		||||
    # ObjKT = int (recipient_id)
 | 
			
		||||
    personal_and_huddle_display_recipients = generic_bulk_cached_fetch(
 | 
			
		||||
    personal_and_huddle_display_recipients: Dict[int, List[UserDisplayRecipient]] = transformed_bulk_cached_fetch(
 | 
			
		||||
        cache_key_function=display_recipient_cache_key,
 | 
			
		||||
        query_function=personal_and_huddle_query_function,
 | 
			
		||||
        object_ids=[recipient[0] for recipient in personal_and_huddle_recipients],
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,7 @@ from zulip_bots.custom_exceptions import ConfigValidationError
 | 
			
		||||
 | 
			
		||||
from zerver.lib.avatar import avatar_url, get_avatar_field
 | 
			
		||||
from zerver.lib.cache import (
 | 
			
		||||
    generic_bulk_cached_fetch,
 | 
			
		||||
    bulk_cached_fetch,
 | 
			
		||||
    realm_user_dict_fields,
 | 
			
		||||
    user_profile_by_id_cache_key,
 | 
			
		||||
    user_profile_cache_key_id,
 | 
			
		||||
@@ -173,7 +173,7 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
 | 
			
		||||
    def user_to_email(user_profile: UserProfile) -> str:
 | 
			
		||||
        return user_profile.email.lower()
 | 
			
		||||
 | 
			
		||||
    return generic_bulk_cached_fetch(
 | 
			
		||||
    return bulk_cached_fetch(
 | 
			
		||||
        # Use a separate cache key to protect us from conflicts with
 | 
			
		||||
        # the get_user cache.
 | 
			
		||||
        lambda email: 'bulk_get_users:' + user_profile_cache_key_id(email, realm_id),
 | 
			
		||||
@@ -182,6 +182,9 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
 | 
			
		||||
        id_fetcher=user_to_email,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def get_user_id(user: UserProfile) -> int:
 | 
			
		||||
    return user.id
 | 
			
		||||
 | 
			
		||||
def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile]:
 | 
			
		||||
    # TODO: Consider adding a flag to control whether deactivated
 | 
			
		||||
    # users should be included.
 | 
			
		||||
@@ -189,10 +192,11 @@ def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile
 | 
			
		||||
    def fetch_users_by_id(user_ids: List[int]) -> List[UserProfile]:
 | 
			
		||||
        return list(UserProfile.objects.filter(id__in=user_ids).select_related())
 | 
			
		||||
 | 
			
		||||
    user_profiles_by_id: Dict[int, UserProfile] = generic_bulk_cached_fetch(
 | 
			
		||||
    user_profiles_by_id: Dict[int, UserProfile] = bulk_cached_fetch(
 | 
			
		||||
        cache_key_function=user_profile_by_id_cache_key,
 | 
			
		||||
        query_function=fetch_users_by_id,
 | 
			
		||||
        object_ids=user_ids,
 | 
			
		||||
        id_fetcher=get_user_id,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    found_user_ids = user_profiles_by_id.keys()
 | 
			
		||||
 
 | 
			
		||||
@@ -43,6 +43,7 @@ from zerver.lib.cache import (
 | 
			
		||||
    bot_dict_fields,
 | 
			
		||||
    bot_dicts_in_realm_cache_key,
 | 
			
		||||
    bot_profile_cache_key,
 | 
			
		||||
    bulk_cached_fetch,
 | 
			
		||||
    cache_delete,
 | 
			
		||||
    cache_set,
 | 
			
		||||
    cache_with_key,
 | 
			
		||||
@@ -52,7 +53,6 @@ from zerver.lib.cache import (
 | 
			
		||||
    flush_submessage,
 | 
			
		||||
    flush_used_upload_space_cache,
 | 
			
		||||
    flush_user_profile,
 | 
			
		||||
    generic_bulk_cached_fetch,
 | 
			
		||||
    get_realm_used_upload_space_cache_key,
 | 
			
		||||
    get_stream_cache_key,
 | 
			
		||||
    realm_alert_words_automaton_cache_key,
 | 
			
		||||
@@ -1653,10 +1653,12 @@ def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any]
 | 
			
		||||
    def stream_to_lower_name(stream: Stream) -> str:
 | 
			
		||||
        return stream.name.lower()
 | 
			
		||||
 | 
			
		||||
    return generic_bulk_cached_fetch(stream_name_to_cache_key,
 | 
			
		||||
                                     fetch_streams_by_name,
 | 
			
		||||
                                     [stream_name.lower() for stream_name in stream_names],
 | 
			
		||||
                                     id_fetcher=stream_to_lower_name)
 | 
			
		||||
    return bulk_cached_fetch(
 | 
			
		||||
        stream_name_to_cache_key,
 | 
			
		||||
        fetch_streams_by_name,
 | 
			
		||||
        [stream_name.lower() for stream_name in stream_names],
 | 
			
		||||
        id_fetcher=stream_to_lower_name,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def get_huddle_recipient(user_profile_ids: Set[int]) -> Recipient:
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ from zerver.lib.cache import (
 | 
			
		||||
    MEMCACHED_MAX_KEY_LENGTH,
 | 
			
		||||
    InvalidCacheKeyException,
 | 
			
		||||
    NotFoundInCache,
 | 
			
		||||
    bulk_cached_fetch,
 | 
			
		||||
    cache_delete,
 | 
			
		||||
    cache_delete_many,
 | 
			
		||||
    cache_get,
 | 
			
		||||
@@ -15,7 +16,6 @@ from zerver.lib.cache import (
 | 
			
		||||
    cache_set,
 | 
			
		||||
    cache_set_many,
 | 
			
		||||
    cache_with_key,
 | 
			
		||||
    generic_bulk_cached_fetch,
 | 
			
		||||
    get_cache_with_key,
 | 
			
		||||
    safe_cache_get_many,
 | 
			
		||||
    safe_cache_set_many,
 | 
			
		||||
@@ -272,6 +272,9 @@ class BotCacheKeyTest(ZulipTestCase):
 | 
			
		||||
        user_profile2 = get_user_profile_by_email(settings.EMAIL_GATEWAY_BOT)
 | 
			
		||||
        self.assertEqual(user_profile2.is_api_super_user, flipped_setting)
 | 
			
		||||
 | 
			
		||||
def get_user_email(user: UserProfile) -> str:
 | 
			
		||||
    return user.email  # nocoverage
 | 
			
		||||
 | 
			
		||||
class GenericBulkCachedFetchTest(ZulipTestCase):
 | 
			
		||||
    def test_query_function_called_only_if_needed(self) -> None:
 | 
			
		||||
        # Get the user cached:
 | 
			
		||||
@@ -285,20 +288,22 @@ class GenericBulkCachedFetchTest(ZulipTestCase):
 | 
			
		||||
 | 
			
		||||
        # query_function shouldn't be called, because the only requested object
 | 
			
		||||
        # is already cached:
 | 
			
		||||
        result: Dict[str, UserProfile] = generic_bulk_cached_fetch(
 | 
			
		||||
        result: Dict[str, UserProfile] = bulk_cached_fetch(
 | 
			
		||||
            cache_key_function=user_profile_by_email_cache_key,
 | 
			
		||||
            query_function=query_function,
 | 
			
		||||
            object_ids=[self.example_email("hamlet")],
 | 
			
		||||
            id_fetcher=get_user_email,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(result, {hamlet.delivery_email: hamlet})
 | 
			
		||||
 | 
			
		||||
        flush_cache(Mock())
 | 
			
		||||
        # With the cache flushed, the query_function should get called:
 | 
			
		||||
        with self.assertRaises(CustomException):
 | 
			
		||||
            generic_bulk_cached_fetch(
 | 
			
		||||
            result = bulk_cached_fetch(
 | 
			
		||||
                cache_key_function=user_profile_by_email_cache_key,
 | 
			
		||||
                query_function=query_function,
 | 
			
		||||
                object_ids=[self.example_email("hamlet")],
 | 
			
		||||
                id_fetcher=get_user_email,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_empty_object_ids_list(self) -> None:
 | 
			
		||||
@@ -313,9 +318,10 @@ class GenericBulkCachedFetchTest(ZulipTestCase):
 | 
			
		||||
 | 
			
		||||
        # query_function and cache_key_function shouldn't be called, because
 | 
			
		||||
        # objects_ids is empty, so there's nothing to do.
 | 
			
		||||
        result: Dict[str, UserProfile] = generic_bulk_cached_fetch(
 | 
			
		||||
        result: Dict[str, UserProfile] = bulk_cached_fetch(
 | 
			
		||||
            cache_key_function=cache_key_function,
 | 
			
		||||
            query_function=query_function,
 | 
			
		||||
            object_ids=[],
 | 
			
		||||
            id_fetcher=get_user_email,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(result, {})
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user