code formatting

This commit is contained in:
sadnub
2022-03-29 18:04:00 -04:00
parent 7897b0ebe9
commit d8ad6c0cb0
26 changed files with 418 additions and 238 deletions

View File

@@ -321,22 +321,32 @@ class Agent(BaseAuditModel):
def is_supported_script(self, platforms: list) -> bool: def is_supported_script(self, platforms: list) -> bool:
return self.plat.lower() in platforms if platforms else True return self.plat.lower() in platforms if platforms else True
def get_checks_with_policies(self, exclude_overridden: bool = False) -> 'List[Check]': def get_checks_with_policies(
self, exclude_overridden: bool = False
) -> "List[Check]":
if exclude_overridden: if exclude_overridden:
checks = list(self.agentchecks.filter(overridden_by_policy=False)) + self.get_checks_from_policies() # type: ignore checks = list(self.agentchecks.filter(overridden_by_policy=False)) + self.get_checks_from_policies() # type: ignore
else: else:
checks = list(self.agentchecks.all()) + self.get_checks_from_policies() # type: ignore checks = list(self.agentchecks.all()) + self.get_checks_from_policies() # type: ignore
return self.add_check_results(checks) return self.add_check_results(checks)
def get_tasks_with_policies(self, exclude_synced: bool = False) -> 'List[AutomatedTask]': def get_tasks_with_policies(
self, exclude_synced: bool = False
) -> "List[AutomatedTask]":
tasks = list(self.autotasks.all()) + self.get_tasks_from_policies() # type: ignore tasks = list(self.autotasks.all()) + self.get_tasks_from_policies() # type: ignore
if exclude_synced: if exclude_synced:
return [task for task in self.add_task_results(tasks) if not task.task_result or task.task_result and task.task_result.sync_status != "synced"] return [
task
for task in self.add_task_results(tasks)
if not task.task_result
or task.task_result
and task.task_result.sync_status != "synced"
]
else: else:
return self.add_task_results(tasks) return self.add_task_results(tasks)
def get_agent_policies(self) -> 'Dict[str, Policy]': def get_agent_policies(self) -> "Dict[str, Policy]":
site_policy = getattr(self.site, f"{self.monitoring_type}_policy", None) site_policy = getattr(self.site, f"{self.monitoring_type}_policy", None)
client_policy = getattr(self.client, f"{self.monitoring_type}_policy", None) client_policy = getattr(self.client, f"{self.monitoring_type}_policy", None)
default_policy = getattr( default_policy = getattr(
@@ -589,11 +599,16 @@ class Agent(BaseAuditModel):
return None return None
def get_or_create_alert_if_needed(self, alert_template: "Optional[AlertTemplate]") -> "Optional[Alert]": def get_or_create_alert_if_needed(
self, alert_template: "Optional[AlertTemplate]"
) -> "Optional[Alert]":
from alerts.models import Alert from alerts.models import Alert
return Alert.create_or_return_availability_alert(self, skip_create=self.should_create_alert(alert_template))
def add_task_results(self, tasks: 'List[AutomatedTask]') -> 'List[AutomatedTask]': return Alert.create_or_return_availability_alert(
self, skip_create=self.should_create_alert(alert_template)
)
def add_task_results(self, tasks: "List[AutomatedTask]") -> "List[AutomatedTask]":
results = self.taskresults.all() # type: ignore results = self.taskresults.all() # type: ignore
@@ -608,7 +623,7 @@ class Agent(BaseAuditModel):
return tasks return tasks
def add_check_results(self, checks: 'List[Check]') -> 'List[Check]': def add_check_results(self, checks: "List[Check]") -> "List[Check]":
results = self.checkresults.all() # type: ignore results = self.checkresults.all() # type: ignore

View File

@@ -179,6 +179,7 @@ class AgentHistorySerializer(serializers.ModelSerializer):
model = AgentHistory model = AgentHistory
fields = "__all__" fields = "__all__"
class AgentAuditSerializer(serializers.ModelSerializer): class AgentAuditSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Agent model = Agent

View File

@@ -113,7 +113,9 @@ class Alert(models.Model):
self.save() self.save()
@classmethod @classmethod
def create_or_return_availability_alert(cls, agent: Agent, skip_create: bool = False) -> Optional[Alert]: def create_or_return_availability_alert(
cls, agent: Agent, skip_create: bool = False
) -> Optional[Alert]:
if not cls.objects.filter(agent=agent, resolved=False).exists(): if not cls.objects.filter(agent=agent, resolved=False).exists():
if skip_create: if skip_create:
return None return None
@@ -141,12 +143,15 @@ class Alert(models.Model):
except cls.DoesNotExist: except cls.DoesNotExist:
return None return None
@classmethod @classmethod
def create_or_return_check_alert(cls, check: Check, agent: Optional[Agent] = None, skip_create: bool = False) -> Optional[Alert]: def create_or_return_check_alert(
cls, check: Check, agent: Optional[Agent] = None, skip_create: bool = False
) -> Optional[Alert]:
# need to pass agent if the check is a policy # need to pass agent if the check is a policy
if not cls.objects.filter(assigned_check=check, agent=agent if check.policy else None, resolved=False).exists(): if not cls.objects.filter(
assigned_check=check, agent=agent if check.policy else None, resolved=False
).exists():
if skip_create: if skip_create:
return None return None
@@ -160,9 +165,17 @@ class Alert(models.Model):
) )
else: else:
try: try:
return cls.objects.get(assigned_check=check, agent=agent if check.policy else None, resolved=False) return cls.objects.get(
assigned_check=check,
agent=agent if check.policy else None,
resolved=False,
)
except cls.MultipleObjectsReturned: except cls.MultipleObjectsReturned:
alerts = cls.objects.filter(assigned_check=check, agent=agent if check.policy else None, resolved=False) alerts = cls.objects.filter(
assigned_check=check,
agent=agent if check.policy else None,
resolved=False,
)
last_alert = alerts[-1] last_alert = alerts[-1]
# cycle through other alerts and resolve # cycle through other alerts and resolve
@@ -175,9 +188,16 @@ class Alert(models.Model):
return None return None
@classmethod @classmethod
def create_or_return_task_alert(cls, task: AutomatedTask, agent: Optional[Agent] = None, skip_create: bool = False) -> Optional[Alert]: def create_or_return_task_alert(
cls,
task: AutomatedTask,
agent: Optional[Agent] = None,
skip_create: bool = False,
) -> Optional[Alert]:
if not cls.objects.filter(assigned_task=task, agent=agent if task.policy else None, resolved=False).exists(): if not cls.objects.filter(
assigned_task=task, agent=agent if task.policy else None, resolved=False
).exists():
if skip_create: if skip_create:
return None return None
@@ -191,9 +211,17 @@ class Alert(models.Model):
) )
else: else:
try: try:
return cls.objects.get(assigned_task=task, agent=agent if task.policy else None, resolved=False) return cls.objects.get(
assigned_task=task,
agent=agent if task.policy else None,
resolved=False,
)
except cls.MultipleObjectsReturned: except cls.MultipleObjectsReturned:
alerts = cls.objects.filter(assigned_task=task, agent=agent if task.policy else None, resolved=False) alerts = cls.objects.filter(
assigned_task=task,
agent=agent if task.policy else None,
resolved=False,
)
last_alert = alerts[-1] last_alert = alerts[-1]
# cycle through other alerts and resolve # cycle through other alerts and resolve
@@ -206,7 +234,9 @@ class Alert(models.Model):
return None return None
@classmethod @classmethod
def handle_alert_failure(cls, instance: Union[Agent, TaskResult, CheckResult]) -> None: def handle_alert_failure(
cls, instance: Union[Agent, TaskResult, CheckResult]
) -> None:
from agents.models import Agent from agents.models import Agent
from autotasks.models import TaskResult from autotasks.models import TaskResult
from checks.models import CheckResult from checks.models import CheckResult
@@ -262,7 +292,11 @@ class Alert(models.Model):
dashboard_alert = instance.assigned_check.dashboard_alert dashboard_alert = instance.assigned_check.dashboard_alert
alert_template = instance.agent.alert_template alert_template = instance.agent.alert_template
maintenance_mode = instance.agent.maintenance_mode maintenance_mode = instance.agent.maintenance_mode
alert_severity = instance.alert_severity if instance.assigned_check.check_type not in ["memcheck", "cpuload"] else instance.alert_severity alert_severity = (
instance.alert_severity
if instance.assigned_check.check_type not in ["memcheck", "cpuload"]
else instance.alert_severity
)
agent = instance.agent agent = instance.agent
# set alert_template settings # set alert_template settings
@@ -373,7 +407,9 @@ class Alert(models.Model):
) )
@classmethod @classmethod
def handle_alert_resolve(cls, instance: Union[Agent, TaskResult, CheckResult]) -> None: def handle_alert_resolve(
cls, instance: Union[Agent, TaskResult, CheckResult]
) -> None:
from agents.models import Agent from agents.models import Agent
from autotasks.models import TaskResult from autotasks.models import TaskResult
from checks.models import CheckResult from checks.models import CheckResult

View File

@@ -34,7 +34,11 @@ class AlertTemplateSerializer(ModelSerializer):
fields = "__all__" fields = "__all__"
def get_applied_count(self, instance): def get_applied_count(self, instance):
return instance.policies.count() + instance.clients.count() + instance.sites.count() return (
instance.policies.count()
+ instance.clients.count()
+ instance.sites.count()
)
class AlertTemplateRelationSerializer(ModelSerializer): class AlertTemplateRelationSerializer(ModelSerializer):

View File

@@ -18,7 +18,9 @@ def unsnooze_alerts() -> str:
def cache_agents_alert_template(): def cache_agents_alert_template():
from agents.models import Agent from agents.models import Agent
for agent in Agent.objects.only("pk", "site", "policy", "alert_template").select_related("site", "policy", "alert_template"): for agent in Agent.objects.only(
"pk", "site", "policy", "alert_template"
).select_related("site", "policy", "alert_template"):
agent.set_alert_template() agent.set_alert_template()
return "ok" return "ok"

View File

@@ -21,7 +21,9 @@ class TestAPIv3(TacticalTestCase):
# add a check # add a check
check1 = baker.make_recipe("checks.ping_check", agent=agent) check1 = baker.make_recipe("checks.ping_check", agent=agent)
check_result1 = baker.make("checks.CheckResult", agent=agent, assigned_check=check1) check_result1 = baker.make(
"checks.CheckResult", agent=agent, assigned_check=check1
)
r = self.client.get(url) r = self.client.get(url)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
self.assertEqual(r.data["check_interval"], self.agent.check_interval) # type: ignore self.assertEqual(r.data["check_interval"], self.agent.check_interval) # type: ignore
@@ -31,7 +33,9 @@ class TestAPIv3(TacticalTestCase):
check2 = baker.make_recipe( check2 = baker.make_recipe(
"checks.diskspace_check", agent=agent, run_interval=20 "checks.diskspace_check", agent=agent, run_interval=20
) )
check_result2 = baker.make("checks.CheckResult", agent=agent, assigned_check=check2) check_result2 = baker.make(
"checks.CheckResult", agent=agent, assigned_check=check2
)
r = self.client.get(url) r = self.client.get(url)
self.assertEqual(r.status_code, 200) self.assertEqual(r.status_code, 200)
@@ -143,7 +147,12 @@ class TestAPIv3(TacticalTestCase):
# setup data # setup data
task_actions = [ task_actions = [
{"type": "cmd", "command": "whoami", "timeout": 10, "shell": "cmd"}, {"type": "cmd", "command": "whoami", "timeout": 10, "shell": "cmd"},
{"type": "script", "script": script.id, "script_args": ["test"], "timeout": 30}, {
"type": "script",
"script": script.id,
"script_args": ["test"],
"timeout": 30,
},
{"type": "script", "script": 3, "script_args": [], "timeout": 30}, {"type": "script", "script": 3, "script_args": [], "timeout": 30},
] ]

View File

@@ -201,7 +201,12 @@ class CheckRunner(APIView):
# see if the correct amount of seconds have passed # see if the correct amount of seconds have passed
or ( or (
check.check_result.last_run # type: ignore check.check_result.last_run # type: ignore
< djangotime.now() - djangotime.timedelta(seconds=check.run_interval if check.run_interval else agent.check_interval) < djangotime.now()
- djangotime.timedelta(
seconds=check.run_interval
if check.run_interval
else agent.check_interval
)
) )
] ]
@@ -263,7 +268,9 @@ class TaskRunner(APIView):
# check check result or create if doesn't exist # check check result or create if doesn't exist
try: try:
task_result = TaskResult.objects.get(task=task, agent=agent) task_result = TaskResult.objects.get(task=task, agent=agent)
serializer = TaskResultSerializer(data=request.data, instance=task_result, partial=True) serializer = TaskResultSerializer(
data=request.data, instance=task_result, partial=True
)
except TaskResult.DoesNotExist: except TaskResult.DoesNotExist:
serializer = TaskResultSerializer(data=request.data, partial=True) serializer = TaskResultSerializer(data=request.data, partial=True)

View File

@@ -53,7 +53,9 @@ class Policy(BaseAuditModel):
def is_agent_excluded(self, agent): def is_agent_excluded(self, agent):
# will prefetch the many to many relations in a single query versus 3. esults are cached on the object # will prefetch the many to many relations in a single query versus 3. esults are cached on the object
models.prefetch_related_objects([self], "excluded_agents", "excluded_sites", "excluded_clients") models.prefetch_related_objects(
[self], "excluded_agents", "excluded_sites", "excluded_clients"
)
return ( return (
agent in self.excluded_agents.all() agent in self.excluded_agents.all()
@@ -61,8 +63,20 @@ class Policy(BaseAuditModel):
or agent.client in self.excluded_clients.all() or agent.client in self.excluded_clients.all()
) )
def related_agents(self, mon_type: Optional[str] = None) -> 'models.QuerySet[Agent]': def related_agents(
models.prefetch_related_objects([self], "excluded_agents", "excluded_sites", "excluded_clients", "workstation_clients", "server_clients", "workstation_sites", "server_sites", "agents") self, mon_type: Optional[str] = None
) -> "models.QuerySet[Agent]":
models.prefetch_related_objects(
[self],
"excluded_agents",
"excluded_sites",
"excluded_clients",
"workstation_clients",
"server_clients",
"workstation_sites",
"server_sites",
"agents",
)
agent_filter = {} agent_filter = {}
filtered_agents_ids = Agent.objects.none() filtered_agents_ids = Agent.objects.none()
@@ -70,9 +84,13 @@ class Policy(BaseAuditModel):
if mon_type: if mon_type:
agent_filter["monitoring_type"] = mon_type agent_filter["monitoring_type"] = mon_type
excluded_clients_ids = self.excluded_clients.only("pk").values_list("id", flat=True) excluded_clients_ids = self.excluded_clients.only("pk").values_list(
"id", flat=True
)
excluded_sites_ids = self.excluded_sites.only("pk").values_list("id", flat=True) excluded_sites_ids = self.excluded_sites.only("pk").values_list("id", flat=True)
excluded_agents_ids = self.excluded_agents.only("pk").values_list("id", flat=True) excluded_agents_ids = self.excluded_agents.only("pk").values_list(
"id", flat=True
)
if self.is_default_server_policy: if self.is_default_server_policy:
filtered_agents_ids |= ( filtered_agents_ids |= (
@@ -106,9 +124,7 @@ class Policy(BaseAuditModel):
explicit_agents = ( explicit_agents = (
self.agents.filter(**agent_filter) # type: ignore self.agents.filter(**agent_filter) # type: ignore
.exclude( .exclude(id__in=excluded_agents_ids)
id__in=excluded_agents_ids
)
.exclude(site_id__in=excluded_sites_ids) .exclude(site_id__in=excluded_sites_ids)
.exclude(site__client_id__in=excluded_clients_ids) .exclude(site__client_id__in=excluded_clients_ids)
) )

View File

@@ -1,8 +1,10 @@
from tacticalrmm.celery import app from tacticalrmm.celery import app
@app.task @app.task
def run_win_policy_autotasks_task(task: int) -> str: def run_win_policy_autotasks_task(task: int) -> str:
from autotasks.models import AutomatedTask from autotasks.models import AutomatedTask
try: try:
policy_task = AutomatedTask.objects.get(pk=task) policy_task = AutomatedTask.objects.get(pk=task)
except AutomatedTask.DoesNotExist: except AutomatedTask.DoesNotExist:

View File

@@ -118,7 +118,7 @@ class TestPolicyViews(TacticalTestCase):
data = { data = {
"name": "Test Policy Update", "name": "Test Policy Update",
"desc": "policy desc Update", "desc": "policy desc Update",
"alert_template": alert_template.pk "alert_template": alert_template.pk,
} }
resp = self.client.put(url, data, format="json") resp = self.client.put(url, data, format="json")
@@ -148,7 +148,9 @@ class TestPolicyViews(TacticalTestCase):
policy = baker.make("automation.Policy") policy = baker.make("automation.Policy")
agent = baker.make_recipe("agents.agent", policy=policy) agent = baker.make_recipe("agents.agent", policy=policy)
policy_diskcheck = baker.make_recipe("checks.diskspace_check", policy=policy) policy_diskcheck = baker.make_recipe("checks.diskspace_check", policy=policy)
result = baker.make("checks.CheckResult", agent=agent, assigned_check=policy_diskcheck) result = baker.make(
"checks.CheckResult", agent=agent, assigned_check=policy_diskcheck
)
url = f"/automation/checks/{policy_diskcheck.pk}/status/" url = f"/automation/checks/{policy_diskcheck.pk}/status/"
@@ -229,10 +231,7 @@ class TestPolicyViews(TacticalTestCase):
policy = baker.make("automation.Policy") policy = baker.make("automation.Policy")
# create managed policy tasks # create managed policy tasks
task = baker.make_recipe( task = baker.make_recipe("autotasks.task", policy=policy)
"autotasks.task",
policy=policy
)
url = f"/automation/tasks/{task.id}/run/" url = f"/automation/tasks/{task.id}/run/"
resp = self.client.post(url, format="json") resp = self.client.post(url, format="json")
@@ -490,27 +489,35 @@ class TestPolicyTasks(TacticalTestCase):
def test_update_policy_tasks(self): def test_update_policy_tasks(self):
from autotasks.models import TaskResult from autotasks.models import TaskResult
# setup data # setup data
policy = baker.make("automation.Policy", active=True) policy = baker.make("automation.Policy", active=True)
task = baker.make_recipe( task = baker.make_recipe("autotasks.task", enabled=True, policy=policy)
"autotasks.task",
enabled=True,
policy=policy
)
agent = baker.make_recipe("agents.server_agent", policy=policy) agent = baker.make_recipe("agents.server_agent", policy=policy)
task_result = baker.make("autotasks.TaskResult", task=task, agent=agent, sync_status="synced") task_result = baker.make(
"autotasks.TaskResult", task=task, agent=agent, sync_status="synced"
)
# this change shouldn't trigger the task_result field to sync_status = "notsynced" # this change shouldn't trigger the task_result field to sync_status = "notsynced"
task.actions = {"type": "cmd", "command": "whoami", "timeout": 90, "shell": "cmd"} task.actions = {
"type": "cmd",
"command": "whoami",
"timeout": 90,
"shell": "cmd",
}
task.save() task.save()
self.assertEqual(TaskResult.objects.get(pk=task_result.id).sync_status, "synced") self.assertEqual(
TaskResult.objects.get(pk=task_result.id).sync_status, "synced"
)
# task result should now be "notsynced" # task result should now be "notsynced"
task.enabled = False task.enabled = False
task.save() task.save()
self.assertEqual(TaskResult.objects.get(pk=task_result.id).sync_status, "notsynced") self.assertEqual(
TaskResult.objects.get(pk=task_result.id).sync_status, "notsynced"
)
def test_policy_exclusions(self): def test_policy_exclusions(self):

View File

@@ -80,6 +80,7 @@ class GetUpdateDeletePolicy(APIView):
return Response("ok") return Response("ok")
class PolicyAutoTask(APIView): class PolicyAutoTask(APIView):
# get status of all tasks # get status of all tasks

View File

@@ -149,9 +149,13 @@ class AutomatedTask(BaseAuditModel):
for field in self.fields_that_trigger_task_update_on_agent: for field in self.fields_that_trigger_task_update_on_agent:
if getattr(self, field) != getattr(old_task, field): if getattr(self, field) != getattr(old_task, field):
if self.policy: if self.policy:
TaskResult.objects.exclude(sync_status="inital").filter(task__policy_id=self.policy.id).update(sync_status="notsynced") TaskResult.objects.exclude(sync_status="inital").filter(
task__policy_id=self.policy.id
).update(sync_status="notsynced")
else: else:
TaskResult.objects.filter(agent=self.agent, task=self).update(sync_status="notsynced") TaskResult.objects.filter(agent=self.agent, task=self).update(
sync_status="notsynced"
)
@property @property
def schedule(self): def schedule(self):
@@ -220,8 +224,9 @@ class AutomatedTask(BaseAuditModel):
return TaskAuditSerializer(task).data return TaskAuditSerializer(task).data
def create_policy_task(
def create_policy_task(self, policy: 'Policy', assigned_check: 'Optional[Check]' = None) -> None: self, policy: "Policy", assigned_check: "Optional[Check]" = None
) -> None:
### Copies certain properties on this task (self) to a new task and sets it to the supplied Policy ### Copies certain properties on this task (self) to a new task and sets it to the supplied Policy
fields_to_copy = [ fields_to_copy = [
"alert_severity", "alert_severity",
@@ -337,7 +342,7 @@ class AutomatedTask(BaseAuditModel):
return task return task
def create_task_on_agent(self, agent: 'Optional[Agent]' = None) -> str: def create_task_on_agent(self, agent: "Optional[Agent]" = None) -> str:
if self.policy and not agent: if self.policy and not agent:
return "agent parameter needs to be passed with policy task" return "agent parameter needs to be passed with policy task"
else: else:
@@ -376,7 +381,7 @@ class AutomatedTask(BaseAuditModel):
return "ok" return "ok"
def modify_task_on_agent(self, agent: 'Optional[Agent]' = None) -> str: def modify_task_on_agent(self, agent: "Optional[Agent]" = None) -> str:
if self.policy and not agent: if self.policy and not agent:
return "agent parameter needs to be passed with policy task" return "agent parameter needs to be passed with policy task"
else: else:
@@ -415,7 +420,7 @@ class AutomatedTask(BaseAuditModel):
return "ok" return "ok"
def delete_task_on_agent(self, agent: 'Optional[Agent]' = None) -> str: def delete_task_on_agent(self, agent: "Optional[Agent]" = None) -> str:
if self.policy and not agent: if self.policy and not agent:
return "agent parameter needs to be passed with policy task" return "agent parameter needs to be passed with policy task"
else: else:
@@ -457,7 +462,7 @@ class AutomatedTask(BaseAuditModel):
return "ok" return "ok"
def run_win_task(self, agent: 'Optional[Agent]' = None) -> str: def run_win_task(self, agent: "Optional[Agent]" = None) -> str:
if self.policy and not agent: if self.policy and not agent:
return "agent parameter needs to be passed with policy task" return "agent parameter needs to be passed with policy task"
else: else:
@@ -469,7 +474,11 @@ class AutomatedTask(BaseAuditModel):
task_result = TaskResult(agent=agent, task=self) task_result = TaskResult(agent=agent, task=self)
task_result.save() task_result.save()
asyncio.run(task_result.agent.nats_cmd({"func": "runtask", "taskpk": self.pk}, wait=False)) asyncio.run(
task_result.agent.nats_cmd(
{"func": "runtask", "taskpk": self.pk}, wait=False
)
)
return "ok" return "ok"
def should_create_alert(self, alert_template=None): def should_create_alert(self, alert_template=None):
@@ -490,7 +499,7 @@ class AutomatedTask(BaseAuditModel):
class TaskResult(models.Model): class TaskResult(models.Model):
class Meta: class Meta:
unique_together = (('agent', 'task'),) unique_together = (("agent", "task"),)
objects = PermissionQuerySet.as_manager() objects = PermissionQuerySet.as_manager()
@@ -506,7 +515,7 @@ class TaskResult(models.Model):
related_name="taskresults", related_name="taskresults",
null=True, null=True,
blank=True, blank=True,
on_delete=models.CASCADE on_delete=models.CASCADE,
) )
retvalue = models.TextField(null=True, blank=True) retvalue = models.TextField(null=True, blank=True)
@@ -525,9 +534,16 @@ class TaskResult(models.Model):
def __str__(self): def __str__(self):
return f"{self.agent.hostname} - {self.task}" return f"{self.agent.hostname} - {self.task}"
def get_or_create_alert_if_needed(self, alert_template: "Optional[AlertTemplate]") -> "Optional[Alert]": def get_or_create_alert_if_needed(
self, alert_template: "Optional[AlertTemplate]"
) -> "Optional[Alert]":
from alerts.models import Alert from alerts.models import Alert
return Alert.create_or_return_task_alert(self.task, agent=self.agent, skip_create=self.task.should_create_alert(alert_template))
return Alert.create_or_return_task_alert(
self.task,
agent=self.agent,
skip_create=self.task.should_create_alert(alert_template),
)
def save_collector_results(self) -> None: def save_collector_results(self) -> None:

View File

@@ -5,7 +5,6 @@ from django.core.exceptions import ObjectDoesNotExist
from .models import AutomatedTask, TaskResult from .models import AutomatedTask, TaskResult
class TaskResultSerializer(serializers.ModelSerializer): class TaskResultSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = TaskResult model = TaskResult
@@ -18,14 +17,16 @@ class TaskSerializer(serializers.ModelSerializer):
schedule = serializers.ReadOnlyField() schedule = serializers.ReadOnlyField()
alert_template = serializers.SerializerMethodField() alert_template = serializers.SerializerMethodField()
run_time_date = serializers.DateTimeField(required=False) run_time_date = serializers.DateTimeField(required=False)
expire_date = serializers.DateTimeField( expire_date = serializers.DateTimeField(allow_null=True, required=False)
allow_null=True, required=False
)
task_result = serializers.SerializerMethodField() task_result = serializers.SerializerMethodField()
# use select_related("taskresults") on the query set to make this go faster # use select_related("taskresults") on the query set to make this go faster
def get_task_result(self, obj): def get_task_result(self, obj):
return TaskResultSerializer(obj.task_result).data if hasattr(obj, "task_result") else {} return (
TaskResultSerializer(obj.task_result).data
if hasattr(obj, "task_result")
else {}
)
def validate_actions(self, value): def validate_actions(self, value):

View File

@@ -11,6 +11,7 @@ from logs.models import DebugLog
from tacticalrmm.celery import app from tacticalrmm.celery import app
@app.task @app.task
def create_win_task_schedule(pk: int, agent_id: Optional[str] = None) -> str: def create_win_task_schedule(pk: int, agent_id: Optional[str] = None) -> str:
task = AutomatedTask.objects.get(pk=pk) task = AutomatedTask.objects.get(pk=pk)

View File

@@ -44,9 +44,7 @@ class TestAutotaskViews(TacticalTestCase):
self.assertEqual(len(resp.data), 4) # type: ignore self.assertEqual(len(resp.data), 4) # type: ignore
@patch("autotasks.tasks.create_win_task_schedule.delay") @patch("autotasks.tasks.create_win_task_schedule.delay")
def test_add_autotask( def test_add_autotask(self, create_win_task_schedule):
self, create_win_task_schedule
):
url = f"{base_url}/" url = f"{base_url}/"
# setup data # setup data
@@ -254,9 +252,7 @@ class TestAutotaskViews(TacticalTestCase):
self.check_not_authenticated("get", url) self.check_not_authenticated("get", url)
def test_update_autotask( def test_update_autotask(self):
self
):
# setup data # setup data
agent = baker.make_recipe("agents.agent") agent = baker.make_recipe("agents.agent")
agent_task = baker.make("autotasks.AutomatedTask", agent=agent) agent_task = baker.make("autotasks.AutomatedTask", agent=agent)
@@ -330,9 +326,7 @@ class TestAutotaskViews(TacticalTestCase):
self.check_not_authenticated("put", url) self.check_not_authenticated("put", url)
@patch("autotasks.tasks.delete_win_task_schedule.delay") @patch("autotasks.tasks.delete_win_task_schedule.delay")
def test_delete_autotask( def test_delete_autotask(self, delete_win_task_schedule):
self, delete_win_task_schedule
):
# setup data # setup data
agent = baker.make_recipe("agents.agent") agent = baker.make_recipe("agents.agent")
agent_task = baker.make("autotasks.AutomatedTask", agent=agent) agent_task = baker.make("autotasks.AutomatedTask", agent=agent)

View File

@@ -212,7 +212,7 @@ class Check(BaseAuditModel):
"modified_time", "modified_time",
] ]
def create_policy_check(self, policy: 'Policy') -> None: def create_policy_check(self, policy: "Policy") -> None:
fields_to_copy = [ fields_to_copy = [
"warning_threshold", "warning_threshold",
@@ -253,9 +253,7 @@ class Check(BaseAuditModel):
) )
for task in self.assignedtasks.all(): # type: ignore for task in self.assignedtasks.all(): # type: ignore
task.create_policy_task( task.create_policy_task(policy=policy, assigned_check=check)
policy=policy, assigned_check=check
)
for field in fields_to_copy: for field in fields_to_copy:
setattr(check, field, getattr(self, field)) setattr(check, field, getattr(self, field))
@@ -278,8 +276,12 @@ class Check(BaseAuditModel):
) )
) )
def add_check_history(self, value: int, agent_id: str, more_info: Any = None) -> None: def add_check_history(
CheckHistory.objects.create(check_id=self.pk, y=value, results=more_info, agent_id=agent_id) self, value: int, agent_id: str, more_info: Any = None
) -> None:
CheckHistory.objects.create(
check_id=self.pk, y=value, results=more_info, agent_id=agent_id
)
def handle_assigned_task(self) -> None: def handle_assigned_task(self) -> None:
for task in self.assignedtasks.all(): # type: ignore for task in self.assignedtasks.all(): # type: ignore
@@ -320,7 +322,7 @@ class CheckResult(models.Model):
objects = PermissionQuerySet.as_manager() objects = PermissionQuerySet.as_manager()
class Meta: class Meta:
unique_together = (('agent', 'assigned_check'),) unique_together = (("agent", "assigned_check"),)
agent = models.ForeignKey( agent = models.ForeignKey(
"agents.Agent", "agents.Agent",
@@ -335,7 +337,7 @@ class CheckResult(models.Model):
related_name="checkresults", related_name="checkresults",
null=True, null=True,
blank=True, blank=True,
on_delete=models.CASCADE on_delete=models.CASCADE,
) )
status = models.CharField( status = models.CharField(
max_length=100, choices=CHECK_STATUS_CHOICES, default="pending" max_length=100, choices=CHECK_STATUS_CHOICES, default="pending"
@@ -367,12 +369,22 @@ class CheckResult(models.Model):
@property @property
def history_info(self): def history_info(self):
if self.assigned_check.check_type == "cpuload" or self.assigned_check.check_type == "memory": if (
self.assigned_check.check_type == "cpuload"
or self.assigned_check.check_type == "memory"
):
return ", ".join(str(f"{x}%") for x in self.history[-6:]) return ", ".join(str(f"{x}%") for x in self.history[-6:])
def get_or_create_alert_if_needed(self, alert_template: "Optional[AlertTemplate]") -> "Optional[Alert]": def get_or_create_alert_if_needed(
self, alert_template: "Optional[AlertTemplate]"
) -> "Optional[Alert]":
from alerts.models import Alert from alerts.models import Alert
return Alert.create_or_return_check_alert(self.assigned_check, agent=self.agent, skip_create=self.assigned_check.should_create_alert(alert_template))
return Alert.create_or_return_check_alert(
self.assigned_check,
agent=self.agent,
skip_create=self.assigned_check.should_create_alert(alert_template),
)
def handle_check(self, data): def handle_check(self, data):
from alerts.models import Alert from alerts.models import Alert
@@ -407,7 +419,10 @@ class CheckResult(models.Model):
elif check.check_type == "diskspace": elif check.check_type == "diskspace":
if data["exists"]: if data["exists"]:
percent_used = round(data["percent_used"]) percent_used = round(data["percent_used"])
if check.error_threshold and (100 - percent_used) < check.error_threshold: if (
check.error_threshold
and (100 - percent_used) < check.error_threshold
):
self.status = "failing" self.status = "failing"
self.alert_severity = "error" self.alert_severity = "error"
elif ( elif (
@@ -478,7 +493,9 @@ class CheckResult(models.Model):
self.save(update_fields=["more_info"]) self.save(update_fields=["more_info"])
check.add_check_history( check.add_check_history(
1 if self.status == "failing" else 0, self.agent.agent_id, self.more_info[:60], 1 if self.status == "failing" else 0,
self.agent.agent_id,
self.more_info[:60],
) )
# windows service checks # windows service checks
@@ -488,7 +505,9 @@ class CheckResult(models.Model):
self.save(update_fields=["more_info"]) self.save(update_fields=["more_info"])
check.add_check_history( check.add_check_history(
1 if self.status == "failing" else 0, self.agent.agent_id, self.more_info[:60], 1 if self.status == "failing" else 0,
self.agent.agent_id,
self.more_info[:60],
) )
elif check.check_type == "eventlog": elif check.check_type == "eventlog":
@@ -549,7 +568,9 @@ class CheckResult(models.Model):
try: try:
percent_used = [ percent_used = [
d["percent"] for d in self.agent.disks if d["device"] == self.assigned_check.disk d["percent"]
for d in self.agent.disks
if d["device"] == self.assigned_check.disk
][0] ][0]
percent_free = 100 - percent_used percent_free = 100 - percent_used
@@ -568,7 +589,10 @@ class CheckResult(models.Model):
body = self.more_info body = self.more_info
elif self.assigned_check.check_type == "cpuload" or self.assigned_check.check_type == "memory": elif (
self.assigned_check.check_type == "cpuload"
or self.assigned_check.check_type == "memory"
):
text = "" text = ""
if self.assigned_check.warning_threshold: if self.assigned_check.warning_threshold:
text += f" Warning Threshold: {self.assigned_check.warning_threshold}%" text += f" Warning Threshold: {self.assigned_check.warning_threshold}%"
@@ -593,9 +617,7 @@ class CheckResult(models.Model):
elif self.assigned_check.event_source: elif self.assigned_check.event_source:
start = f"Event ID {self.assigned_check.event_id}, source {self.assigned_check.event_source} " start = f"Event ID {self.assigned_check.event_id}, source {self.assigned_check.event_source} "
elif self.assigned_check.event_message: elif self.assigned_check.event_message:
start = ( start = f"Event ID {self.assigned_check.event_id}, containing string {self.assigned_check.event_message} "
f"Event ID {self.assigned_check.event_id}, containing string {self.assigned_check.event_message} "
)
else: else:
start = f"Event ID {self.assigned_check.event_id} " start = f"Event ID {self.assigned_check.event_id} "
@@ -629,7 +651,9 @@ class CheckResult(models.Model):
try: try:
percent_used = [ percent_used = [
d["percent"] for d in self.agent.disks if d["device"] == self.assigned_check.disk d["percent"]
for d in self.agent.disks
if d["device"] == self.assigned_check.disk
][0] ][0]
percent_free = 100 - percent_used percent_free = 100 - percent_used
body = subject + f" - Free: {percent_free}%, {text}" body = subject + f" - Free: {percent_free}%, {text}"
@@ -640,7 +664,10 @@ class CheckResult(models.Model):
body = subject + f" - Return code: {self.retcode}" body = subject + f" - Return code: {self.retcode}"
elif self.assigned_check.check_type == "ping": elif self.assigned_check.check_type == "ping":
body = subject body = subject
elif self.assigned_check.check_type == "cpuload" or self.assigned_check.check_type == "memory": elif (
self.assigned_check.check_type == "cpuload"
or self.assigned_check.check_type == "memory"
):
text = "" text = ""
if self.assigned_check.warning_threshold: if self.assigned_check.warning_threshold:
text += f" Warning Threshold: {self.assigned_check.warning_threshold}%" text += f" Warning Threshold: {self.assigned_check.warning_threshold}%"
@@ -673,6 +700,7 @@ class CheckResult(models.Model):
subject = f"{self.agent.client.name}, {self.agent.site.name}, {self} Resolved" subject = f"{self.agent.client.name}, {self.agent.site.name}, {self} Resolved"
CORE.send_sms(subject, alert_template=self.agent.alert_template) # type: ignore CORE.send_sms(subject, alert_template=self.agent.alert_template) # type: ignore
class CheckHistory(models.Model): class CheckHistory(models.Model):
objects = PermissionQuerySet.as_manager() objects = PermissionQuerySet.as_manager()

View File

@@ -14,7 +14,6 @@ class AssignedTaskField(serializers.ModelSerializer):
class CheckResultSerializer(serializers.ModelSerializer): class CheckResultSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = CheckResult model = CheckResult
fields = "__all__" fields = "__all__"
@@ -28,7 +27,11 @@ class CheckSerializer(serializers.ModelSerializer):
check_result = serializers.SerializerMethodField() check_result = serializers.SerializerMethodField()
def get_check_result(self, obj): def get_check_result(self, obj):
return CheckResultSerializer(obj.check_result).data if hasattr(obj, "check_result") else {} return (
CheckResultSerializer(obj.check_result).data
if hasattr(obj, "check_result")
else {}
)
def get_alert_template(self, obj): def get_alert_template(self, obj):
if obj.agent: if obj.agent:
@@ -46,7 +49,6 @@ class CheckSerializer(serializers.ModelSerializer):
"always_alert": alert_template.check_always_alert, "always_alert": alert_template.check_always_alert,
} }
class Meta: class Meta:
model = Check model = Check
fields = "__all__" fields = "__all__"
@@ -67,10 +69,7 @@ class CheckSerializer(serializers.ModelSerializer):
# make sure no duplicate diskchecks exist for an agent/policy # make sure no duplicate diskchecks exist for an agent/policy
if check_type == "diskspace": if check_type == "diskspace":
if not self.instance: # only on create if not self.instance: # only on create
checks = ( checks = Check.objects.filter(**filter).filter(check_type="diskspace")
Check.objects.filter(**filter)
.filter(check_type="diskspace")
)
for check in checks: for check in checks:
if val["disk"] in check.disk: if val["disk"] in check.disk:
raise serializers.ValidationError( raise serializers.ValidationError(
@@ -103,10 +102,7 @@ class CheckSerializer(serializers.ModelSerializer):
) )
if check_type == "cpuload" and not self.instance: if check_type == "cpuload" and not self.instance:
if ( if Check.objects.filter(**filter, check_type="cpuload").exists():
Check.objects.filter(**filter, check_type="cpuload")
.exists()
):
raise serializers.ValidationError( raise serializers.ValidationError(
"A cpuload check for this agent already exists" "A cpuload check for this agent already exists"
) )
@@ -126,10 +122,7 @@ class CheckSerializer(serializers.ModelSerializer):
) )
if check_type == "memory" and not self.instance: if check_type == "memory" and not self.instance:
if ( if Check.objects.filter(**filter, check_type="memory").exists():
Check.objects.filter(**filter, check_type="memory")
.exists()
):
raise serializers.ValidationError( raise serializers.ValidationError(
"A memory check for this agent already exists" "A memory check for this agent already exists"
) )

View File

@@ -254,8 +254,15 @@ class TestCheckViews(TacticalTestCase):
# setup data # setup data
agent = baker.make_recipe("agents.agent") agent = baker.make_recipe("agents.agent")
check = baker.make_recipe("checks.diskspace_check", agent=agent) check = baker.make_recipe("checks.diskspace_check", agent=agent)
check_result = baker.make("checks.CheckResult", assigned_check=check, agent=agent) check_result = baker.make(
baker.make("checks.CheckHistory", check_id=check.id, agent_id=agent.agent_id, _quantity=30) "checks.CheckResult", assigned_check=check, agent=agent
)
baker.make(
"checks.CheckHistory",
check_id=check.id,
agent_id=agent.agent_id,
_quantity=30,
)
check_history_data = baker.make( check_history_data = baker.make(
"checks.CheckHistory", "checks.CheckHistory",
check_id=check.id, check_id=check.id,
@@ -689,7 +696,12 @@ class TestCheckTasks(TacticalTestCase):
) )
# test failing info # test failing info
data = {"id": check.id, "agent_id": self.agent.agent_id, "status": "failing", "output": "reply from a.com"} data = {
"id": check.id,
"agent_id": self.agent.agent_id,
"status": "failing",
"output": "reply from a.com",
}
resp = self.client.patch(url, data, format="json") resp = self.client.patch(url, data, format="json")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@@ -729,7 +741,12 @@ class TestCheckTasks(TacticalTestCase):
self.assertEqual(check.alert_severity, "error") self.assertEqual(check.alert_severity, "error")
# test passing # test passing
data = {"id": check.id, "agent_id": self.agent.agent_id, "status": "passing", "output": "reply from a.com"} data = {
"id": check.id,
"agent_id": self.agent.agent_id,
"status": "passing",
"output": "reply from a.com",
}
resp = self.client.patch(url, data, format="json") resp = self.client.patch(url, data, format="json")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@@ -746,7 +763,12 @@ class TestCheckTasks(TacticalTestCase):
) )
# test passing running # test passing running
data = {"id": check.id,"agent_id": self.agent.agent_id, "status": "passing", "more_info": "ok"} data = {
"id": check.id,
"agent_id": self.agent.agent_id,
"status": "passing",
"more_info": "ok",
}
resp = self.client.patch(url, data, format="json") resp = self.client.patch(url, data, format="json")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@@ -755,7 +777,12 @@ class TestCheckTasks(TacticalTestCase):
self.assertEqual(check_result.status, "passing") self.assertEqual(check_result.status, "passing")
# test failing # test failing
data = {"id": check.id, "agent_id": self.agent.agent_id, "status": "failing", "more_info": "ok"} data = {
"id": check.id,
"agent_id": self.agent.agent_id,
"status": "failing",
"more_info": "ok",
}
resp = self.client.patch(url, data, format="json") resp = self.client.patch(url, data, format="json")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
@@ -1023,14 +1050,22 @@ class TestCheckPermissions(TacticalTestCase):
agent = baker.make_recipe("agents.agent") agent = baker.make_recipe("agents.agent")
unauthorized_agent = baker.make_recipe("agents.agent") unauthorized_agent = baker.make_recipe("agents.agent")
check = baker.make("checks.Check", agent=agent) check = baker.make("checks.Check", agent=agent)
check_result = baker.make("checks.CheckResult", agent=agent, assigned_check=check) check_result = baker.make(
"checks.CheckResult", agent=agent, assigned_check=check
)
unauthorized_check = baker.make("checks.Check", agent=unauthorized_agent) unauthorized_check = baker.make("checks.Check", agent=unauthorized_agent)
unauthorized_check_result = baker.make("checks.CheckResult", agent=unauthorized_agent, assigned_check=unauthorized_check) unauthorized_check_result = baker.make(
"checks.CheckResult",
agent=unauthorized_agent,
assigned_check=unauthorized_check,
)
for action in ["reset", "run"]: for action in ["reset", "run"]:
if action == "reset": if action == "reset":
url = f"{base_url}/{check_result.id}/{action}/" url = f"{base_url}/{check_result.id}/{action}/"
unauthorized_url = f"{base_url}/{unauthorized_check_result.id}/{action}/" unauthorized_url = (
f"{base_url}/{unauthorized_check_result.id}/{action}/"
)
else: else:
url = f"{base_url}/{agent.agent_id}/{action}/" url = f"{base_url}/{agent.agent_id}/{action}/"
unauthorized_url = f"{base_url}/{unauthorized_agent.agent_id}/{action}/" unauthorized_url = f"{base_url}/{unauthorized_agent.agent_id}/{action}/"
@@ -1067,9 +1102,15 @@ class TestCheckPermissions(TacticalTestCase):
agent = baker.make_recipe("agents.agent") agent = baker.make_recipe("agents.agent")
unauthorized_agent = baker.make_recipe("agents.agent") unauthorized_agent = baker.make_recipe("agents.agent")
check = baker.make("checks.Check", agent=agent) check = baker.make("checks.Check", agent=agent)
check_result = baker.make("checks.CheckResult", agent=agent, assigned_check=check) check_result = baker.make(
"checks.CheckResult", agent=agent, assigned_check=check
)
unauthorized_check = baker.make("checks.Check", agent=unauthorized_agent) unauthorized_check = baker.make("checks.Check", agent=unauthorized_agent)
unauthorized_check_result = baker.make("checks.CheckResult", agent=unauthorized_agent, assigned_check=unauthorized_check) unauthorized_check_result = baker.make(
"checks.CheckResult",
agent=unauthorized_agent,
assigned_check=unauthorized_check,
)
url = f"{base_url}/{check_result.id}/history/" url = f"{base_url}/{check_result.id}/history/"
unauthorized_url = f"{base_url}/{unauthorized_check_result.id}/history/" unauthorized_url = f"{base_url}/{unauthorized_check_result.id}/history/"

View File

@@ -120,7 +120,9 @@ class ResetCheck(APIView):
result.save() result.save()
# resolve any alerts that are open # resolve any alerts that are open
alert = Alert.create_or_return_check_alert(result.assigned_check, agent=result.agent, skip_create=True) alert = Alert.create_or_return_check_alert(
result.assigned_check, agent=result.agent, skip_create=True
)
if alert: if alert:
alert.resolve() alert.resolve()
@@ -148,11 +150,7 @@ class GetCheckHistory(APIView):
check_history = CheckHistory.objects.filter(check_id=result.assigned_check.id, agent_id=result.agent.agent_id).filter(timeFilter).order_by("-x") # type: ignore check_history = CheckHistory.objects.filter(check_id=result.assigned_check.id, agent_id=result.agent.agent_id).filter(timeFilter).order_by("-x") # type: ignore
return Response( return Response(CheckHistorySerializer(check_history, many=True).data)
CheckHistorySerializer(
check_history, many=True
).data
)
@api_view(["POST"]) @api_view(["POST"])

View File

@@ -85,7 +85,9 @@ class CoreSettings(BaseAuditModel):
null=True, null=True,
blank=True, blank=True,
) )
date_format = models.CharField(max_length=30, blank=True, default="MMM-DD-YYYY - HH:mm") date_format = models.CharField(
max_length=30, blank=True, default="MMM-DD-YYYY - HH:mm"
)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
from alerts.tasks import cache_agents_alert_template from alerts.tasks import cache_agents_alert_template

View File

@@ -69,7 +69,10 @@ def _get_failing_data(agents):
for task in agent.get_tasks_with_policies(): for task in agent.get_tasks_with_policies():
if not task.task_result: if not task.task_result:
continue continue
elif task.task_result.status == "failing" and task.task_result.alert_severity == "error": elif (
task.task_result.status == "failing"
and task.task_result.alert_severity == "error"
):
data["error"] = True data["error"] = True
break break
@@ -110,7 +113,10 @@ def cache_db_fields_task():
# sync scheduled tasks # sync scheduled tasks
for task in agent.get_tasks_with_policies(exclude_synced=True): for task in agent.get_tasks_with_policies(exclude_synced=True):
try: try:
if not task.task_result or task.task_result.sync_status == "initial": if (
not task.task_result
or task.task_result.sync_status == "initial"
):
task.create_task_on_agent(agent=agent if task.policy else None) task.create_task_on_agent(agent=agent if task.policy else None)
if task.task_result.sync_status == "pendingdeletion": if task.task_result.sync_status == "pendingdeletion":
task.delete_task_on_agent(agent=agent if task.policy else None) task.delete_task_on_agent(agent=agent if task.policy else None)

View File

@@ -77,7 +77,7 @@ def dashboard_info(request):
"loading_bar_color": request.user.loading_bar_color, "loading_bar_color": request.user.loading_bar_color,
"clear_search_when_switching": request.user.clear_search_when_switching, "clear_search_when_switching": request.user.clear_search_when_switching,
"hosted": getattr(settings, "HOSTED", False), "hosted": getattr(settings, "HOSTED", False),
"date_format": CoreSettings.objects.first().date_format "date_format": CoreSettings.objects.first().date_format,
} }
) )