diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 9b08a49808..1620e2e1c6 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -2438,6 +2438,10 @@ def do_update_message(user_profile, message_id, subject, propagate_mode, content event["content"] = content event["rendered_content"] = rendered_content + prev_content = edit_history_event['prev_content'] + if Message.content_has_attachment(prev_content) or Message.content_has_attachment(message.content): + check_attachment_reference_change(prev_content, message) + if subject is not None: orig_subject = message.subject subject = truncate_topic(subject) @@ -3302,3 +3306,32 @@ def do_delete_old_unclaimed_attachments(weeks_ago): for attachment in old_unclaimed_attachments: delete_message_image(attachment.path_id) attachment.delete() + +def check_attachment_reference_change(prev_content, message): + new_content = message.content + attachment_url_re = re.compile(u'[/\-]user[\-_]uploads[/\.-].*?(?=[ )]|\Z)') + prev_attachments = set(attachment_url_re.findall(prev_content)) + new_attachments = set(attachment_url_re.findall(new_content)) + + to_remove = list(prev_attachments - new_attachments) + path_ids = [] + for url in to_remove: + path_id = re.sub(u'[/\-]user[\-_]uploads[/\.-]', u'', url) + # Remove any extra '.' after file extension. These are probably added by the user + path_id = re.sub(u'[.]+$', u'', path_id, re.M) + path_ids.append(path_id) + + attachments_to_update = Attachment.objects.filter(path_id__in=path_ids).select_for_update() + for attachment in attachments_to_update: + try: + attachment = Attachment.objects.get(path_id=path_id) + attachment.messages.remove(message) + attachment.save() + except Attachment.DoesNotExist: + # The entry for this attachment does not exist. Just ignore. + pass + + to_add = list(new_attachments - prev_attachments) + if len(to_add) > 1: + do_claim_attachments(message) + diff --git a/zerver/tests/test_upload.py b/zerver/tests/test_upload.py index 6a2646cd88..3994411b1e 100644 --- a/zerver/tests/test_upload.py +++ b/zerver/tests/test_upload.py @@ -11,7 +11,7 @@ from zerver.lib.upload import sanitize_name, S3UploadBackend, \ upload_message_image, delete_message_image, LocalUploadBackend import zerver.lib.upload from zerver.models import Attachment, Recipient, get_user_profile_by_email, \ - get_old_unclaimed_attachments + get_old_unclaimed_attachments, Message from zerver.lib.actions import do_delete_old_unclaimed_attachments import ujson @@ -189,6 +189,52 @@ class FileUploadTest(AuthedTestCase): self.assertEquals(Attachment.objects.get(path_id=d1_path_id).messages.count(), 2) + def test_check_attachment_reference_update(self): + f1 = StringIO("file1") + f1.name = "file1.txt" + f2 = StringIO("file2") + f2.name = "file2.txt" + f3 = StringIO("file3") + f3.name = "file3.txt" + + self.login("hamlet@zulip.com") + result = self.client.post("/json/upload_file", {'file': f1}) + json = ujson.loads(result.content) + uri = json["uri"] + f1_path_id = re.sub('/user_uploads/', '', uri) + + result = self.client.post("/json/upload_file", {'file': f2}) + json = ujson.loads(result.content) + uri = json["uri"] + f2_path_id = re.sub('/user_uploads/', '', uri) + + self.subscribe_to_stream("hamlet@zulip.com", "test") + body = ("[f1.txt](http://localhost:9991/user_uploads/" + f1_path_id + ")" + "[f2.txt](http://localhost:9991/user_uploads/" + f2_path_id + ")") + msg_id = self.send_message("hamlet@zulip.com", "test", Recipient.STREAM, body, "test") + + result = self.client.post("/json/upload_file", {'file': f3}) + json = ujson.loads(result.content) + uri = json["uri"] + f3_path_id = re.sub('/user_uploads/', '', uri) + + new_body = ("[f3.txt](http://localhost:9991/user_uploads/" + f3_path_id + ")" + "[f2.txt](http://localhost:9991/user_uploads/" + f2_path_id + ")") + result = self.client.post("/json/update_message", { + 'message_id': msg_id, + 'content': new_body + }) + self.assert_json_success(result) + + message = Message.objects.get(id=msg_id) + f1_attachment = Attachment.objects.get(path_id=f1_path_id) + f2_attachment = Attachment.objects.get(path_id=f2_path_id) + f3_attachment = Attachment.objects.get(path_id=f2_path_id) + + self.assertTrue(message not in f1_attachment.messages.all()) + self.assertTrue(message in f2_attachment.messages.all()) + self.assertTrue(message in f3_attachment.messages.all()) + def tearDown(self): # type: () -> None destroy_uploads()