diff --git a/zerver/lib/bot_lib.py b/zerver/lib/bot_lib.py index 202a957f2e..3b9a728691 100644 --- a/zerver/lib/bot_lib.py +++ b/zerver/lib/bot_lib.py @@ -49,9 +49,10 @@ class StateHandler: def put(self, key: Text, value: Text) -> None: set_bot_state(self.user_profile, key, self.marshal(value)) + set_bot_state(self.user_profile, [(key, self.marshal(value))]) def remove(self, key: Text) -> None: - remove_bot_state(self.user_profile, key) + remove_bot_state(self.user_profile, [key]) def contains(self, key: Text) -> bool: return is_key_in_bot_state(self.user_profile, key) diff --git a/zerver/lib/bot_storage.py b/zerver/lib/bot_storage.py index 79a4c1d0d0..7122c77efa 100644 --- a/zerver/lib/bot_storage.py +++ b/zerver/lib/bot_storage.py @@ -4,7 +4,7 @@ from django.db.models.query import F from django.db.models.functions import Length from zerver.models import BotUserStateData, UserProfile, Length -from typing import Text, Optional +from typing import Text, Optional, List, Tuple class StateError(Exception): pass @@ -28,33 +28,31 @@ def get_bot_state_size(bot_profile, key=None): except BotUserStateData.DoesNotExist: return 0 -def set_bot_state(bot_profile, key, value): - # type: (UserProfile, Text, Text) -> None +def set_bot_state(bot_profile, entries): + # type: (UserProfile, List[Tuple[str, str]]) -> None state_size_limit = settings.USER_STATE_SIZE_LIMIT - old_entry_size = get_bot_state_size(bot_profile, key) - new_entry_size = len(key) + len(value) - old_state_size = get_bot_state_size(bot_profile) - new_state_size = old_state_size + (new_entry_size - old_entry_size) + state_size_difference = 0 + for key, value in entries: + if type(key) is not str: + raise StateError("Key type is {}, but should be str.".format(type(key))) + if type(value) is not str: + raise StateError("Value type is {}, but should be str.".format(type(value))) + state_size_difference += (len(key) + len(value)) - get_bot_state_size(bot_profile, key) + new_state_size = get_bot_state_size(bot_profile) + state_size_difference if new_state_size > state_size_limit: raise StateError("Request exceeds storage limit by {} characters. The limit is {} characters." .format(new_state_size - state_size_limit, state_size_limit)) - elif type(key) is not str: - raise StateError("Key type is {}, but should be str.".format(type(key))) - elif type(value) is not str: - raise StateError("Value type is {}, but should be str.".format(type(value))) else: - obj, created = BotUserStateData.objects.get_or_create(bot_profile=bot_profile, key=key, - defaults={'value': value}) - if not created: - obj.value = value - obj.save() + for key, value in entries: + BotUserStateData.objects.update_or_create(bot_profile=bot_profile, key=key, + defaults={'value': value}) -def remove_bot_state(bot_profile, key): - # type: (UserProfile, Text) -> None - try: - BotUserStateData.objects.get(bot_profile=bot_profile, key=key).delete() - except BotUserStateData.DoesNotExist: - raise StateError("Key does not exist.".format(key)) +def remove_bot_state(bot_profile, keys): + # type: (UserProfile, List[Text]) -> None + queryset = BotUserStateData.objects.filter(bot_profile=bot_profile, key__in=keys) + if len(queryset) < len(keys): + raise StateError("Key does not exist.") + queryset.delete() def is_key_in_bot_state(bot_profile, key): # type: (UserProfile, Text) -> bool