diff --git a/zerver/data_import/mattermost.py b/zerver/data_import/mattermost.py index 9f3e6a5fbf..1eed45ae59 100644 --- a/zerver/data_import/mattermost.py +++ b/zerver/data_import/mattermost.py @@ -408,7 +408,9 @@ def process_raw_message_batch( subscriber_map: dict[int, set[int]], user_id_mapper: IdMapper[str], user_handler: UserHandler, - get_recipient_id_from_receiver_name: Callable[[str, int], int], + get_recipient_id_from_channel_name: Callable[[str], int], + get_recipient_id_from_direct_message_group_name: Callable[[str], int], + get_recipient_id_from_username: Callable[[str], int], is_pm_data: bool, output_dir: str, zerver_realmemoji: list[dict[str, Any]], @@ -456,21 +458,19 @@ def process_raw_message_batch( date_sent = raw_message["date_sent"] sender_user_id = raw_message["sender_id"] if "channel_name" in raw_message: - recipient_id = get_recipient_id_from_receiver_name( - raw_message["channel_name"], Recipient.STREAM - ) + recipient_id = get_recipient_id_from_channel_name(raw_message["channel_name"]) elif "huddle_name" in raw_message: - recipient_id = get_recipient_id_from_receiver_name( - raw_message["huddle_name"], Recipient.DIRECT_MESSAGE_GROUP + recipient_id = get_recipient_id_from_direct_message_group_name( + raw_message["huddle_name"] ) elif "pm_members" in raw_message: members = raw_message["pm_members"] member_ids = {user_id_mapper.get(member) for member in members} pm_members[message_id] = member_ids if sender_user_id == user_id_mapper.get(members[0]): - recipient_id = get_recipient_id_from_receiver_name(members[1], Recipient.PERSONAL) + recipient_id = get_recipient_id_from_username(members[1]) else: - recipient_id = get_recipient_id_from_receiver_name(members[0], Recipient.PERSONAL) + recipient_id = get_recipient_id_from_username(members[0]) else: raise AssertionError("raw_message without channel_name, huddle_name or pm_members key") @@ -544,7 +544,9 @@ def process_posts( team_name: str, realm_id: int, post_data: list[dict[str, Any]], - get_recipient_id_from_receiver_name: Callable[[str, int], int], + get_recipient_id_from_channel_name: Callable[[str], int], + get_recipient_id_from_direct_message_group_name: Callable[[str], int], + get_recipient_id_from_username: Callable[[str], int], subscriber_map: dict[int, set[int]], output_dir: str, is_pm_data: bool, @@ -630,7 +632,9 @@ def process_posts( subscriber_map=subscriber_map, user_id_mapper=user_id_mapper, user_handler=user_handler, - get_recipient_id_from_receiver_name=get_recipient_id_from_receiver_name, + get_recipient_id_from_channel_name=get_recipient_id_from_channel_name, + get_recipient_id_from_direct_message_group_name=get_recipient_id_from_direct_message_group_name, + get_recipient_id_from_username=get_recipient_id_from_username, is_pm_data=is_pm_data, output_dir=output_dir, zerver_realmemoji=zerver_realmemoji, @@ -680,19 +684,17 @@ def write_message_data( if d["type"] == Recipient.PERSONAL: user_id_to_recipient_id[d["type_id"]] = d["id"] - def get_recipient_id_from_receiver_name(receiver_name: str, recipient_type: int) -> int: - if recipient_type == Recipient.STREAM: - receiver_id = stream_id_mapper.get(receiver_name) - recipient_id = stream_id_to_recipient_id[receiver_id] - elif recipient_type == Recipient.DIRECT_MESSAGE_GROUP: - receiver_id = huddle_id_mapper.get(receiver_name) - recipient_id = huddle_id_to_recipient_id[receiver_id] - elif recipient_type == Recipient.PERSONAL: - receiver_id = user_id_mapper.get(receiver_name) - recipient_id = user_id_to_recipient_id[receiver_id] - else: - raise AssertionError("Invalid recipient_type") - return recipient_id + def get_recipient_id_from_channel_name(channel_name: str) -> int: + receiver_id = stream_id_mapper.get(channel_name) + return stream_id_to_recipient_id[receiver_id] + + def get_recipient_id_from_direct_message_group_name(direct_message_group_name: str) -> int: + receiver_id = huddle_id_mapper.get(direct_message_group_name) + return huddle_id_to_recipient_id[receiver_id] + + def get_recipient_id_from_username(username: str) -> int: + receiver_id = user_id_mapper.get(username) + return user_id_to_recipient_id[receiver_id] if num_teams == 1: post_types = ["channel_post", "direct_post"] @@ -708,7 +710,9 @@ def write_message_data( team_name=team_name, realm_id=realm_id, post_data=post_data[post_type], - get_recipient_id_from_receiver_name=get_recipient_id_from_receiver_name, + get_recipient_id_from_channel_name=get_recipient_id_from_channel_name, + get_recipient_id_from_direct_message_group_name=get_recipient_id_from_direct_message_group_name, + get_recipient_id_from_username=get_recipient_id_from_username, subscriber_map=subscriber_map, output_dir=output_dir, is_pm_data=post_type == "direct_post",