make the policy automated tasks check assignment work correctly and add tests

This commit is contained in:
sadnub
2021-05-10 20:35:38 -04:00
parent 83a19e005b
commit f39cd5ae2f
3 changed files with 35 additions and 33 deletions

View File

@@ -54,6 +54,8 @@ class TestPolicyViews(TacticalTestCase):
@patch("autotasks.models.AutomatedTask.create_task_on_agent")
def test_add_policy(self, create_task):
from automation.models import Policy
url = "/automation/policies/"
data = {
@@ -72,8 +74,12 @@ class TestPolicyViews(TacticalTestCase):
# create policy with tasks and checks
policy = baker.make("automation.Policy")
self.create_checks(policy=policy)
baker.make("autotasks.AutomatedTask", policy=policy, _quantity=3)
checks = self.create_checks(policy=policy)
tasks = baker.make("autotasks.AutomatedTask", policy=policy, _quantity=3)
# assign a task to a check
tasks[0].assigned_check = checks[0] # type: ignore
tasks[0].save() # type: ignore
# test copy tasks and checks to another policy
data = {
@@ -86,8 +92,16 @@ class TestPolicyViews(TacticalTestCase):
resp = self.client.post(f"/automation/policies/", data, format="json")
self.assertEqual(resp.status_code, 200)
self.assertEqual(policy.autotasks.count(), 3) # type: ignore
self.assertEqual(policy.policychecks.count(), 7) # type: ignore
copied_policy = Policy.objects.get(name=data["name"])
self.assertEqual(copied_policy.autotasks.count(), 3) # type: ignore
self.assertEqual(copied_policy.policychecks.count(), 7) # type: ignore
# make sure correct task was assign to the check
self.assertEqual(copied_policy.autotasks.get(name=tasks[0].name).assigned_check.check_type, checks[0].check_type) # type: ignore
create_task.assert_not_called()
self.check_not_authenticated("post", url)

View File

@@ -5,16 +5,15 @@ import string
from typing import List
import pytz
from alerts.models import SEVERITY_CHOICES
from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.db.models.fields import DateTimeField
from django.utils import timezone as djangotime
from logs.models import BaseAuditModel
from loguru import logger
from packaging import version as pyver
from alerts.models import SEVERITY_CHOICES
from logs.models import BaseAuditModel
from tacticalrmm.utils import bitdays_to_string
logger.configure(**settings.LOG_CONFIG)
@@ -197,33 +196,19 @@ class AutomatedTask(BaseAuditModel):
return TaskSerializer(task).data
def create_policy_task(self, agent=None, policy=None):
def create_policy_task(self, agent=None, policy=None, assigned_check=None):
# if policy is present, then this task is being copied to another policy
# if agent is present, then this task is being created on an agent from a policy
# exit if neither are set or if both are set
if not agent and not policy or agent and policy:
# also exit if assigned_check is set because this task will be created when the check is
if (
(not agent and not policy)
or (agent and policy)
or (self.assigned_check and not assigned_check)
):
return
assigned_check = None
# get correct assigned check to task if set
if agent and self.assigned_check:
# check if there is a matching check on the agent
if agent.agentchecks.filter(parent_check=self.assigned_check.pk).exists():
assigned_check = agent.agentchecks.filter(
parent_check=self.assigned_check.pk
).first()
elif policy and self.assigned_check:
if policy.policychecks.filter(name=self.assigned_check.name).exists():
assigned_check = policy.policychecks.filter(
name=self.assigned_check.name
).first()
else:
assigned_check = policy.policychecks.filter(
check_type=self.assigned_check.check_type
).first()
task = AutomatedTask.objects.create(
agent=agent,
policy=policy,
@@ -233,7 +218,8 @@ class AutomatedTask(BaseAuditModel):
)
for field in self.policy_fields_to_copy:
setattr(task, field, getattr(self, field))
if field != "assigned_check":
setattr(task, field, getattr(self, field))
task.save()

View File

@@ -6,17 +6,16 @@ from statistics import mean
from typing import Any
import pytz
from alerts.models import SEVERITY_CHOICES
from core.models import CoreSettings
from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from logs.models import BaseAuditModel
from loguru import logger
from packaging import version as pyver
from alerts.models import SEVERITY_CHOICES
from core.models import CoreSettings
from logs.models import BaseAuditModel
from .utils import bytes2human
logger.configure(**settings.LOG_CONFIG)
@@ -604,6 +603,9 @@ class Check(BaseAuditModel):
script=self.script,
)
for task in self.assignedtask.all(): # type: ignore
task.create_policy_task(agent=agent, policy=policy, assigned_check=check)
for field in self.policy_fields_to_copy:
setattr(check, field, getattr(self, field))