cache: Fix typing for generic_bulk_cached_fetch.

The typing for generic_bulk_cached_fetch is complicated, and was
recorded incorrectly previously for the case where a cache_transformer
function is required.  We fix this by adding the new CacheItemT, and
additionally add comments explaining what's going on with these types
for future reference.

Thanks to Mateusz Mandera for raising this issue.
This commit is contained in:
Tim Abbott
2019-08-08 12:34:06 -07:00
parent 2ada0a9bad
commit 27a0e307b6
4 changed files with 45 additions and 20 deletions

View File

@@ -216,9 +216,21 @@ def cache_delete_many(items: Iterable[str], cache_name: Optional[str]=None) -> N
KEY_PREFIX + item for item in items) KEY_PREFIX + item for item in items)
remote_cache_stats_finish() remote_cache_stats_finish()
# Generic_bulk_cached fetch and its helpers # Generic_bulk_cached fetch and its helpers. We start with declaring
# a few type variables that help define its interface.
# Type for the cache's keys; will typically be int or str.
ObjKT = TypeVar('ObjKT') ObjKT = TypeVar('ObjKT')
# Type for items to be fetched from the database (e.g. a Django model object)
ItemT = TypeVar('ItemT') ItemT = TypeVar('ItemT')
# Type for items to be stored in the cache (e.g. a dictionary serialization).
# Will equal ItemT unless a cache_transformer is specified.
CacheItemT = TypeVar('CacheItemT')
# Type for compressed items for storage in the cache. For
# serializable objects, will be the object; if encoded, bytes.
CompressedItemT = TypeVar('CompressedItemT') CompressedItemT = TypeVar('CompressedItemT')
def default_extractor(obj: CompressedItemT) -> ItemT: def default_extractor(obj: CompressedItemT) -> ItemT:
@@ -230,8 +242,8 @@ def default_setter(obj: ItemT) -> CompressedItemT:
def default_id_fetcher(obj: ItemT) -> ObjKT: def default_id_fetcher(obj: ItemT) -> ObjKT:
return obj.id # type: ignore # Need ItemT/CompressedItemT typevars to be a Django protocol return obj.id # type: ignore # Need ItemT/CompressedItemT typevars to be a Django protocol
def default_cache_transformer(obj: ItemT) -> ItemT: def default_cache_transformer(obj: ItemT) -> CacheItemT:
return obj return obj # type: ignore # Need a type assert that ItemT=CacheItemT
# Required Arguments are as follows: # Required Arguments are as follows:
# * object_ids: The list of object ids to look up # * object_ids: The list of object ids to look up
@@ -249,19 +261,19 @@ def default_cache_transformer(obj: ItemT) -> ItemT:
# function of the objects, not the objects themselves) # function of the objects, not the objects themselves)
def generic_bulk_cached_fetch( def generic_bulk_cached_fetch(
cache_key_function: Callable[[ObjKT], str], cache_key_function: Callable[[ObjKT], str],
query_function: Callable[[List[ObjKT]], Iterable[Any]], query_function: Callable[[List[ObjKT]], Iterable[ItemT]],
object_ids: Iterable[ObjKT], object_ids: Iterable[ObjKT],
extractor: Callable[[CompressedItemT], ItemT] = default_extractor, extractor: Callable[[CompressedItemT], CacheItemT] = default_extractor,
setter: Callable[[ItemT], CompressedItemT] = default_setter, setter: Callable[[CacheItemT], CompressedItemT] = default_setter,
id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher, id_fetcher: Callable[[ItemT], ObjKT] = default_id_fetcher,
cache_transformer: Callable[[ItemT], ItemT] = default_cache_transformer cache_transformer: Callable[[ItemT], CacheItemT] = default_cache_transformer,
) -> Dict[ObjKT, ItemT]: ) -> Dict[ObjKT, CacheItemT]:
cache_keys = {} # type: Dict[ObjKT, str] cache_keys = {} # type: Dict[ObjKT, str]
for object_id in object_ids: for object_id in object_ids:
cache_keys[object_id] = cache_key_function(object_id) cache_keys[object_id] = cache_key_function(object_id)
cached_objects_compressed = cache_get_many([cache_keys[object_id] cached_objects_compressed = cache_get_many([cache_keys[object_id]
for object_id in object_ids]) # type: Dict[str, Tuple[CompressedItemT]] for object_id in object_ids]) # type: Dict[str, Tuple[CompressedItemT]]
cached_objects = {} # type: Dict[str, ItemT] cached_objects = {} # type: Dict[str, CacheItemT]
for (key, val) in cached_objects_compressed.items(): for (key, val) in cached_objects_compressed.items():
cached_objects[key] = extractor(cached_objects_compressed[key][0]) cached_objects[key] = extractor(cached_objects_compressed[key][0])
needed_ids = [object_id for object_id in object_ids if needed_ids = [object_id for object_id in object_ids if

View File

@@ -88,13 +88,14 @@ def messages_for_ids(message_ids: List[int],
cache_transformer = MessageDict.build_dict_from_raw_db_row cache_transformer = MessageDict.build_dict_from_raw_db_row
id_fetcher = lambda row: row['id'] id_fetcher = lambda row: row['id']
message_dicts = generic_bulk_cached_fetch(to_dict_cache_key_id, message_dicts = generic_bulk_cached_fetch(
MessageDict.get_raw_db_rows, to_dict_cache_key_id,
message_ids, MessageDict.get_raw_db_rows,
id_fetcher=id_fetcher, message_ids,
cache_transformer=cache_transformer, id_fetcher=id_fetcher,
extractor=extract_message_dict, cache_transformer=cache_transformer,
setter=stringify_message_dict) extractor=extract_message_dict,
setter=stringify_message_dict)
message_list = [] # type: List[Dict[str, Any]] message_list = [] # type: List[Dict[str, Any]]

View File

@@ -124,13 +124,16 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
where=[where_clause], where=[where_clause],
params=emails) params=emails)
def user_to_email(user_profile: UserProfile) -> str:
return user_profile.email.lower()
return generic_bulk_cached_fetch( return generic_bulk_cached_fetch(
# Use a separate cache key to protect us from conflicts with # Use a separate cache key to protect us from conflicts with
# the get_user cache. # the get_user cache.
lambda email: 'bulk_get_users:' + user_profile_cache_key_id(email, realm_id), lambda email: 'bulk_get_users:' + user_profile_cache_key_id(email, realm_id),
fetch_users_by_email, fetch_users_by_email,
[email.lower() for email in emails], [email.lower() for email in emails],
id_fetcher=lambda user_profile: user_profile.email.lower() id_fetcher=user_to_email,
) )
def user_ids_to_users(user_ids: List[int], realm: Realm) -> List[UserProfile]: def user_ids_to_users(user_ids: List[int], realm: Realm) -> List[UserProfile]:

View File

@@ -1435,10 +1435,16 @@ def bulk_get_streams(realm: Realm, stream_names: STREAM_NAMES) -> Dict[str, Any]
where=[where_clause], where=[where_clause],
params=stream_names) params=stream_names)
return generic_bulk_cached_fetch(lambda stream_name: get_stream_cache_key(stream_name, realm.id), def stream_name_to_cache_key(stream_name: str) -> str:
return get_stream_cache_key(stream_name, realm.id)
def stream_to_lower_name(stream: Stream) -> str:
return stream.name.lower()
return generic_bulk_cached_fetch(stream_name_to_cache_key,
fetch_streams_by_name, fetch_streams_by_name,
[stream_name.lower() for stream_name in stream_names], [stream_name.lower() for stream_name in stream_names],
id_fetcher=lambda stream: stream.name.lower()) id_fetcher=stream_to_lower_name)
def get_recipient_cache_key(type: int, type_id: int) -> str: def get_recipient_cache_key(type: int, type_id: int) -> str:
return u"%s:get_recipient:%s:%s" % (cache.KEY_PREFIX, type, type_id,) return u"%s:get_recipient:%s:%s" % (cache.KEY_PREFIX, type, type_id,)
@@ -1477,8 +1483,11 @@ def bulk_get_recipients(type: int, type_ids: List[int]) -> Dict[int, Any]:
# TODO: Change return type to QuerySet[Recipient] # TODO: Change return type to QuerySet[Recipient]
return Recipient.objects.filter(type=type, type_id__in=type_ids) return Recipient.objects.filter(type=type, type_id__in=type_ids)
def recipient_to_type_id(recipient: Recipient) -> int:
return recipient.type_id
return generic_bulk_cached_fetch(cache_key_function, query_function, type_ids, return generic_bulk_cached_fetch(cache_key_function, query_function, type_ids,
id_fetcher=lambda recipient: recipient.type_id) id_fetcher=recipient_to_type_id)
def get_stream_recipients(stream_ids: List[int]) -> List[Recipient]: def get_stream_recipients(stream_ids: List[int]) -> List[Recipient]: