diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index bf09bbf731..d3b718e560 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -257,13 +257,22 @@ def add_new_user_history(user_profile, streams): recipients = Recipient.objects.filter(type=Recipient.STREAM, type_id__in=[stream.id for stream in streams if not stream.invite_only]) - messages = Message.objects.filter(recipient_id__in=recipients, pub_date__gt=one_week_ago).order_by("-id")[0:100] - if len(messages) > 0: - ums_to_create = [UserMessage(user_profile=user_profile, message=message, - flags=UserMessage.flags.read) - for message in messages] + recent_messages = Message.objects.filter(recipient_id__in=recipients, + pub_date__gt=one_week_ago).order_by("-id") + message_ids_to_use = list(recent_messages.values_list('id', flat=True)[0:100]) + if len(message_ids_to_use) == 0: + return - UserMessage.objects.bulk_create(ums_to_create) + # Handle the race condition where a message arrives between + # bulk_add_subscriptions above and the Message query just above + already_ids = set(UserMessage.objects.filter(message_id__in=message_ids_to_use, + user_profile=user_profile).values_list("message_id", flat=True)) + ums_to_create = [UserMessage(user_profile=user_profile, message_id=message_id, + flags=UserMessage.flags.read) + for message_id in message_ids_to_use + if message_id not in already_ids] + + UserMessage.objects.bulk_create(ums_to_create) # Does the processing for a new user account: # * Subscribes to default/invitation streams diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index 6862967f4a..7949421921 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -14,6 +14,7 @@ from zerver.views.invite import get_invitee_emails_set from zerver.models import ( get_realm, get_realm_by_string_id, get_prereg_user_by_email, get_user_profile_by_email, PreregistrationUser, Realm, RealmAlias, Recipient, ScheduledJob, UserProfile, UserMessage, + Stream, Subscription, ) from zerver.lib.actions import ( @@ -22,7 +23,8 @@ from zerver.lib.actions import ( ) from zerver.lib.initial_password import initial_password -from zerver.lib.actions import do_deactivate_realm, do_set_realm_default_language +from zerver.lib.actions import do_deactivate_realm, do_set_realm_default_language, \ + add_new_user_history from zerver.lib.digest import send_digest_email from zerver.lib.notifications import enqueue_welcome_emails, one_click_unsubscribe_link from zerver.lib.test_helpers import find_key_by_email, queries_captured, \ @@ -124,6 +126,22 @@ class PublicURLTest(ZulipTestCase): self.assertEqual('success', data['result']) self.assertEqual('ABCD', data['google_client_id']) +class AddNewUserHistoryTest(ZulipTestCase): + def test_add_new_user_history_race(self): + # type: () -> None + """Sends a message during user creation""" + # Create a user who hasn't had historical messages added + set_default_streams(get_realm_by_string_id("zulip"), ["Denmark", "Verona"]) + with patch("zerver.lib.actions.add_new_user_history"): + self.register("test", "test") + user_profile = get_user_profile_by_email("test@zulip.com") + + subs = Subscription.objects.select_related("recipient").filter( + user_profile=user_profile, recipient__type=Recipient.STREAM) + streams = Stream.objects.filter(id__in=[sub.recipient.type_id for sub in subs]) + self.send_message("hamlet@zulip.com", streams[0].name, Recipient.STREAM, "test") + add_new_user_history(user_profile, streams) + class PasswordResetTest(ZulipTestCase): """ Log in, reset password, log out, log in with new password.