mirror of
https://github.com/zulip/zulip.git
synced 2025-11-15 11:22:04 +00:00
274 lines
9.2 KiB
Python
274 lines
9.2 KiB
Python
import glob
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
from collections import Counter
|
|
from multiprocessing import current_process
|
|
|
|
from django.db import connection
|
|
|
|
from zerver.lib.parallel import _disconnect, run_parallel
|
|
from zerver.lib.partial import partial
|
|
from zerver.lib.test_classes import ZulipTestCase
|
|
from zerver.models import Realm
|
|
|
|
|
|
class RunNotParallelTest(ZulipTestCase):
|
|
def test_disconnect(self) -> None:
|
|
self.assertTrue(connection.is_usable())
|
|
self.assertEqual(Realm.objects.count(), 4)
|
|
_disconnect()
|
|
self.assertFalse(connection.is_usable())
|
|
|
|
def test_not_parallel(self) -> None:
|
|
# Nothing here is parallel, or forks at all
|
|
events = []
|
|
|
|
run_parallel(
|
|
lambda item: events.append(f"Item: {item}"),
|
|
range(100, 110),
|
|
processes=1,
|
|
initializer=lambda a, b: events.append(f"Init: {a}, {b}"),
|
|
initargs=("alpha", "bravo"),
|
|
report_every=3,
|
|
report=lambda n: events.append(f"Completed {n}"),
|
|
)
|
|
|
|
self.assertEqual(
|
|
events,
|
|
[
|
|
"Init: alpha, bravo",
|
|
"Item: 100",
|
|
"Item: 101",
|
|
"Item: 102",
|
|
"Completed 3",
|
|
"Item: 103",
|
|
"Item: 104",
|
|
"Item: 105",
|
|
"Completed 6",
|
|
"Item: 106",
|
|
"Item: 107",
|
|
"Item: 108",
|
|
"Completed 9",
|
|
"Item: 109",
|
|
],
|
|
)
|
|
|
|
def test_not_parallel_throw(self) -> None:
|
|
events = []
|
|
|
|
def do_work(item: int) -> None:
|
|
if item == 103:
|
|
raise Exception("I don't like threes")
|
|
events.append(f"Item: {item}")
|
|
|
|
with self.assertRaisesRegex(Exception, "I don't like threes"):
|
|
run_parallel(
|
|
do_work,
|
|
range(100, 110),
|
|
processes=1,
|
|
report_every=5,
|
|
report=lambda n: events.append(f"Completed {n}"),
|
|
catch=False,
|
|
)
|
|
|
|
self.assertEqual(
|
|
events,
|
|
[
|
|
"Item: 100",
|
|
"Item: 101",
|
|
"Item: 102",
|
|
],
|
|
)
|
|
|
|
def test_not_parallel_catch(self) -> None:
|
|
events = []
|
|
|
|
def do_work(item: int) -> None:
|
|
if item == 103:
|
|
raise Exception("I don't like threes")
|
|
events.append(f"Item: {item}")
|
|
|
|
with self.assertLogs(level="ERROR") as error_logs:
|
|
run_parallel(
|
|
do_work,
|
|
range(100, 105),
|
|
processes=1,
|
|
report_every=5,
|
|
report=lambda n: events.append(f"Completed {n}"),
|
|
catch=True,
|
|
)
|
|
|
|
self.assert_length(error_logs.output, 1)
|
|
self.assertTrue(
|
|
error_logs.output[0].startswith("ERROR:root:Error processing item: 103\nTraceback")
|
|
)
|
|
self.assertIn("I don't like threes", error_logs.output[0])
|
|
|
|
self.assertEqual(
|
|
events,
|
|
[
|
|
"Item: 100",
|
|
"Item: 101",
|
|
"Item: 102",
|
|
"Item: 104",
|
|
# We "completed" the one which raised an exception,
|
|
# despite it not having output
|
|
"Completed 5",
|
|
],
|
|
)
|
|
|
|
|
|
def write_number(
|
|
output_dir: str, total_processes: int, fail: set[int], item: int
|
|
) -> None: # nocoverage
|
|
if item in fail:
|
|
raise Exception("Whoops")
|
|
|
|
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
|
|
fh.write(f"{item}\n")
|
|
# We wait to exit until we see total_processes unique files in the
|
|
# output directory, so we ensure that every PID got a chance to
|
|
# run.
|
|
slept = 0
|
|
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
|
|
time.sleep(1)
|
|
slept += 1
|
|
|
|
|
|
def db_query(output_dir: str, total_processes: int, item: int) -> None: # nocoverage
|
|
connection.connect()
|
|
with open(f"{output_dir}/{os.getpid()}.output", "a") as fh:
|
|
fh.write(f"{Realm.objects.count()}\n")
|
|
slept = 0
|
|
while len(glob.glob(f"{output_dir}/*.output")) < total_processes and slept < 5:
|
|
time.sleep(1)
|
|
slept += 1
|
|
|
|
|
|
class RunParallelTest(ZulipTestCase):
|
|
def skip_in_parallel_harness(self) -> None:
|
|
if current_process().daemon:
|
|
self.skipTest("Testing of parallel pool is skipped under the parallel test harness")
|
|
|
|
def test_parallel(self) -> None: # nocoverage
|
|
self.skip_in_parallel_harness()
|
|
|
|
output_dir = tempfile.mkdtemp()
|
|
report_lines = []
|
|
try:
|
|
run_parallel(
|
|
partial(write_number, output_dir, 4, set()),
|
|
range(100, 110),
|
|
processes=4,
|
|
report_every=3,
|
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
|
)
|
|
|
|
files = glob.glob(f"{output_dir}/*.output")
|
|
self.assert_length(files, 4)
|
|
all_lines: Counter[str] = Counter()
|
|
for output_path in files:
|
|
with open(output_path) as output_file:
|
|
file_lines = output_file.readlines()
|
|
self.assertGreater(len(file_lines), 0)
|
|
self.assertLessEqual(len(file_lines), 10 - (4 - 1))
|
|
self.assertEqual(sorted(file_lines), file_lines)
|
|
all_lines.update(file_lines)
|
|
|
|
self.assertEqual(all_lines.total(), 10)
|
|
self.assertEqual(sorted(all_lines.keys()), [f"{n}\n" for n in range(100, 110)])
|
|
|
|
self.assertEqual(report_lines, ["Completed 3", "Completed 6", "Completed 9"])
|
|
finally:
|
|
shutil.rmtree(output_dir)
|
|
|
|
def test_parallel_throw(self) -> None: # nocoverage
|
|
self.skip_in_parallel_harness()
|
|
output_dir = tempfile.mkdtemp()
|
|
report_lines = []
|
|
try:
|
|
with self.assertRaisesMessage(Exception, "Whoops"):
|
|
run_parallel(
|
|
partial(write_number, output_dir, 4, {103}),
|
|
range(100, 105),
|
|
processes=2,
|
|
report_every=5,
|
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
|
)
|
|
output_files = glob.glob(f"{output_dir}/*.output")
|
|
self.assert_length(output_files, 2)
|
|
all_lines: set[int] = set()
|
|
for output_path in output_files:
|
|
with open(output_path) as output_file:
|
|
all_lines.update(int(line) for line in output_file)
|
|
self.assertIn(100, all_lines)
|
|
self.assertIn(101, all_lines)
|
|
self.assertNotIn(103, all_lines)
|
|
self.assertEqual(report_lines, [])
|
|
finally:
|
|
shutil.rmtree(output_dir)
|
|
|
|
def test_parallel_catch(self) -> None: # nocoverage
|
|
self.skip_in_parallel_harness()
|
|
output_dir = tempfile.mkdtemp()
|
|
report_lines = []
|
|
|
|
def set_file_logger(output_dir: str) -> None:
|
|
# In each worker process, we set up the logger to write to
|
|
# a (pid).error file.
|
|
logging.basicConfig(
|
|
filename=f"{output_dir}/{os.getpid()}.error",
|
|
level=logging.INFO,
|
|
filemode="w",
|
|
force=True,
|
|
)
|
|
|
|
try:
|
|
run_parallel(
|
|
partial(write_number, output_dir, 4, {103}),
|
|
range(100, 105),
|
|
processes=2,
|
|
report_every=5,
|
|
report=lambda n: report_lines.append(f"Completed {n}"),
|
|
catch=True,
|
|
initializer=set_file_logger,
|
|
initargs=(output_dir,),
|
|
)
|
|
output_files = glob.glob(f"{output_dir}/*.output")
|
|
self.assert_length(output_files, 2)
|
|
all_lines: set[int] = set()
|
|
for output_path in output_files:
|
|
with open(output_path) as output_file:
|
|
all_lines.update(int(line) for line in output_file)
|
|
self.assertEqual(sorted(all_lines), [100, 101, 102, 104])
|
|
self.assertEqual(report_lines, ["Completed 5"])
|
|
|
|
error_files = glob.glob(f"{output_dir}/*.error")
|
|
error_lines = []
|
|
self.assert_length(error_files, 2)
|
|
for error_path in error_files:
|
|
with open(error_path) as error_file:
|
|
error_lines.extend(error_file.readlines())
|
|
self.assertEqual(error_lines[0], "ERROR:root:Error processing item: 103\n")
|
|
finally:
|
|
shutil.rmtree(output_dir)
|
|
|
|
def test_parallel_reconnect(self) -> None: # nocoverage
|
|
self.skip_in_parallel_harness()
|
|
output_dir = tempfile.mkdtemp()
|
|
run_parallel(
|
|
partial(db_query, output_dir, 2),
|
|
range(100, 105),
|
|
processes=2,
|
|
)
|
|
output_files = glob.glob(f"{output_dir}/*.output")
|
|
self.assert_length(output_files, 2)
|
|
all_lines: set[int] = set()
|
|
for output_path in output_files:
|
|
with open(output_path) as output_file:
|
|
all_lines.update(int(line) for line in output_file)
|
|
self.assertEqual(all_lines, {4})
|