Revert "create_user: Use transaction.atomic decorator for do_create_user."

This reverts commit 851d68e0fc.

That commit widened how long the transaction is open, which made it
much more likely that after the user was created in the transaction,
and the memcached caches were flushed, some other request will fill
the `get_realm_user_dicts` cache with data which did not include the
new user (because it had not been committed yet).

If a user creation request lost this race, the user would, upon first
request to `/`, get a blank page and a Javascript error:

    Unknown user_id in get_by_user_id: 12345

...where 12345 was their own user-id.  This error would persist until
the cache expired (in 7 days) or something else expunged it.

Reverting this does not prevent the race, as the post_save hook's call
to flush_user_profile is still in a transaction (and has been since
168f241ff0), and thus leaves the potential race window open.
However, it much shortens the potential window of opportunity, and is
a reasonable short-term stopgap.
This commit is contained in:
Alex Vandiver
2023-02-17 20:44:51 -05:00
committed by Alex Vandiver
parent 7feda75c5f
commit 8998aa00cd
10 changed files with 145 additions and 199 deletions

View File

@@ -368,7 +368,6 @@ def notify_created_bot(user_profile: UserProfile) -> None:
)
@transaction.atomic(durable=True)
def do_create_user(
email: str,
password: Optional[str],
@@ -392,74 +391,75 @@ def do_create_user(
acting_user: Optional[UserProfile],
enable_marketing_emails: bool = True,
) -> UserProfile:
user_profile = create_user(
email=email,
password=password,
realm=realm,
full_name=full_name,
role=role,
bot_type=bot_type,
bot_owner=bot_owner,
tos_version=tos_version,
timezone=timezone,
avatar_source=avatar_source,
default_language=default_language,
default_sending_stream=default_sending_stream,
default_events_register_stream=default_events_register_stream,
default_all_public_streams=default_all_public_streams,
source_profile=source_profile,
enable_marketing_emails=enable_marketing_emails,
)
event_time = user_profile.date_joined
if not acting_user:
acting_user = user_profile
RealmAuditLog.objects.create(
realm=user_profile.realm,
acting_user=acting_user,
modified_user=user_profile,
event_type=RealmAuditLog.USER_CREATED,
event_time=event_time,
extra_data=orjson.dumps(
{
RealmAuditLog.ROLE_COUNT: realm_user_count_by_role(user_profile.realm),
}
).decode(),
)
if realm_creation:
# If this user just created a realm, make sure they are
# properly tagged as the creator of the realm.
realm_creation_audit_log = (
RealmAuditLog.objects.filter(event_type=RealmAuditLog.REALM_CREATED, realm=realm)
.order_by("id")
.last()
with transaction.atomic():
user_profile = create_user(
email=email,
password=password,
realm=realm,
full_name=full_name,
role=role,
bot_type=bot_type,
bot_owner=bot_owner,
tos_version=tos_version,
timezone=timezone,
avatar_source=avatar_source,
default_language=default_language,
default_sending_stream=default_sending_stream,
default_events_register_stream=default_events_register_stream,
default_all_public_streams=default_all_public_streams,
source_profile=source_profile,
enable_marketing_emails=enable_marketing_emails,
)
assert realm_creation_audit_log is not None
realm_creation_audit_log.acting_user = user_profile
realm_creation_audit_log.save(update_fields=["acting_user"])
do_increment_logging_stat(
user_profile.realm,
COUNT_STATS["active_users_log:is_bot:day"],
user_profile.is_bot,
event_time,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
system_user_group = get_system_user_group_for_user(user_profile)
UserGroupMembership.objects.create(user_profile=user_profile, user_group=system_user_group)
if user_profile.role == UserProfile.ROLE_MEMBER and not user_profile.is_provisional_member:
full_members_system_group = UserGroup.objects.get(
name=UserGroup.FULL_MEMBERS_GROUP_NAME,
event_time = user_profile.date_joined
if not acting_user:
acting_user = user_profile
RealmAuditLog.objects.create(
realm=user_profile.realm,
is_system_group=True,
acting_user=acting_user,
modified_user=user_profile,
event_type=RealmAuditLog.USER_CREATED,
event_time=event_time,
extra_data=orjson.dumps(
{
RealmAuditLog.ROLE_COUNT: realm_user_count_by_role(user_profile.realm),
}
).decode(),
)
UserGroupMembership.objects.create(
user_profile=user_profile, user_group=full_members_system_group
if realm_creation:
# If this user just created a realm, make sure they are
# properly tagged as the creator of the realm.
realm_creation_audit_log = (
RealmAuditLog.objects.filter(event_type=RealmAuditLog.REALM_CREATED, realm=realm)
.order_by("id")
.last()
)
assert realm_creation_audit_log is not None
realm_creation_audit_log.acting_user = user_profile
realm_creation_audit_log.save(update_fields=["acting_user"])
do_increment_logging_stat(
user_profile.realm,
COUNT_STATS["active_users_log:is_bot:day"],
user_profile.is_bot,
event_time,
)
if settings.BILLING_ENABLED:
update_license_ledger_if_needed(user_profile.realm, event_time)
system_user_group = get_system_user_group_for_user(user_profile)
UserGroupMembership.objects.create(user_profile=user_profile, user_group=system_user_group)
if user_profile.role == UserProfile.ROLE_MEMBER and not user_profile.is_provisional_member:
full_members_system_group = UserGroup.objects.get(
name=UserGroup.FULL_MEMBERS_GROUP_NAME,
realm=user_profile.realm,
is_system_group=True,
)
UserGroupMembership.objects.create(
user_profile=user_profile, user_group=full_members_system_group
)
# Note that for bots, the caller will send an additional event
# with bot-specific info like services.

View File

@@ -36,7 +36,7 @@ from zerver.tornado.django_api import send_event
def notify_invites_changed(realm: Realm) -> None:
event = dict(type="invites_changed")
admin_ids = [user.id for user in realm.get_admin_users_and_bots()]
transaction.on_commit(lambda: send_event(realm, event, admin_ids))
send_event(realm, event, admin_ids)
def do_send_confirmation_email(

View File

@@ -941,11 +941,7 @@ def do_send_messages(
event["local_id"] = send_request.local_id
if send_request.sender_queue_id is not None:
event["sender_queue_id"] = send_request.sender_queue_id
transaction.on_commit(
lambda event=event, users=users, realm=send_request.realm: send_event(
realm, event, users
)
)
send_event(send_request.realm, event, users)
if send_request.links_for_embed:
event_data = {
@@ -954,9 +950,7 @@ def do_send_messages(
"message_realm_id": send_request.realm.id,
"urls": list(send_request.links_for_embed),
}
transaction.on_commit(
lambda event_data=event_data: queue_json_publish("embed_links", event_data)
)
queue_json_publish("embed_links", event_data)
if send_request.message.recipient.type == Recipient.PERSONAL:
welcome_bot_id = get_system_bot(
@@ -973,15 +967,13 @@ def do_send_messages(
assert send_request.service_queue_events is not None
for queue_name, events in send_request.service_queue_events.items():
for event in events:
transaction.on_commit(
lambda event=event, queue_name=queue_name, wide_message_dict=wide_message_dict: queue_json_publish(
queue_name,
{
"message": wide_message_dict,
"trigger": event["trigger"],
"user_profile_id": event["user_profile_id"],
},
)
queue_json_publish(
queue_name,
{
"message": wide_message_dict,
"trigger": event["trigger"],
"user_profile_id": event["user_profile_id"],
},
)
return [send_request.message.id for send_request in send_message_requests]

View File

@@ -297,9 +297,7 @@ def send_subscription_add_events(
# Send a notification to the user who subscribed.
event = dict(type="subscription", op="add", subscriptions=sub_dicts)
transaction.on_commit(
lambda event=event, user_id=user_id: send_event(realm, event, [user_id])
)
send_event(realm, event, [user_id])
# This function contains all the database changes as part of

View File

@@ -1,6 +1,7 @@
from typing import Dict, List
from django.conf import settings
from django.db import transaction
from django.db.models import Count
from django.utils.translation import gettext as _
from django.utils.translation import override as override_language
@@ -229,6 +230,7 @@ def send_welcome_bot_response(send_request: SendMessageRequest) -> None:
)
@transaction.atomic
def send_initial_realm_messages(realm: Realm) -> None:
welcome_bot = get_system_bot(settings.WELCOME_BOT, realm.id)
# Make sure each stream created in the realm creation process has at least one message below

View File

@@ -5,7 +5,7 @@ import shutil
import subprocess
import tempfile
import urllib
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
from datetime import timedelta
from typing import (
TYPE_CHECKING,
@@ -967,22 +967,18 @@ Output:
to_user: UserProfile,
content: str = "test content",
sending_client_name: str = "test suite",
capture_on_commit_callbacks: bool = True,
) -> int:
recipient_list = [to_user.id]
(sending_client, _) = Client.objects.get_or_create(name=sending_client_name)
with self.captureOnCommitCallbacks(
execute=True
) if capture_on_commit_callbacks else nullcontext():
return check_send_message(
from_user,
sending_client,
"private",
recipient_list,
None,
content,
)
return check_send_message(
from_user,
sending_client,
"private",
recipient_list,
None,
content,
)
def send_huddle_message(
self,
@@ -990,24 +986,20 @@ Output:
to_users: List[UserProfile],
content: str = "test content",
sending_client_name: str = "test suite",
capture_on_commit_callbacks: bool = True,
) -> int:
to_user_ids = [u.id for u in to_users]
assert len(to_user_ids) >= 2
(sending_client, _) = Client.objects.get_or_create(name=sending_client_name)
with self.captureOnCommitCallbacks(
execute=True
) if capture_on_commit_callbacks else nullcontext():
return check_send_message(
from_user,
sending_client,
"private",
to_user_ids,
None,
content,
)
return check_send_message(
from_user,
sending_client,
"private",
to_user_ids,
None,
content,
)
def send_stream_message(
self,
@@ -1018,21 +1010,17 @@ Output:
recipient_realm: Optional[Realm] = None,
sending_client_name: str = "test suite",
allow_unsubscribed_sender: bool = False,
capture_on_commit_callbacks: bool = True,
) -> int:
(sending_client, _) = Client.objects.get_or_create(name=sending_client_name)
with self.captureOnCommitCallbacks(
execute=True
) if capture_on_commit_callbacks else nullcontext():
message_id = check_send_stream_message(
sender=sender,
client=sending_client,
stream_name=stream_name,
topic=topic_name,
body=content,
realm=recipient_realm,
)
message_id = check_send_stream_message(
sender=sender,
client=sending_client,
stream_name=stream_name,
topic=topic_name,
body=content,
realm=recipient_realm,
)
if (
not UserMessage.objects.filter(user_profile=sender, message_id=message_id).exists()
and not sender.is_bot

View File

@@ -679,11 +679,9 @@ class MissedMessageHookTest(ZulipTestCase):
def test_disable_external_notifications(self) -> None:
# The disable_external_notifications parameter, used for messages sent by welcome bot,
# should result in no email/push notifications being sent regardless of the message type.
with self.captureOnCommitCallbacks(execute=True):
msg_id = internal_send_private_message(
self.iago, self.user_profile, "Test Content", disable_external_notifications=True
)
msg_id = internal_send_private_message(
self.iago, self.user_profile, "Test Content", disable_external_notifications=True
)
assert msg_id is not None
with mock.patch("zerver.tornado.event_queue.maybe_enqueue_notifications") as mock_enqueue:
missedmessage_hook(self.user_profile.id, self.client_descriptor, True)

View File

@@ -321,17 +321,16 @@ class GetEventsTest(ZulipTestCase):
self.assert_length(events, 0)
local_id = "10.01"
with self.captureOnCommitCallbacks(execute=True):
check_send_message(
sender=user_profile,
client=get_client("whatever"),
message_type_name="private",
message_to=[recipient_email],
topic_name=None,
message_content="hello",
local_id=local_id,
sender_queue_id=queue_id,
)
check_send_message(
sender=user_profile,
client=get_client("whatever"),
message_type_name="private",
message_to=[recipient_email],
topic_name=None,
message_content="hello",
local_id=local_id,
sender_queue_id=queue_id,
)
result = self.tornado_call(
get_events,
@@ -355,17 +354,16 @@ class GetEventsTest(ZulipTestCase):
last_event_id = events[0]["id"]
local_id = "10.02"
with self.captureOnCommitCallbacks(execute=True):
check_send_message(
sender=user_profile,
client=get_client("whatever"),
message_type_name="private",
message_to=[recipient_email],
topic_name=None,
message_content="hello",
local_id=local_id,
sender_queue_id=queue_id,
)
check_send_message(
sender=user_profile,
client=get_client("whatever"),
message_type_name="private",
message_to=[recipient_email],
topic_name=None,
message_content="hello",
local_id=local_id,
sender_queue_id=queue_id,
)
result = self.tornado_call(
get_events,

View File

@@ -440,33 +440,20 @@ class NormalActionsTest(BaseAction):
for i in range(3):
content = "mentioning... @**" + user.full_name + "** hello " + str(i)
self.verify_action(
lambda: self.send_stream_message(
self.example_user("cordelia"),
"Verona",
content,
capture_on_commit_callbacks=False,
),
lambda: self.send_stream_message(self.example_user("cordelia"), "Verona", content),
)
def test_wildcard_mentioned_send_message_events(self) -> None:
for i in range(3):
content = "mentioning... @**all** hello " + str(i)
self.verify_action(
lambda: self.send_stream_message(
self.example_user("cordelia"),
"Verona",
content,
capture_on_commit_callbacks=False,
),
lambda: self.send_stream_message(self.example_user("cordelia"), "Verona", content),
)
def test_pm_send_message_events(self) -> None:
self.verify_action(
lambda: self.send_personal_message(
self.example_user("cordelia"),
self.example_user("hamlet"),
"hola",
capture_on_commit_callbacks=False,
self.example_user("cordelia"), self.example_user("hamlet"), "hola"
),
)
@@ -513,16 +500,12 @@ class NormalActionsTest(BaseAction):
self.example_user("othello"),
]
self.verify_action(
lambda: self.send_huddle_message(
self.example_user("cordelia"), huddle, "hola", capture_on_commit_callbacks=False
),
lambda: self.send_huddle_message(self.example_user("cordelia"), huddle, "hola"),
)
def test_stream_send_message_events(self) -> None:
events = self.verify_action(
lambda: self.send_stream_message(
self.example_user("hamlet"), "Verona", "hello", capture_on_commit_callbacks=False
),
lambda: self.send_stream_message(self.example_user("hamlet"), "Verona", "hello"),
client_gravatar=False,
)
check_message("events[0]", events[0])
@@ -536,9 +519,7 @@ class NormalActionsTest(BaseAction):
)
events = self.verify_action(
lambda: self.send_stream_message(
self.example_user("hamlet"), "Verona", "hello", capture_on_commit_callbacks=False
),
lambda: self.send_stream_message(self.example_user("hamlet"), "Verona", "hello"),
client_gravatar=True,
)
check_message("events[0]", events[0])
@@ -749,16 +730,12 @@ class NormalActionsTest(BaseAction):
"hello 1",
)
self.verify_action(
lambda: self.send_stream_message(
sender, "Verona", "hello 2", capture_on_commit_callbacks=False
),
lambda: self.send_stream_message(sender, "Verona", "hello 2"),
state_change_expected=True,
)
def test_add_reaction(self) -> None:
message_id = self.send_stream_message(
self.example_user("hamlet"), "Verona", "hello", capture_on_commit_callbacks=False
)
message_id = self.send_stream_message(self.example_user("hamlet"), "Verona", "hello")
message = Message.objects.get(id=message_id)
events = self.verify_action(
lambda: do_add_reaction(self.user_profile, message, "tada", "1f389", "unicode_emoji"),
@@ -933,7 +910,7 @@ class NormalActionsTest(BaseAction):
num_events=7,
)
check_invites_changed("events[5]", events[5])
check_invites_changed("events[1]", events[1])
def test_typing_events(self) -> None:
events = self.verify_action(
@@ -1151,20 +1128,20 @@ class NormalActionsTest(BaseAction):
events = self.verify_action(lambda: self.register("test1@zulip.com", "test1"), num_events=5)
self.assert_length(events, 5)
check_realm_user_add("events[0]", events[0])
check_realm_user_add("events[1]", events[1])
new_user_profile = get_user_by_delivery_email("test1@zulip.com", self.user_profile.realm)
self.assertEqual(new_user_profile.delivery_email, "test1@zulip.com")
check_subscription_peer_add("events[3]", events[3])
check_subscription_peer_add("events[4]", events[4])
check_message("events[4]", events[4])
check_message("events[0]", events[0])
self.assertIn(
f'data-user-id="{new_user_profile.id}">test1_zulip.com</span> just signed up for Zulip',
events[4]["message"]["content"],
events[0]["message"]["content"],
)
check_user_group_add_members("events[1]", events[1])
check_user_group_add_members("events[2]", events[2])
check_user_group_add_members("events[3]", events[3])
def test_register_events_email_address_visibility(self) -> None:
realm_user_default = RealmUserDefault.objects.get(realm=self.user_profile.realm)
@@ -1177,20 +1154,20 @@ class NormalActionsTest(BaseAction):
events = self.verify_action(lambda: self.register("test1@zulip.com", "test1"), num_events=5)
self.assert_length(events, 5)
check_realm_user_add("events[0]", events[0])
check_realm_user_add("events[1]", events[1])
new_user_profile = get_user_by_delivery_email("test1@zulip.com", self.user_profile.realm)
self.assertEqual(new_user_profile.email, f"user{new_user_profile.id}@zulip.testserver")
check_subscription_peer_add("events[3]", events[3])
check_subscription_peer_add("events[4]", events[4])
check_message("events[4]", events[4])
check_message("events[0]", events[0])
self.assertIn(
f'data-user-id="{new_user_profile.id}">test1_zulip.com</span> just signed up for Zulip',
events[4]["message"]["content"],
events[0]["message"]["content"],
)
check_user_group_add_members("events[1]", events[1])
check_user_group_add_members("events[2]", events[2])
check_user_group_add_members("events[3]", events[3])
def test_alert_words_events(self) -> None:
events = self.verify_action(lambda: do_add_alert_words(self.user_profile, ["alert_word"]))
@@ -2428,13 +2405,7 @@ class NormalActionsTest(BaseAction):
assert uri is not None
body = f"First message ...[zulip.txt](http://{hamlet.realm.host}" + uri + ")"
events = self.verify_action(
lambda: self.send_stream_message(
self.example_user("hamlet"),
"Denmark",
body,
"test",
capture_on_commit_callbacks=False,
),
lambda: self.send_stream_message(self.example_user("hamlet"), "Denmark", body, "test"),
num_events=2,
)

View File

@@ -1695,7 +1695,6 @@ class StreamMessagesTest(ZulipTestCase):
user,
stream_name,
content=content,
capture_on_commit_callbacks=False,
)
users = events[0]["users"]
user_ids = {u["id"] for u in users}