finish up tests and some code rework

This commit is contained in:
sadnub
2023-09-15 16:17:28 -04:00
parent bd19c4e2bd
commit 0b7eb41049
14 changed files with 3071 additions and 370 deletions

View File

@@ -0,0 +1,147 @@
import pytest
from model_bakery import baker
from rest_framework.test import APIClient
from rest_framework import status
from ..models import ReportHTMLTemplate
@pytest.fixture
def authenticated_client():
client = APIClient()
user = baker.make("accounts.User")
client.force_authenticate(user=user)
return client
@pytest.fixture
def unauthenticated_client():
return APIClient()
@pytest.fixture
def report_html_template():
return baker.make("reporting.ReportHTMLTemplate")
@pytest.fixture
def report_html_template_data():
return {"name": "Test Template", "html": "<div>Test HTML</div>"}
@pytest.mark.django_db
class TestGetAddReportHTMLTemplate:
def test_get_all_report_html_templates(
self, authenticated_client, report_html_template
):
response = authenticated_client.get("/reporting/htmltemplates/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["name"] == report_html_template.name
def test_post_new_report_html_template(
self, authenticated_client, report_html_template_data
):
response = authenticated_client.post(
"/reporting/htmltemplates/", data=report_html_template_data
)
assert response.status_code == status.HTTP_200_OK
assert ReportHTMLTemplate.objects.filter(
name=report_html_template_data["name"]
).exists()
def test_post_invalid_data(self, authenticated_client):
response = authenticated_client.post(
"/reporting/htmltemplates/", data={"name": ""}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_unauthenticated_get_html_templates_view(self, unauthenticated_client):
response = unauthenticated_client.get("/reporting/htmltemplates/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_add_html_template_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/htmltemplates/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestGetEditDeleteReportHTMLTemplate:
def test_get_specific_report_html_template(
self, authenticated_client, report_html_template
):
response = authenticated_client.get(
f"/reporting/htmltemplates/{report_html_template.id}/"
)
assert response.status_code == status.HTTP_200_OK
assert response.data["name"] == report_html_template.name
def test_get_non_existent_template(self, authenticated_client):
response = authenticated_client.get("/reporting/htmltemplates/999/")
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_update_report_html_template(
self, authenticated_client, report_html_template, report_html_template_data
):
response = authenticated_client.put(
f"/reporting/htmltemplates/{report_html_template.id}/",
data=report_html_template_data,
)
report_html_template.refresh_from_db()
assert response.status_code == status.HTTP_200_OK
assert report_html_template.name == report_html_template_data["name"]
def test_put_invalid_data(self, authenticated_client, report_html_template):
response = authenticated_client.put(
f"/reporting/htmltemplates/{report_html_template.id}/", data={"name": ""}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_delete_report_html_template(
self, authenticated_client, report_html_template
):
response = authenticated_client.delete(
f"/reporting/htmltemplates/{report_html_template.id}/"
)
assert response.status_code == status.HTTP_200_OK
assert not ReportHTMLTemplate.objects.filter(
id=report_html_template.id
).exists()
def test_delete_non_existent_template(self, authenticated_client):
response = authenticated_client.delete("/reporting/htmltemplates/999/")
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_unauthenticated_get_html_template_view(
self, unauthenticated_client, report_html_template
):
response = unauthenticated_client.get(
f"/reporting/htmltemplates/{report_html_template.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_edit_html_template_view(
self, unauthenticated_client, report_html_template
):
response = unauthenticated_client.put(
f"/reporting/htmltemplates/{report_html_template.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_delete_html_template_view(
self, unauthenticated_client, report_html_template
):
response = unauthenticated_client.delete(
f"/reporting/htmltemplates/{report_html_template.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -0,0 +1,513 @@
import pytest
import yaml
from unittest.mock import patch
from model_bakery import baker
from django.apps import apps
from ..utils import (
add_custom_fields,
make_dataqueries_inline,
build_queryset,
resolve_model,
ResolveModelException,
InvalidDBOperationException,
)
from ..constants import REPORTING_MODELS
from agents.models import Agent
@pytest.mark.django_db
class TestMakeVariablesInline:
def test_make_dataqueries_inline_valid_reference(self):
data_query = baker.make(
"reporting.ReportDataQuery", name="test_query", json_query={"test": "query"}
)
variables = yaml.dump({"data_sources": {"source1": "test_query"}})
result = make_dataqueries_inline(variables=variables)
assert yaml.safe_load(result) == {
"data_sources": {"source1": {"test": "query"}}
}
def test_make_dataqueries_inline_invalid_reference(self):
variables = yaml.dump({"data_sources": {"source1": "nonexistent_query"}})
result = make_dataqueries_inline(variables=variables)
assert yaml.safe_load(result) == {
"data_sources": {"source1": "nonexistent_query"}
}
def test_make_dataqueries_inline_no_reference(self):
variables = yaml.dump({"key": "value"})
result = make_dataqueries_inline(variables=variables)
assert yaml.safe_load(result) == {"key": "value"}
def test_make_dataqueries_inline_invalid_yaml(self):
variables = "{some: invalid: yaml}"
result = make_dataqueries_inline(variables=variables)
assert yaml.safe_load(result) == {}
class TestResolvingModels:
def test_all_reporting_models_valid(self):
for model_name, app_name in REPORTING_MODELS:
try:
model = apps.get_model(app_name, model_name)
except LookupError:
pytest.fail(f"Model: {model_name} does not exist in app: {app_name}")
def test_resolve_model_valid_model(self):
data_source = {"model": "Agent"}
result = resolve_model(data_source=data_source)
# Assuming 'agents.Agent' is a valid model in your Django app.
from agents.models import Agent
assert result["model"] == Agent
def test_resolve_model_invalid_model_name(self):
data_source = {"model": "InvalidModel"}
with pytest.raises(
ResolveModelException, match="Model lookup failed for InvalidModel"
):
resolve_model(data_source=data_source)
def test_resolve_model_no_model_key(self):
data_source = {"key": "value"}
with pytest.raises(
ResolveModelException, match="Model key must be present on data_source"
):
resolve_model(data_source=data_source)
@patch("agents.models.Agent.objects.using", return_value=Agent.objects.using("default"))
@pytest.mark.django_db()
class TestBuildingQueryset:
@pytest.fixture
def setup_agents(self):
agent1 = baker.make("agents.Agent", hostname="ZAgent1", plat="windows")
agent2 = baker.make("agents.Agent", hostname="Agent2", plat="windows")
return [agent1, agent2]
def test_build_queryset_with_valid_model(self, mock, setup_agents):
data_source = {
"model": Agent,
}
result = build_queryset(data_source=data_source)
assert result is not None, "Queryset should not be None for a valid model."
def test_build_queryset_invalid_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"invalid_op": "value",
}
with pytest.raises(InvalidDBOperationException):
build_queryset(data_source=data_source)
def test_build_queryset_without_model(self, mock, setup_agents):
data_source = {}
with pytest.raises(
Exception
): # This could be a more specific exception if you expect one.
build_queryset(data_source=data_source)
def test_build_queryset_only_operation(self, mock, setup_agents):
data_source = {"model": Agent, "only": ["hostname", "operating_system"]}
result = build_queryset(data_source=data_source)
assert len(result) == 2
for agent_data in result:
assert "hostname" in agent_data
assert "operating_system" in agent_data
assert "plat" not in agent_data
def test_build_queryset_id_is_appended_if_only_exists(self, mock, setup_agents):
data_source = {"model": Agent, "only": ["hostname"]}
result = build_queryset(data_source=data_source)
assert len(result) == 2
for agent_data in result:
assert "id" in agent_data
def test_build_queryset_filter_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"filter": {"hostname": setup_agents[0].hostname},
}
result = build_queryset(data_source=data_source)
assert len(result) == 1
assert result[0]["hostname"] == setup_agents[0].hostname
def test_filtering_operation_with_multiple_fields(self, mock, setup_agents):
data_source = {
"model": Agent,
"filter": {
"hostname": setup_agents[0].hostname,
"operating_system": setup_agents[0].operating_system,
},
}
result = build_queryset(data_source=data_source)
assert len(result) == 1
assert result[0]["hostname"] == setup_agents[0].hostname
def test_filtering_with_non_existing_condition(self, mock, setup_agents):
data_source = {"model": Agent, "filter": {"hostname": "doesn't exist"}}
result = build_queryset(data_source=data_source)
assert len(result) == 0
def test_build_queryset_exclude_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"exclude": {"hostname": setup_agents[0].hostname},
}
results = build_queryset(data_source=data_source)
assert len(results) == 1
assert results[0]["hostname"] != setup_agents[0].hostname
def test_build_query_get_operation(self, mock, setup_agents):
data_source = {"model": Agent, "get": {"agent_id": setup_agents[0].agent_id}}
agent = build_queryset(data_source=data_source)
assert agent["hostname"] == setup_agents[0].hostname
def test_build_queryset_all_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"all": True,
}
result = build_queryset(data_source=data_source)
assert len(result) == 2
# test filter and only
def test_build_queryset_filter_only_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"filter": {"hostname": setup_agents[0].hostname},
"only": ["agent_id", "hostname"],
}
result = build_queryset(data_source=data_source)
# should only return 1 result
assert len(result) == 1
assert result[0]["hostname"] == setup_agents[0].hostname
assert "plat" not in result[0]
def test_build_queryset_limit_operation(self, mock, setup_agents):
data_source = {
"model": Agent,
"limit": 1,
}
result = build_queryset(data_source=data_source)
assert len(result) == 1
def test_build_queryset_field_defer_operation(self, mock, setup_agents):
data_source = {"model": Agent, "defer": ["wmi_detail", "services"]}
result = build_queryset(data_source=data_source)
assert "wmi_detail" not in result[0]
assert "services" not in result[0]
assert result[0]["hostname"] == setup_agents[0].hostname
def test_build_queryset_first_operation(self, mock, setup_agents):
data_source = {"model": Agent, "first": True}
result = build_queryset(data_source=data_source)
assert result["hostname"] == setup_agents[0].hostname
def test_build_queryset_count_operation(self, mock, setup_agents):
data_source = {"model": Agent, "count": True}
count = build_queryset(data_source=data_source)
assert count == 2
def test_build_queryset_order_by_operation(self, mock, setup_agents):
data_source = {"model": Agent, "order_by": ["hostname"]}
result = build_queryset(data_source=data_source)
assert len(result) == 2
assert result[0]["hostname"] == setup_agents[1].hostname
assert result[1]["hostname"] == setup_agents[0].hostname
def test_build_queryset_json_presentation(self, mock, setup_agents):
import json
data_source = {"model": Agent, "json": True}
result = build_queryset(data_source=data_source)
# Deserializing the result to check the content.
result_data = json.loads(result)
assert result_data[0]["hostname"] == setup_agents[0].hostname
def test_build_queryset_custom_fields(self, mock, setup_agents):
default_value = "Default Value"
field1 = baker.make(
"core.CustomField", name="custom_1", model="agent", type="text"
)
baker.make(
"core.CustomField",
name="custom_2",
model="agent",
type="text",
default_value_string=default_value,
)
baker.make(
"agents.AgentCustomField",
agent=setup_agents[0],
field=field1,
string_value="Agent1",
)
baker.make(
"agents.AgentCustomField",
agent=setup_agents[1],
field=field1,
string_value="Agent2",
)
data_source = {"model": Agent, "custom_fields": ["custom_1", "custom_2"]}
result = build_queryset(data_source=data_source)
assert len(result) == 2
# check agent 1
assert result[0]["custom_fields"]["custom_1"] == "Agent1"
assert result[0]["custom_fields"]["custom_2"] == default_value
# check agent 2
assert result[1]["custom_fields"]["custom_1"] == "Agent2"
assert result[1]["custom_fields"]["custom_2"] == default_value
def test_build_queryset_filter_only_json_combination(self, mock, setup_agents):
import json
data_source = {
"model": Agent,
"filter": {"agent_id": setup_agents[0].agent_id},
"only": ["hostname", "agent_id"],
"json": True,
}
result_json = build_queryset(data_source=data_source)
result = json.loads(result_json)
assert len(result) == 1
assert "operating_system" not in result[0]
assert result[0]["hostname"] == setup_agents[0].hostname
def test_build_queryset_get_only_json_combination(self, mock, setup_agents):
import json
data_source = {
"model": Agent,
"get": {"agent_id": setup_agents[0].agent_id},
"only": ["hostname", "agent_id"],
"json": True,
}
result_json = build_queryset(data_source=data_source)
result = json.loads(result_json)
assert isinstance(result, dict)
assert "operating_system" not in result
assert result["hostname"] == setup_agents[0].hostname
def test_build_queryset_filter_order_by_combination(self, mock, setup_agents):
data_source = {
"model": Agent,
"filter": {"plat": "windows"},
"order_by": ["hostname"],
}
result = build_queryset(data_source=data_source)
assert len(result) == 2
assert result[0]["hostname"] == setup_agents[1].hostname
assert result[1]["hostname"] == setup_agents[0].hostname
def test_build_queryset_defer_used_over_only(self, mock, setup_agents):
data_source = {
"model": Agent,
"only": ["hostname", "operating_system"],
"defer": ["operating_system"],
}
result = build_queryset(data_source=data_source)[0]
assert "hostname" in result
assert "operating_system" not in result
def test_build_queryset_limit_ignored_with_first_or_get(self, mock, setup_agents):
data_source = {"model": Agent, "limit": 1, "first": True}
result_first = build_queryset(data_source=data_source)
assert isinstance(result_first, dict)
data_source = {
"model": Agent,
"limit": 1,
"get": {"agent_id": setup_agents[0].agent_id},
}
result_get = build_queryset(data_source=data_source)
assert isinstance(result_get, dict)
def test_build_queryset_result_type_with_get_or_first(self, mock, setup_agents):
# Test with "get"
data_source_get = {
"model": Agent,
"get": {"hostname": setup_agents[0].hostname},
}
result_get = build_queryset(data_source=data_source_get)
# Test with "first"
data_source_first = {"model": Agent, "first": True}
result_first = build_queryset(data_source=data_source_first)
assert not isinstance(result_get, list)
assert not isinstance(result_first, list)
def test_build_queryset_result_type_without_get_or_first(self, mock, setup_agents):
data_source = {
"model": Agent,
}
result = build_queryset(data_source=data_source)
assert isinstance(result, list)
def test_build_queryset_result_in_json_format(self, mock, setup_agents):
import json
data_source = {"model": Agent, "json": True}
result = build_queryset(data_source=data_source)
try:
parsed_result = json.loads(result)
except json.JSONDecodeError:
assert False
assert isinstance(parsed_result, list)
@pytest.mark.django_db
class TestAddingCustomFields:
@pytest.mark.parametrize(
"model_name,custom_field_model",
[
("agent", "agents.AgentCustomField"),
("client", "clients.ClientCustomField"),
("site", "clients.SiteCustomField"),
],
)
def test_add_custom_fields_with_list_of_dicts(self, model_name, custom_field_model):
custom_field = baker.make("core.CustomField", name="field1", model=model_name)
default_value = "Default Value"
custom_field2 = baker.make(
"core.CustomField",
name="field2",
model=model_name,
default_value_string=default_value,
)
custom_model_instance1 = baker.make(
custom_field_model, field=custom_field, string_value="Value"
)
custom_model_instance2 = baker.make(
custom_field_model, field=custom_field, string_value="Value"
)
data = [
{"id": getattr(custom_model_instance1, f"{model_name}_id")},
{"id": getattr(custom_model_instance2, f"{model_name}_id")},
]
fields_to_add = ["field1", "field2"]
result = add_custom_fields(
data=data, fields_to_add=fields_to_add, model_name=model_name
)
# Assert logic here based on what you expect the result to be
assert result[0]["custom_fields"]["field1"] == custom_model_instance1.value
assert result[1]["custom_fields"]["field1"] == custom_model_instance2.value
assert result[0]["custom_fields"]["field2"] == default_value
assert result[1]["custom_fields"]["field2"] == default_value
@pytest.mark.parametrize(
"model_name,custom_field_model",
[
("agent", "agents.AgentCustomField"),
("client", "clients.ClientCustomField"),
("site", "clients.SiteCustomField"),
],
)
def test_add_custom_fields_to_dictionary(self, model_name, custom_field_model):
custom_field = baker.make("core.CustomField", name="field1", model=model_name)
custom_model_instance = baker.make(
custom_field_model, field=custom_field, string_value="default_value"
)
data = {"id": getattr(custom_model_instance, f"{model_name}_id")}
fields_to_add = ["field1"]
result = add_custom_fields(
data=data,
fields_to_add=fields_to_add,
model_name=model_name,
dict_value=True,
)
# Assert logic here based on what you expect the result to be
assert result["custom_fields"]["field1"] == custom_model_instance.value
@pytest.mark.parametrize(
"model_name",
[
"agent",
"client",
"site",
],
)
def test_add_custom_fields_with_default_value(self, model_name):
default_value = "default_value"
custom_field = baker.make(
"core.CustomField",
name="field1",
model=model_name,
default_value_string=default_value,
)
# Note: Not creating an instance of the custom_field_model here to ensure the default value is used
data = {"id": 999} # ID not associated with any custom field model instance
fields_to_add = ["field1"]
result = add_custom_fields(
data=data,
fields_to_add=fields_to_add,
model_name=model_name,
dict_value=True,
)
# Assert that the default value is used
assert result["custom_fields"]["field1"] == default_value

View File

@@ -0,0 +1,184 @@
import pytest
import json
from model_bakery import baker
from rest_framework.test import APIClient
from rest_framework import status
from unittest.mock import patch, mock_open, MagicMock
from ..models import ReportDataQuery
@pytest.fixture
def authenticated_client():
client = APIClient()
user = baker.make("accounts.User")
client.force_authenticate(user=user)
return client
@pytest.fixture
def unauthenticated_client():
return APIClient()
@pytest.fixture
def report_data_query():
return baker.make("reporting.ReportDataQuery")
@pytest.fixture
def report_data_query_data():
return {"name": "Test Data Query", "json_query": {"test": "value"}}
@pytest.mark.django_db
class TestGetAddReportDataQuery:
def test_get_all_report_data_queries(self, authenticated_client, report_data_query):
url = "/reporting/dataqueries/"
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data[0]["name"] == report_data_query.name
def test_post_new_report_data_query(
self, authenticated_client, report_data_query_data
):
url = "/reporting/dataqueries/"
response = authenticated_client.post(
url, data=report_data_query_data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert ReportDataQuery.objects.filter(
name=report_data_query_data["name"]
).exists()
def test_post_invalid_data(self, authenticated_client):
url = "/reporting/dataqueries/"
response = authenticated_client.post(url, data={"name": ""})
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_unauthenticated_get_data_queries_view(self, unauthenticated_client):
response = unauthenticated_client.get("/reporting/dataqueries/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_add_data_query_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/dataqueries/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestGetEditDeleteReportDataQuery:
def test_get_specific_report_data_query(
self, authenticated_client, report_data_query
):
url = f"/reporting/dataqueries/{report_data_query.id}/"
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_200_OK
assert response.data["name"] == report_data_query.name
def test_get_non_existent_data_query(self, authenticated_client):
url = "/reporting/dataqueries/9999/"
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_update_report_data_query(
self, authenticated_client, report_data_query, report_data_query_data
):
url = f"/reporting/dataqueries/{report_data_query.id}/"
response = authenticated_client.put(
url, data=report_data_query_data, format="json"
)
report_data_query.refresh_from_db()
assert response.status_code == status.HTTP_200_OK
assert report_data_query.name == report_data_query_data["name"]
def test_put_invalid_data(self, authenticated_client, report_data_query):
url = f"/reporting/dataqueries/{report_data_query.id}/"
response = authenticated_client.put(url, data={"name": ""})
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_delete_report_data_query(self, authenticated_client, report_data_query):
url = f"/reporting/dataqueries/{report_data_query.id}/"
response = authenticated_client.delete(url)
assert response.status_code == status.HTTP_200_OK
assert not ReportDataQuery.objects.filter(id=report_data_query.id).exists()
def test_delete_non_existent_data_query(self, authenticated_client):
url = "/reporting/dataqueries/9999/"
response = authenticated_client.delete(url)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_unauthenticated_get_data_queries_view(
self, unauthenticated_client, report_data_query
):
response = unauthenticated_client.get(
f"/reporting/dataqueries/{report_data_query.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_edit_html_template_view(
self, unauthenticated_client, report_data_query
):
response = unauthenticated_client.put(
f"/reporting/dataqueries/{report_data_query.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthenticated_delete_html_template_view(
self, unauthenticated_client, report_data_query
):
response = unauthenticated_client.delete(
f"/reporting/dataqueries/{report_data_query.id}/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestQuerySchema:
def test_get_query_schema_in_debug_mode(self, settings, authenticated_client):
# Set DEBUG mode
settings.DEBUG = True
expected_data = {"sample": "json"}
# mock the file
mopen = mock_open(read_data=json.dumps({"sample": "json"}))
with patch("builtins.open", mopen):
response = authenticated_client.get("/reporting/queryschema/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == expected_data
def test_get_query_schema_in_production_mode(self, settings, authenticated_client):
# Set production mode (DEBUG = False)
settings.DEBUG = False
response = authenticated_client.get("/reporting/queryschema/")
assert response.status_code == status.HTTP_200_OK
# Check that the X-Accel-Redirect header is set correctly
assert (
response["X-Accel-Redirect"]
== "/static/reporting/schemas/query_schema.json"
)
def test_get_query_schema_file_missing(self, settings, authenticated_client):
# Set DEBUG mode
settings.DEBUG = True
with patch("builtins.open", side_effect=FileNotFoundError):
response = authenticated_client.get("/reporting/queryschema/")
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_unauthenticated_query_schema_view(self, unauthenticated_client):
response = unauthenticated_client.delete(f"/reporting/queryschema/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -0,0 +1,249 @@
import pytest
import uuid
import base64
import json
from model_bakery import baker
from rest_framework.test import APIClient
from unittest.mock import patch
from rest_framework import status
from ..models import ReportAsset, ReportTemplate, ReportHTMLTemplate
@pytest.fixture
def authenticated_client():
client = APIClient()
user = baker.make("accounts.User")
client.force_authenticate(user=user)
return client
@pytest.fixture
def unauthenticated_client():
return APIClient()
@pytest.fixture
def report_template():
return baker.make(
"reporting.ReportTemplate",
name="test_template",
template_md="# Test MD",
template_css="body { color: red; }",
type="markdown",
depends_on=["some_dependency"],
)
@pytest.fixture
def report_template_with_base_template():
base_template = baker.make("reporting.ReportHTMLTemplate")
return baker.make(
"reporting.ReportTemplate",
name="test_template",
template_md="# Test MD",
template_css="body { color: red; }",
template_html=base_template,
type="markdown",
depends_on=["some_dependency"],
)
@pytest.mark.django_db
class TestExportReportTemplate:
@patch(
"ee.reporting.views.base64_encode_assets", return_value="some_encoded_assets"
)
def test_export_report_template_with_base_template(
self,
mock_encode_assets,
authenticated_client,
report_template_with_base_template,
):
url = f"/reporting/templates/{report_template_with_base_template.id}/export/"
response = authenticated_client.post(url)
expected_response = {
"base_template": {
"name": report_template_with_base_template.template_html.name,
"html": report_template_with_base_template.template_html.html,
},
"template": {
"name": report_template_with_base_template.name,
"template_css": report_template_with_base_template.template_css,
"template_md": report_template_with_base_template.template_md,
"type": report_template_with_base_template.type,
"depends_on": report_template_with_base_template.depends_on,
"template_variables": "{}\n",
},
"assets": "some_encoded_assets",
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected_response
mock_encode_assets.assert_called()
@patch(
"ee.reporting.views.base64_encode_assets", return_value="some_encoded_assets"
)
def test_export_report_template_without_base_template(
self,
mock_encode_assets,
authenticated_client,
report_template,
):
url = f"/reporting/templates/{report_template.id}/export/"
response = authenticated_client.post(url)
expected_response = {
"base_template": None,
"template": {
"name": report_template.name,
"template_css": report_template.template_css,
"template_md": report_template.template_md,
"type": report_template.type,
"depends_on": report_template.depends_on,
"template_variables": "{}\n",
},
"assets": "some_encoded_assets",
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected_response
mock_encode_assets.assert_called()
def test_unauthenticated_export_report_template_view(
self, unauthenticated_client, report_template
):
response = unauthenticated_client.post(
f"/reporting/templates/{report_template.id}/export/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestImportReportTemplate:
@pytest.fixture
def valid_template_data(self):
"""Returns a sample valid template data."""
return {
"template": {
"name": "test_template",
"template_md": "# Test MD",
"type": "markdown",
"depends_on": ["some_dependency"],
}
}
@pytest.fixture
def valid_base_template_data(self):
"""Returns a sample valid base template data."""
return {"name": "base_test_template", "html": "<div>Test</div>"}
@pytest.fixture
def valid_assets_data(self):
"""Returns a sample valid assets data."""
return [
{
"id": str(uuid.uuid4()),
"name": "asset1.png",
"file": base64.b64encode(b"mock_content1").decode("utf-8"),
},
{
"id": str(uuid.uuid4()),
"name": "asset2.png",
"file": base64.b64encode(b"mock_content2").decode("utf-8"),
},
]
def test_basic_import(self, authenticated_client, valid_template_data):
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_200_OK
assert ReportTemplate.objects.filter(name="test_template").exists()
def test_import_with_base_template(
self, authenticated_client, valid_template_data, valid_base_template_data
):
valid_template_data["base_template"] = valid_base_template_data
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_200_OK
assert ReportHTMLTemplate.objects.filter(name="base_test_template").exists()
assert ReportTemplate.objects.filter(name="test_template").exists()
def test_import_with_assets(
self, authenticated_client, valid_template_data, valid_assets_data
):
valid_template_data["assets"] = valid_assets_data
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_200_OK
assert ReportAsset.objects.filter(id=valid_assets_data[0]["id"]).exists()
assert ReportAsset.objects.filter(id=valid_assets_data[1]["id"]).exists()
@patch(
"ee.reporting.views.ImportReportTemplate._generate_random_string",
return_value="_randomized",
)
def test_name_conflict_on_repeated_calls(
self, generate_random_string_mock, authenticated_client, valid_template_data
):
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert ReportTemplate.objects.filter(name="test_template").exists()
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_200_OK
assert ReportTemplate.objects.filter(name="test_template_randomized").exists()
def test_invalid_data(self, authenticated_client, valid_template_data):
valid_template_data["template"].pop("name")
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "name" in response.data
def test_import_with_assets_with_conflicting_paths(
self, authenticated_client, valid_template_data, valid_assets_data
):
conflicting_asset = {
"id": str(uuid.uuid4()),
"name": valid_assets_data[0]["name"],
"file": base64.b64encode(b"mock_content1").decode("utf-8"),
}
valid_assets_data.append(conflicting_asset)
valid_template_data["assets"] = valid_assets_data
url = "/reporting/templates/import/"
response = authenticated_client.post(
url, data={"template": json.dumps(valid_template_data)}
)
assert response.status_code == status.HTTP_200_OK
assert ReportAsset.objects.filter(id=valid_assets_data[0]["id"]).exists()
assert ReportAsset.objects.filter(id=valid_assets_data[1]["id"]).exists()
assert ReportAsset.objects.filter(id=conflicting_asset["id"]).exists()
# check if the renaming logic is working
asset = ReportAsset.objects.get(id=conflicting_asset["id"])
assert asset.file.name != valid_assets_data[0]["name"]
def test_unauthenticated_import_report_template_view(self, unauthenticated_client):
response = unauthenticated_client.post(f"/reporting/templates/import/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -0,0 +1,17 @@
import pytest
import os
from django.core import management
@pytest.mark.django_db
class TestSchemaGeneration:
def test_generate_json_schema(self, settings):
management.call_command("generate_json_schemas")
schema_path = (
f"{settings.STATICFILES_DIRS[0]}reporting/schemas/query_schema.json"
)
assert os.path.exists(schema_path)
os.remove(schema_path)

View File

@@ -0,0 +1,534 @@
import pytest
import uuid
import os
from rest_framework.test import APIClient
from unittest.mock import patch, mock_open
from model_bakery import baker
from rest_framework import status
from django.core.exceptions import SuspiciousFileOperation
from django.core.files.uploadedfile import SimpleUploadedFile
from ..models import ReportAsset
@pytest.fixture
def authenticated_client():
client = APIClient()
user = baker.make("accounts.User")
client.force_authenticate(user=user)
return client
@pytest.fixture
def unauthenticated_client():
return APIClient()
@pytest.mark.django_db
class TestGetReportAssets:
@patch("ee.reporting.views.report_assets_fs")
def test_valid_path_with_dir_and_files(
self, mock_report_assets_fs, authenticated_client
):
# Set up the mock behavior for report_assets_fs
mock_report_assets_fs.listdir.return_value = (["folder1"], ["file1.txt"])
mock_report_assets_fs.size.return_value = 100
mock_report_assets_fs.url.return_value = "/mocked/url/to/resource"
path = "some/valid/path"
url = f"/reporting/assets/?path={path}"
expected_response_data = [
{
"name": "folder1",
"path": os.path.join(path, "folder1"),
"type": "folder",
"size": None,
"url": "/mocked/url/to/resource",
},
{
"name": "file1.txt",
"path": os.path.join(path, "file1.txt"),
"type": "file",
"size": "100",
"url": "/mocked/url/to/resource",
},
]
response = authenticated_client.get(url)
assert response.status_code == 200
assert response.data == expected_response_data
@patch("ee.reporting.views.report_assets_fs")
def test_no_path(self, mock_report_assets_fs, authenticated_client):
# Set up the mock behavior for report_assets_fs
mock_report_assets_fs.listdir.return_value = (["folder1"], ["file1.txt"])
mock_report_assets_fs.size.return_value = 100
mock_report_assets_fs.url.return_value = "/mocked/url/to/resource"
url = "/reporting/assets/"
expected_response_data = [
{
"name": "folder1",
"path": "folder1",
"type": "folder",
"size": None,
"url": "/mocked/url/to/resource",
},
{
"name": "file1.txt",
"path": "file1.txt",
"type": "file",
"size": "100",
"url": "/mocked/url/to/resource",
},
]
response = authenticated_client.get(url)
assert response.status_code == 200
assert response.data == expected_response_data
def test_invalid_path(self, authenticated_client):
url = "/reporting/assets/?path=some/invalid/path"
response = authenticated_client.get(url)
assert response.status_code == 400
def test_unauthenticated_get_report_assets_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestGetAllAssets:
@patch("ee.reporting.views.os.walk")
def test_general_functionality(self, mock_os_walk, authenticated_client):
mock_os_walk.return_value = iter(
[(".", ["subdir"], ["file1.txt"]), ("./subdir", [], ["subdirfile.txt"])]
)
asset1 = baker.make("reporting.ReportAsset", file="file1.txt")
asset2 = baker.make("reporting.ReportAsset", file="subdir/subdirfile.txt")
expected_data = [
{
"type": "folder",
"name": "Report Assets",
"path": ".",
"children": [
{
"type": "folder",
"name": "subdir",
"path": "subdir",
"children": [
{
"id": asset2.id,
"type": "file",
"name": "subdirfile.txt",
"path": "subdir/subdirfile.txt",
"icon": "description",
}
],
"selectable": False,
"icon": "folder",
"iconColor": "yellow-9",
},
{
"id": asset1.id,
"type": "file",
"name": "file1.txt",
"path": "file1.txt",
"icon": "description",
},
],
"selectable": False,
"icon": "folder",
"iconColor": "yellow-9",
}
]
response = authenticated_client.get("/reporting/assets/all/")
assert response.status_code == 200
assert expected_data == response.data
@patch("ee.reporting.views.os.chdir", side_effect=FileNotFoundError)
def test_invalid_report_assets_dir(self, mock_os_walk, authenticated_client):
response = authenticated_client.get("/reporting/assets/all/")
assert response.status_code == 400
@patch("ee.reporting.views.os.walk")
def test_only_folders(self, mock_os_walk, authenticated_client):
mock_os_walk.return_value = iter(
[(".", ["subdir"], ["file1.txt"]), ("./subdir", [], ["subdirfile.txt"])]
)
asset1 = baker.make("reporting.ReportAsset", file="file1.txt")
asset2 = baker.make("reporting.ReportAsset", file="subdir/subdirfile.txt")
response = authenticated_client.get("/reporting/assets/all/?onlyFolders=true")
assert response.status_code == 200
for node in response.data:
assert node["type"] != "file"
def test_unauthenticated_get_report_assets_all_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/all/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestRenameAsset:
def test_rename_file(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.rename",
return_value="path/to/newname.txt",
) as mock_rename, patch(
"ee.reporting.views.report_assets_fs.isfile", return_value=True
) as mock_isfile, patch(
"ee.reporting.views.report_assets_fs.exists", return_value=True
) as mock_exists:
asset = baker.make("reporting.ReportAsset", file="path/to/file.txt")
response = authenticated_client.put(
"/reporting/assets/rename/",
data={"path": "path/to/file.txt", "newName": "newname.txt"},
)
mock_rename.assert_called_with(
path="path/to/file.txt", new_name="newname.txt"
)
assert response.status_code == 200
assert response.data == "path/to/newname.txt"
asset.refresh_from_db()
assert asset.file.name == "path/to/newname.txt"
def test_rename_folder(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.rename",
return_value="path/to/newfolder",
) as mock_rename, patch(
"ee.reporting.views.report_assets_fs.isfile", return_value=False
) as mock_isfile, patch(
"ee.reporting.views.report_assets_fs.exists", return_value=True
) as mock_exists:
response = authenticated_client.put(
"/reporting/assets/rename/",
data={"path": "path/to/folder", "newName": "newfolder"},
)
mock_rename.assert_called_with(path="path/to/folder", new_name="newfolder")
assert response.status_code == 200
assert response.data == "path/to/newfolder"
def test_rename_non_existent_file(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.rename",
side_effect=OSError("File not found"),
) as mock_rename, patch(
"ee.reporting.views.report_assets_fs.exists", return_value=True
) as mock_exists:
response = authenticated_client.put(
"/reporting/assets/rename/",
data={
"path": "non_existent_path/to/file.txt",
"newName": "newname.txt",
},
)
mock_rename.assert_called_with(
path="non_existent_path/to/file.txt", new_name="newname.txt"
)
assert response.status_code == 400
def test_suspicious_operation(self, authenticated_client):
response = authenticated_client.put(
"/reporting/assets/rename/",
data={"path": "../outside/path", "newName": "newname.txt"},
)
assert response.status_code == 400
def test_unauthenticated_rename_report_asset_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/rename/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestNewFolder:
def test_create_folder_success(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.createfolder",
return_value="new/path/to/folder",
) as mock_createfolder:
response = authenticated_client.post(
"/reporting/assets/newfolder/", data={"path": "new/path/to/folder"}
)
mock_createfolder.assert_called_with(path="new/path/to/folder")
assert response.status_code == 200
assert response.data == "new/path/to/folder"
def test_create_folder_missing_path(self, authenticated_client):
response = authenticated_client.post("/reporting/assets/newfolder/")
assert response.status_code == 400
def test_create_folder_os_error(self, authenticated_client):
response = authenticated_client.post(
"/reporting/assets/newfolder/", data={"path": "invalid/path"}
)
assert response.status_code == 400
def test_create_folder_suspicious_operation(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.createfolder",
side_effect=SuspiciousFileOperation("Invalid path"),
) as mock_createfolder:
response = authenticated_client.post(
"/reporting/assets/newfolder/", data={"path": "../outside/path"}
)
mock_createfolder.assert_called_with(path="../outside/path")
assert response.status_code == 400
def test_unauthenticated_new_folder_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/newfolder/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestDeleteAssets:
def test_delete_directory_success(self, authenticated_client):
from ..settings import settings
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=True
), patch("ee.reporting.views.shutil.rmtree") as mock_rmtree:
asset1 = baker.make("reporting.ReportAsset", file="path/to/dir/file1.txt")
asset2 = baker.make("reporting.ReportAsset", file="path/to/dir/file2.txt")
response = authenticated_client.post(
"/reporting/assets/delete/", data={"paths": ["path/to/dir"]}
)
mock_rmtree.assert_called_with(
f"{settings.REPORTING_ASSETS_BASE_PATH}/path/to/dir"
)
assert response.status_code == 200
# make sure the assets within the deleted folder also got deleted
assert not ReportAsset.objects.filter(id=asset1.id).exists()
assert not ReportAsset.objects.filter(id=asset2.id).exists()
def test_delete_file_success(self, authenticated_client):
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=False):
asset = baker.make("reporting.ReportAsset", file="path/to/file.txt")
response = authenticated_client.post(
"/reporting/assets/delete/", data={"paths": ["path/to/file.txt"]}
)
assert response.status_code == 200
assert not ReportAsset.objects.filter(id=asset.id).exists()
def test_delete_os_error(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=True
), patch(
"ee.reporting.views.shutil.rmtree", side_effect=OSError("Unable to delete")
):
response = authenticated_client.post(
"/reporting/assets/delete/", data={"paths": ["invalid/path"]}
)
assert response.status_code == 400
def test_delete_suspicious_operation(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=True
), patch(
"ee.reporting.views.shutil.rmtree",
side_effect=SuspiciousFileOperation("Invalid path"),
):
response = authenticated_client.post(
"/reporting/assets/delete/", data={"paths": ["../outside/path"]}
)
assert response.status_code == 400
def test_delete_asset_not_in_db_but_on_fs(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=False
), patch("ee.reporting.views.report_assets_fs.delete") as mock_fs_delete:
response = authenticated_client.post(
"/reporting/assets/delete/", data={"paths": ["path/to/file.txt"]}
)
mock_fs_delete.assert_called_with("path/to/file.txt")
assert response.status_code == 200
def test_unauthenticated_delete_assets_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/delete/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestUploadAssets:
def test_upload_success(self, authenticated_client):
mock_file = SimpleUploadedFile("test123.txt", b"file_content")
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=True):
response = authenticated_client.post(
"/reporting/assets/upload/",
data={"parentPath": "path", "test123.txt": mock_file},
)
assert response.data["test123.txt"]["filename"] == "path/test123.txt"
assert response.status_code == 200
assert ReportAsset.objects.filter(file="path/test123.txt").exists()
# cleanup file so future tests don't break
asset = ReportAsset.objects.get(file="path/test123.txt")
asset.file.delete()
asset.delete()
def test_upload_invalid_directory(self, authenticated_client):
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=False):
response = authenticated_client.post(
"/reporting/assets/upload/",
data={"parentPath": "invalid_path"},
)
assert response.status_code == 400
def test_upload_suspicious_operation(self, authenticated_client):
mock_file = SimpleUploadedFile("test2.txt", b"file_content")
with patch("ee.reporting.views.report_assets_fs.isdir", return_value=True):
response = authenticated_client.post(
"/reporting/assets/upload/",
data={"parentPath": "../path", "test2.txt": mock_file},
)
assert response.status_code == 400
def test_unauthenticated_uploads_assets_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/upload/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestAssetDownload:
def test_download_file_success(self, authenticated_client):
m = mock_open()
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=False
), patch("builtins.open", m), patch(
"ee.reporting.views.report_assets_fs.path", return_value="path/test.txt"
):
response = authenticated_client.get(
"/reporting/assets/download/", {"path": "test.txt"}
)
assert response.status_code == 200
def test_download_directory_success(self, authenticated_client):
m = mock_open()
with patch(
"ee.reporting.views.report_assets_fs.isdir", return_value=True
), patch(
"ee.reporting.views.shutil.make_archive", return_value="path/test.zip"
), patch(
"builtins.open", m
), patch(
"ee.reporting.views.report_assets_fs.path", return_value="path/test"
), patch(
"ee.reporting.views.os.remove", return_value=None
):
response = authenticated_client.get(
"/reporting/assets/download/", {"path": "test"}
)
assert response.status_code == 200
m.assert_called_once_with("path/test.zip", "rb")
def test_download_invalid_path(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir",
side_effect=OSError("Path does not exist"),
):
response = authenticated_client.get(
"/reporting/assets/download/", {"path": "invalid_path"}
)
assert response.status_code == 400
def test_download_os_error(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir",
side_effect=OSError("Download failed"),
):
response = authenticated_client.get(
"/reporting/assets/download/", {"path": "test.txt"}
)
assert response.status_code == 400
def test_download_suspicious_operation(self, authenticated_client):
with patch(
"ee.reporting.views.report_assets_fs.isdir",
side_effect=SuspiciousFileOperation("Suspicious path"),
):
response = authenticated_client.get(
"/reporting/assets/download/", {"path": "test.txt"}
)
assert response.status_code == 400
def test_unauthenticated_uploads_assets_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/assets/download/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestAssetNginxRedirect:
@pytest.fixture
def asset(self):
return baker.make("reporting.ReportAsset", file="test_asset.txt")
def test_valid_uuid_and_path(self, unauthenticated_client, asset):
response = unauthenticated_client.get(
f"/reporting/assets/test_asset.txt?id={asset.id}"
)
assert response.status_code == 200
assert response["X-Accel-Redirect"] == "/assets/test_asset.txt"
def test_valid_uuid_wrong_path(self, unauthenticated_client, asset):
response = unauthenticated_client.get(
f"/reporting/assets/wrong_path.txt?id={asset.id}"
)
assert response.status_code == 403
def test_valid_uuid_no_asset(self, unauthenticated_client):
non_existent_uuid = uuid.uuid4()
url = f"/reporting/assets/test_asset.txt?id={non_existent_uuid}"
response = unauthenticated_client.get(url)
assert response.status_code == 404
def test_invalid_uuid(self, unauthenticated_client):
response = unauthenticated_client.get(
"/reporting/assets/test_asset.txt?id=invalid_uuid"
)
assert response.status_code == 400
assert "There was a error processing the request" in response.content.decode(
"utf-8"
)
def test_no_id(self, unauthenticated_client):
response = unauthenticated_client.get("/reporting/assets/test_asset.txt")
assert response.status_code == 400
assert "There was a error processing the request" in response.content.decode(
"utf-8"
)

View File

@@ -0,0 +1,368 @@
import pytest
from rest_framework.test import APIClient
from unittest.mock import patch
from model_bakery import baker
from rest_framework import status
from jinja2.exceptions import TemplateError
@pytest.fixture
def authenticated_client():
client = APIClient()
user = baker.make("accounts.User")
client.force_authenticate(user=user)
return client
@pytest.fixture
def unauthenticated_client():
return APIClient()
@pytest.fixture
def report_template():
return baker.make("reporting.ReportTemplate")
@pytest.mark.django_db
class TestReportTemplateViews:
def test_get_all_report_templates_empty_db(self, authenticated_client):
response = authenticated_client.get("/reporting/templates/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 0
def test_get_all_report_templates(self, authenticated_client):
# Create some sample ReportTemplates using model_bakery
baker.make("reporting.ReportTemplate", _quantity=5)
response = authenticated_client.get("/reporting/templates/")
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 5
def test_get_report_templates_with_filter(self, authenticated_client):
# Create templates with specific dependencies
baker.make("reporting.ReportTemplate", depends_on=["agent"], _quantity=3)
baker.make("reporting.ReportTemplate", depends_on=["client"], _quantity=2)
response = authenticated_client.get(
"/reporting/templates/", {"dependsOn[]": ["agent"]}
)
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 3
def test_post_report_template_valid_data(self, authenticated_client):
valid_data = {"name": "Test Template", "template_md": "Template Text"}
response = authenticated_client.post("/reporting/templates/", valid_data)
assert response.status_code == status.HTTP_200_OK
assert response.data["name"] == "Test Template"
def test_post_report_template_invalid_data(self, authenticated_client):
invalid_data = {"name": "Test Template"}
response = authenticated_client.post("/reporting/templates/", invalid_data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_get_report_template(self, authenticated_client, report_template):
url = f"/reporting/templates/{report_template.pk}/"
response = authenticated_client.get(url)
assert response.status_code == status.HTTP_200_OK
assert response.data["id"] == report_template.id
def test_edit_report_template(self, authenticated_client, report_template):
url = f"/reporting/templates/{report_template.pk}/"
updated_name = "Updated name"
response = authenticated_client.put(url, {"name": updated_name}, format="json")
assert response.status_code == status.HTTP_200_OK
report_template.refresh_from_db()
assert report_template.name == updated_name
def test_delete_report_template(self, authenticated_client, report_template):
url = f"/reporting/templates/{report_template.pk}/"
response = authenticated_client.delete(url)
assert response.status_code == status.HTTP_200_OK
with pytest.raises(report_template.DoesNotExist):
report_template.refresh_from_db()
# test unauthorized access
def test_unauthorized_get_report_templates_view(self, unauthenticated_client):
url = f"/reporting/templates/"
response = unauthenticated_client.get(url)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthorized_add_report_template_view(self, unauthenticated_client):
url = f"/reporting/templates/"
response = unauthenticated_client.post(url)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthorized_get_report_template_view(
self, unauthenticated_client, report_template
):
url = f"/reporting/templates/{report_template.pk}/"
response = unauthenticated_client.get(url)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthorized_edit_report_template_view(
self, unauthenticated_client, report_template
):
url = f"/reporting/templates/{report_template.pk}/"
updated_name = "Updated name"
response = unauthenticated_client.put(
url, {"name": updated_name}, format="json"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_unauthorized_delete_report_template_view(
self, unauthenticated_client, report_template
):
url = f"/reporting/templates/{report_template.pk}/"
response = unauthenticated_client.delete(url)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestReportTemplateGenerateView:
def test_generate_html_report(self, authenticated_client, report_template):
data = {"format": "html", "dependencies": {}}
response = authenticated_client.post(
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert report_template.template_md in response.data
def test_generate_pdf_report(self, authenticated_client, report_template):
data = {"format": "pdf", "dependencies": {}}
response = authenticated_client.post(
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert response["content-type"] == "application/pdf"
def test_generate_invalid_format_report(
self, authenticated_client, report_template
):
data = {"format": "invalid_format", "dependencies": {}}
response = authenticated_client.post(
f"/reporting/templates/{report_template.id}/run/", data=data, format="json"
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_generate_report_template_error(self, authenticated_client):
template = baker.make("reporting.ReportTemplate", template_md="{{invalid}")
data = {"format": "html", "dependencies": {}}
response = authenticated_client.post(
f"/reporting/templates/{template.id}/run/", data=data, format="json"
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_generate_report_with_dependencies(
self, authenticated_client, report_template
):
sample_html = "<html><body>Sample Report</body></html>"
data = {"format": "html", "dependencies": {"client": 1}}
with patch(
"ee.reporting.views.generate_html", return_value=(sample_html, None)
) as mock_generate_html:
url = f"/reporting/templates/{report_template.id}/run/"
response = authenticated_client.post(url, data, format="json")
assert response.status_code == status.HTTP_200_OK
assert response.data == sample_html
mock_generate_html.assert_called_with(
template=report_template.template_md,
template_type=report_template.type,
css=report_template.template_css if report_template.template_css else "",
html_template=report_template.template_html.id
if report_template.template_html
else None,
variables=report_template.template_variables,
dependencies={"client": 1},
)
def test_unauthenticated_generate_report_view(
self, unauthenticated_client, report_template
):
response = unauthenticated_client.post(
f"/reporting/templates/{report_template.id}/run/"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestGenerateReportPreview:
def test_generate_report_preview_html_format(self, authenticated_client):
data = {
"template_md": "some template md",
"type": "some type",
"template_css": "some css",
"template_variables": {},
"dependencies": {},
"format": "html",
"debug": False,
}
with patch(
"ee.reporting.views.generate_html", return_value=("<html></html>", {})
) as mock_generate_html:
response = authenticated_client.post(
"/reporting/templates/preview/", data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert response.data == "<html></html>"
mock_generate_html.assert_called()
@patch("ee.reporting.views.generate_html", return_value=("<html></html>", {}))
@patch("ee.reporting.views.generate_pdf", return_value=b"some_pdf_bytes")
def test_generate_report_preview_pdf_format(
self, mock_generate_html, mock_generate_pdf, authenticated_client
):
data = {
"template_md": "some template md",
"type": "some type",
"template_css": "some css",
"template_variables": {},
"dependencies": {},
"format": "pdf",
"debug": False,
}
response = authenticated_client.post(
"/reporting/templates/preview/", data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert response["Content-Type"] == "application/pdf"
mock_generate_html.assert_called()
mock_generate_pdf.assert_called()
def test_generate_report_preview_debug(self, authenticated_client):
data = {
"template_md": "some template md",
"type": "markdown",
"template_css": "some css",
"template_variables": {},
"dependencies": {},
"format": "html",
"debug": True,
}
with patch(
"ee.reporting.views.generate_html",
return_value=("<html></html>", {"agent": baker.prepare("agents.Agent")}),
) as mock_generate_html:
response = authenticated_client.post(
"/reporting/templates/preview/", data, format="json"
)
assert response.status_code == status.HTTP_200_OK
assert "template" in response.data
assert "variables" in response.data
mock_generate_html.assert_called()
def test_generate_report_preview_invalid_data(self, authenticated_client):
data = {
"template_md": "some template md",
# Missing 'type'
"template_css": "some css",
"template_variables": {},
"dependencies": {},
"format": "invalid_format",
"debug": True,
}
response = authenticated_client.post(
"/reporting/templates/preview/", data, format="json"
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_generate_report_preview_template_error(self, authenticated_client):
data = {
"template_md": "some template md",
"type": "some type",
"template_css": "some css",
"template_variables": {},
"dependencies": {},
"format": "html",
"debug": True,
}
with patch(
"ee.reporting.views.generate_html",
side_effect=TemplateError("Some template error"),
):
response = authenticated_client.post(
"/reporting/templates/preview/", data, format="json"
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "Some template error" in response.data
def test_unauthenticated_generate_report_preview_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/templates/preview/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.django_db
class TestGetAllowedValues:
def test_valid_input(self, authenticated_client):
data = {
"variables": {
"user": {
"name": "Alice",
"roles": ["admin", "user"],
},
"count": 1,
},
"dependencies": {},
}
expected_response_data = {
"user": "Object",
"user.name": "str",
"user.roles": "Array (2 Results)",
"user.roles[0]": "str",
"count": "int",
}
with patch(
"ee.reporting.views.prep_variables_for_template",
return_value=data["variables"],
):
response = authenticated_client.post(
"/reporting/templates/preview/analysis/", data, format="json"
)
assert response.status_code == 200
assert response.data == expected_response_data
def test_empty_variables(self, authenticated_client):
data = {"variables": {}, "dependencies": {}}
with patch(
"ee.reporting.views.prep_variables_for_template",
return_value=data["variables"],
):
response = authenticated_client.post(
"/reporting/templates/preview/analysis/", data, format="json"
)
assert response.status_code == 200
assert response.data == None
def test_invalid_input(self, authenticated_client):
data = {"invalidKey": {}}
response = authenticated_client.post(
"/reporting/templates/preview/analysis/", data, format="json"
)
assert response.status_code == 400
def test_unauthenticated_variable_analysis_view(self, unauthenticated_client):
response = unauthenticated_client.post("/reporting/templates/preview/analysis/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -257,4 +257,12 @@ def test_move_folder_with_name_conflict(tmp_path: Path) -> None:
def test_move_directory_traversal(tmp_path: Path) -> None:
assert False
storage = ReportAssetStorage(location=tmp_path)
with pytest.raises(SuspiciousFileOperation):
# relative
storage.move(source="../../file", destination="../..")
with pytest.raises(SuspiciousFileOperation):
# absolute
storage.move(source="/etc", destination="/newpath")

View File

@@ -0,0 +1,107 @@
import pytest
from model_bakery import baker
from ..utils import generate_html, db_template_loader
@pytest.mark.django_db
class TestGenerateHTML:
@pytest.fixture
def base_template(self):
html = "<html>{% block content %}{% endblock %}</html>"
return baker.make(
"reporting.ReportHTMLTemplate", name="Base Template", html=html
)
def test_markdown_conversion(self):
template = "# This is a header"
result, _ = generate_html(
template=template, template_type="markdown", css="", variables=""
)
assert "<h1>This is a header</h1>" in result
def test_html_unchanged(self):
template = "<h1>This is a header</h1>"
result, _ = generate_html(
template=template, template_type="html", css="", variables=""
)
assert "<h1>This is a header</h1>" in result
def test_html_template_exists(self, base_template):
template = "{% block content %}<h1>This is a header</h1>{% endblock %}"
result, _ = generate_html(
template=template,
template_type="html",
html_template=base_template.id,
css="",
variables="",
)
assert "<html><h1>This is a header</h1></html>" == result
def test_html_template_does_not_exist(self):
template = "<h1>This is a header</h1>"
# check it doesn't raise an error.
generate_html(
template=template,
template_type="html",
html_template=999,
css="",
variables="",
)
def test_variables_processing(self):
template = "Hello {{ name }}"
variables = "name: John"
result, _ = generate_html(
template=template, template_type="html", css="", variables=variables
)
assert "Hello John" in result
def test_css_incorporation(self):
template = "<html><head><style>{{css}}</style></head></html>"
css = ".my-class { color: red; }"
result, _ = generate_html(
template=template, template_type="html", css=css, variables=""
)
assert css in result
@pytest.mark.django_db
class TestJinjaDBLoader:
@pytest.fixture
def report_base_template(self):
return baker.make(
"reporting.ReportHTMLTemplate", name="test_html_template", html="Test HTML"
)
@pytest.fixture
def report_template(self):
return baker.make(
"reporting.ReportTemplate", name="test_md_template", template_md="Test MD"
)
def test_load_base_template(self, report_base_template):
result = db_template_loader(report_base_template.name)
assert result == "Test HTML"
def test_fallback_to_md_template(self, report_template):
result = db_template_loader(report_template.name)
assert result == "Test MD"
def test_no_template_found(self):
# Will return None
result = db_template_loader("nonexistent_template")
assert result is None
def test_html_template_priority(self):
# Create both a ReportHTMLTemplate and a ReportTemplate with the same name
template_name = "common_template"
baker.make("reporting.ReportHTMLTemplate", name=template_name, html="Test HTML")
baker.make(
"reporting.ReportTemplate", name=template_name, template_md="Test MD"
)
result = db_template_loader(template_name)
assert result == "Test HTML" # HTML has priority

View File

@@ -0,0 +1,318 @@
import pytest
from unittest.mock import patch
from model_bakery import baker
from ..utils import (
process_dependencies,
process_data_sources,
process_chart_variables,
prep_variables_for_template,
)
@pytest.mark.django_db
class TestProcessingDependencies:
# Fixtures for creating model instances
@pytest.fixture
def test_client(db):
return baker.make("clients.Client", name="Test Client Name")
@pytest.fixture
def test_site(db):
return baker.make("clients.Site", name="Test Site Name")
@pytest.fixture
def test_agent(db):
return baker.make(
"agents.Agent", agent_id="Test Agent ID", hostname="Test Agent Name"
)
@pytest.fixture
def test_global(db):
return baker.make(
"core.GlobalKVStore", name="some_global_value", value="Some Global Value"
)
def test_replace_with_client_db_value(self, test_client):
variables = """
some_field:
client_field: {{client.name}}
"""
dependencies = {"client": test_client.id}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["client_field"] == test_client.name
def test_replace_with_site_db_value(self, test_site):
variables = """
some_field:
site_field: {{site.name}}
"""
dependencies = {"site": test_site.id}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["site_field"] == test_site.name
def test_replace_with_agent_db_value(self, test_agent):
variables = """
some_field:
agent_field: {{agent.hostname}}
"""
dependencies = {"agent": test_agent.agent_id}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["agent_field"] == test_agent.hostname
def test_replace_with_global_value(self, test_global):
variables = """
some_field:
global_field: {{global.some_global_value}}
"""
result = process_dependencies(variables=variables, dependencies={})
# Assuming you have a global value with key 'some_global_value' set to 'Some Global Value'
assert result["some_field"]["global_field"] == test_global.value
def test_replace_with_non_db_dependencies(self):
variables = """
some_field:
dependency_field: {{some_dependency}}
"""
dependencies = {"some_dependency": "Some Value"}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["dependency_field"] == "Some Value"
def test_missing_non_db_dependencies(self):
variables = """
some_field:
missing_dependency: "{{missing_dependency}}"
"""
dependencies = {} # Empty dependencies, simulating a missing dependency
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["missing_dependency"] == "{{missing_dependency}}"
def test_variables_dependencies_merge(self, test_agent):
variables = """
some_field:
agent_field: {{agent.agent_id}}
dependency_field: {{some_dependency}}
"""
dependencies = {"agent": test_agent.agent_id, "some_dependency": "Some Value"}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["some_field"]["agent_field"] == "Test Agent ID"
assert result["some_field"]["dependency_field"] == "Some Value"
# Additionally, assert the merged structure has both processed variables and dependencies
assert "agent" in result
assert isinstance(result["agent"], type(test_agent))
assert "some_dependency" in result
assert result["some_dependency"] == "Some Value"
def test_multiple_replacements(self, test_agent, test_client, test_site):
variables = """
fields:
agent_name: {{agent.hostname}}
client_name: {{client.name}}
site_name: {{site.name}}
dependency_1: {{dep_1}}
dependency_2: {{dep_2}}
"""
dependencies = {
"agent": test_agent.agent_id,
"client": test_client.id,
"site": test_site.id,
"dep_1": "Dependency Value 1",
"dep_2": "Dependency Value 2",
}
result = process_dependencies(variables=variables, dependencies=dependencies)
assert result["fields"]["agent_name"] == test_agent.hostname
assert result["fields"]["client_name"] == test_client.name
assert result["fields"]["site_name"] == test_site.name
assert result["fields"]["dependency_1"] == "Dependency Value 1"
assert result["fields"]["dependency_2"] == "Dependency Value 2"
# Also verify that the non-replaced fields from dependencies are present in the result.
assert "agent" in result
assert "client" in result
assert "site" in result
assert result["dep_1"] == "Dependency Value 1"
assert result["dep_2"] == "Dependency Value 2"
@pytest.mark.django_db
class TestProcessDataSourceVariables:
def test_process_data_sources_without_data_sources(self):
variables = {"other_key": "some_value"}
result = process_data_sources(variables=variables)
assert result == variables
def test_process_data_sources_with_non_dict_data_sources(self):
variables = {"data_sources": "some_string_value"}
result = process_data_sources(variables=variables)
assert result == variables
def test_process_data_sources_with_dict_data_sources(self):
variables = {
"data_sources": {
"source1": {"model": "agent", "other_field": "value"},
"source2": "some_string_value",
}
}
mock_queryset = {"data": "sample_data"}
# Mock build_queryset to return the mock_queryset
with patch("ee.reporting.utils.build_queryset", return_value=mock_queryset):
result = process_data_sources(variables=variables)
# Assert that the data_sources for "source1" is replaced with mock_queryset
assert result["data_sources"]["source1"] == mock_queryset
# Assert that the "source2" data remains unchanged
assert result["data_sources"]["source2"] == "some_string_value"
def test_process_data_sources_with_non_dict_data_sources(self):
variables = {
"data_sources": {
"source1": {"model": "agent", "other_field": "value"},
"source2": "some_string_value",
}
}
mock_queryset = 5
# Mock build_queryset to return the mock_queryset
with patch("ee.reporting.utils.build_queryset", return_value=mock_queryset):
result = process_data_sources(variables=variables)
# Assert that the data_sources for "source1" is replaced with mock_queryset
assert result["data_sources"]["source1"] == mock_queryset
# Assert that the "source2" data remains unchanged
assert result["data_sources"]["source2"] == "some_string_value"
class TestProcessChartVariables:
def test_process_chart_no_replace_data_frame(self):
# Scenario where path doesn't exist in variables
variables = {
"charts": {
"chart1": {
"chartType": "type1",
"outputType": "html",
"options": {"data_frame": "path.to.nonexistent"},
}
}
}
assert process_chart_variables(variables=variables) == variables
def test_process_chart_generate_chart_invocation(self):
# Ensure generate_chart is invoked with expected parameters
variables = {
"charts": {
"chart1": {
"chartType": "type1",
"outputType": "html",
"options": {},
"traces": "Some Traces",
"layout": "Some Layout",
}
}
}
with patch("ee.reporting.utils.generate_chart") as mock_generate_chart:
mock_generate_chart.return_value = "<html>Chart Here</html>"
_ = process_chart_variables(variables=variables)
mock_generate_chart.assert_called_once_with(
type="type1",
format="html",
options={},
traces="Some Traces",
layout="Some Layout",
)
def test_process_chart_missing_keys(self):
# Scenario where necessary keys are missing
variables = {"charts": {"chart1": {}}}
result = process_chart_variables(variables=variables)
assert result == variables # Expecting unchanged variables
def test_process_chart_no_charts(self):
# Scenario with no charts key or charts not a dict
variables1 = {}
variables2 = {"charts": "Not a dict"}
assert process_chart_variables(variables=variables1) == variables1
assert process_chart_variables(variables=variables2) == variables2
def test_process_chart_replaces_data_frame(self):
# Sample input
variables = {
"charts": {
"myChart": {
"chartType": "bar",
"outputType": "html",
"options": {"data_frame": "data_sources.sample_data"},
}
},
"data_sources": {"sample_data": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]},
}
with patch("ee.reporting.utils.generate_chart") as mock_generate_chart:
mock_generate_chart.return_value = "<html>Chart Here</html>"
result = process_chart_variables(variables=variables)
# Check if the generate_chart function was called correctly
mock_generate_chart.assert_called_once_with(
type="bar",
format="html",
options={"data_frame": [{"x": 1, "y": 2}, {"x": 2, "y": 3}]},
traces=None,
layout=None,
)
# Check if the returned data has the chart in place
assert "<html>Chart Here</html>" in result["charts"]["myChart"]
class TestPrepVariablesFunction:
def test_prep_variables_base(self):
result = prep_variables_for_template(variables="")
assert isinstance(result, dict)
assert not result
def test_prep_variables_with_dataqueries(self):
with patch(
"ee.reporting.utils.make_dataqueries_inline", return_value="test_yaml: true"
):
result = prep_variables_for_template(variables="dataquery: test")
assert result == {"test_yaml": True}
def test_prep_variables_with_dependencies(self):
with patch(
"ee.reporting.utils.process_dependencies",
return_value={"dependency_key": "dependency_value"},
):
result = prep_variables_for_template(
variables="", dependencies={"some_dependency": "value"}
)
assert "dependency_key" in result
assert result["dependency_key"] == "dependency_value"
def test_prep_variables_with_data_sources(self):
with patch(
"ee.reporting.utils.process_data_sources",
return_value={"data_source_key": "data_value"},
):
result = prep_variables_for_template(variables="data_sources: some_data")
assert "data_source_key" in result
assert result["data_source_key"] == "data_value"
def test_prep_variables_with_charts(self):
with patch(
"ee.reporting.utils.process_chart_variables",
return_value={"chart_key": "chart_value"},
):
result = prep_variables_for_template(variables="charts: some_chart")
assert "chart_key" in result
assert result["chart_key"] == "chart_value"

View File

@@ -0,0 +1,108 @@
import pytest
from model_bakery import baker
from ..utils import (
normalize_asset_url,
base64_encode_assets,
decode_base64_asset,
)
import base64
@pytest.mark.django_db
class TestBase64EncodeDecodeAssets:
def test_base64_encode_assets_multiple_valid_urls(self):
asset1 = baker.make("reporting.ReportAsset", _create_files=True)
asset2 = baker.make("reporting.ReportAsset", _create_files=True)
template = f"Some content with link asset://{asset1.id} and another link asset://{asset2.id}"
result = base64_encode_assets(template)
assert len(result) == 2
# check asset 1
assert result[0]["id"] == asset1.id
# Checking if the file content is correctly encoded to base64
encoded_content = base64.b64encode(asset1.file.file.read()).decode("utf-8")
assert result[0]["file"] == encoded_content
# check asset 2
assert result[1]["id"] == asset2.id
# Checking if the file content is correctly encoded to base64
encoded_content = base64.b64encode(asset2.file.file.read()).decode("utf-8")
assert result[1]["file"] == encoded_content
def test_base64_encode_assets_some_invalid_urls(self):
asset = baker.make("reporting.ReportAsset", _create_files=True)
invalid_id = "11111111-1111-1111-1111-111111111111"
template = f"Some content with link asset://{asset.id} and invalid link asset://{invalid_id}"
result = base64_encode_assets(template)
assert len(result) == 1
assert result[0]["id"] == asset.id
def test_base64_encode_assets_no_urls(self):
template = "Some content with no assets"
result = base64_encode_assets(template)
assert result == []
def test_base64_encode_assets_duplicate_urls(self):
asset = baker.make("reporting.ReportAsset", _create_files=True)
template = f"Some content with link asset://{asset.id} and another link asset://{asset.id}"
result = base64_encode_assets(template)
assert len(result) == 1
assert result[0]["id"] == asset.id
def test_decode_base64_asset_valid_input(self):
original_data = b"Hello, world!"
encoded_data = base64.b64encode(original_data).decode("utf-8")
result = decode_base64_asset(encoded_data)
assert result == original_data
def test_decode_base64_asset_invalid_input(self):
invalid_data = "Not a base64 encoded string."
with pytest.raises(base64.binascii.Error):
decode_base64_asset(invalid_data)
@pytest.mark.django_db
class TestNormalizeAssets:
def test_normalize_asset_url_valid_html_type(self):
asset = baker.make("reporting.ReportAsset", _create_files=True)
text = f"Some content with link asset://{asset.id} and more content"
result = normalize_asset_url(text, "html")
assert f"{asset.file.url}?id={asset.id}" in result
assert f"asset://{asset.id}" not in result
def test_normalize_asset_url_valid_pdf_type(self):
asset = baker.make("reporting.ReportAsset", _create_files=True)
text = f"Some content with link asset://{asset.id} and more content"
result = normalize_asset_url(text, "pdf")
assert f"file://{asset.file.path}" in result
assert f"asset://{asset.id}" not in result
def test_normalize_asset_url_invalid_asset(self):
invalid_id = "11111111-1111-1111-1111-111111111111" # UUID that's not in the DB
text = f"Some content with link asset://{invalid_id} and more content"
result = normalize_asset_url(text, "html")
# Since the asset doesn't exist, the URL should remain unchanged
assert f"asset://{invalid_id}" in result
def test_normalize_asset_url_no_asset(self):
text = "Some content with no assets"
result = normalize_asset_url(text, "html")
# The text remains unchanged
assert text == result

View File

@@ -18,7 +18,7 @@ from weasyprint.text.fonts import FontConfiguration
from .constants import REPORTING_MODELS
from .markdown.config import Markdown
from .models import ReportAsset, ReportHTMLTemplate, ReportTemplate
from .models import ReportAsset, ReportHTMLTemplate, ReportTemplate, ReportDataQuery
# regex for db data replacement
@@ -30,6 +30,8 @@ RE_ASSET_URL = re.compile(
r"(asset://([0-9a-f]{8}-[0-9a-f]{4}-[0-5][0-9a-f]{3}-[089ab][0-9a-f]{3}-[0-9a-f]{12}))"
)
RE_DEPENDENCY_VALUE = re.compile(r"(\{\{\s*(.*)\s*\}\})")
# this will lookup the Jinja parent template in the DB
# Example: {% extends "MASTER_TEMPLATE_NAME or REPORT_TEMPLATE_NAME" %}
@@ -43,7 +45,7 @@ def db_template_loader(template_name: str) -> Optional[str]:
try:
template = ReportTemplate.objects.get(name=template_name)
return template.template_md
except ReportHTMLTemplate.DoesNotExist:
except ReportTemplate.DoesNotExist:
pass
return None
@@ -75,12 +77,15 @@ def generate_html(
css: str = "",
html_template: Optional[int] = None,
variables: str = "",
dependencies: Dict[str, int] = {},
) -> Tuple[str, Optional[Dict[str, Any]]]:
# validate the template before doing anything. This will throw a TemplateError exception
dependencies: Optional[Dict[str, int]] = None,
) -> Tuple[str, Dict[str, Any]]:
if dependencies is None:
dependencies = {}
# validate the template
env.parse(template)
# convert template from markdown to html if type is markdown
# convert template
template_string = (
Markdown.convert(template) if template_type == "markdown" else template
)
@@ -89,7 +94,6 @@ def generate_html(
if html_template:
try:
html_template_name = ReportHTMLTemplate.objects.get(pk=html_template).name
template_string = (
f"""{{% extends "{html_template_name}" %}}\n{template_string}"""
)
@@ -98,28 +102,26 @@ def generate_html(
tm = env.from_string(template_string)
variables = prep_variables_for_template(
variables_dict = prep_variables_for_template(
variables=variables, dependencies=dependencies
)
if variables:
return (tm.render(css=css, **variables), variables)
else:
return (tm.render(css=css), None)
return (tm.render(css=css, **variables_dict), variables_dict)
def make_dataqueries_inline(*, variables: str) -> str:
variables_obj = yaml.safe_load(variables) or {}
if "data_sources" in variables_obj.keys() and isinstance(
variables_obj["data_sources"], dict
):
for key, value in variables_obj["data_sources"].items():
# data_source is referencing a saved data query
try:
variables_obj = yaml.safe_load(variables) or {}
except (yaml.parser.ParserError, yaml.YAMLError):
variables_obj = {}
data_sources = variables_obj.get("data_sources", {})
if isinstance(data_sources, dict):
for key, value in data_sources.items():
if isinstance(value, str):
ReportDataQuery = apps.get_model("reporting", "ReportDataQuery")
try:
variables_obj["data_sources"][key] = ReportDataQuery.objects.get(
name=value
).json_query
query = ReportDataQuery.objects.get(name=value).json_query
variables_obj["data_sources"][key] = query
except ReportDataQuery.DoesNotExist:
continue
@@ -129,101 +131,33 @@ def make_dataqueries_inline(*, variables: str) -> str:
def prep_variables_for_template(
*,
variables: str,
dependencies: Dict[str, Any] = {},
dependencies: Optional[Dict[str, Any]] = None,
limit_query_results: Optional[int] = None,
) -> Dict[str, Any]:
if not dependencies:
dependencies = {}
if not variables:
variables = ""
# replace any data queries in data_sources with the yaml
variables = make_dataqueries_inline(variables=variables)
# resolve dependencies that are agent, site, or client
if "client" in dependencies.keys():
Model = apps.get_model("clients", "Client")
dependencies["client"] = Model.objects.get(id=dependencies["client"])
elif "site" in dependencies.keys():
Model = apps.get_model("clients", "Site")
dependencies["site"] = Model.objects.get(id=dependencies["site"])
elif "agent" in dependencies.keys():
Model = apps.get_model("agents", "Agent")
dependencies["agent"] = Model.objects.get(agent_id=dependencies["agent"])
# check for variables that need to be replaced with the database values ({{client.name}}, {{agent.hostname}}, etc)
if variables and isinstance(variables, str):
# returns {{ model.prop }}, prop, model
for string, model, prop in re.findall(RE_DB_VALUE, variables):
value: Any = ""
# will be agent, site, client, or global
if model == "global":
value = get_db_value(string=f"{model}.{prop}")
elif model in dependencies.keys():
instance = dependencies[model]
value = (
get_db_value(string=prop, instance=instance) if instance else None
)
if value:
variables = variables.replace(string, str(value))
# check for any non-database dependencies and replace in variables
if variables and isinstance(variables, str):
RE_DEP_VALUE = re.compile(r"(\{\{\s*(.*)\s*\}\})")
for string, dep in re.findall(RE_DEP_VALUE, variables):
if dep in dependencies.keys():
variables = variables.replace(string, str(dependencies[dep]))
# load yaml variables if they exist
variables = yaml.safe_load(variables) or {}
# process report dependencies
variables_dict = process_dependencies(
variables=variables, dependencies=dependencies
)
# replace the data_sources with the actual data from DB. This will be passed to the template
# in the form of {{data_sources.data_source_name}}
if "data_sources" in variables.keys() and isinstance(
variables["data_sources"], dict
):
for key, value in variables["data_sources"].items():
if isinstance(value, dict):
data_source = value
_ = data_source.pop("meta") if "meta" in data_source.keys() else None
modified_datasource = resolve_model(data_source=data_source)
queryset = build_queryset(
data_source=modified_datasource, limit=limit_query_results
)
variables["data_sources"][key] = queryset
variables_dict = process_data_sources(
variables=variables_dict, limit_query_results=limit_query_results
)
# generate and replace charts in the variables
if "charts" in variables.keys() and isinstance(variables["charts"], dict):
for key, chart in variables["charts"].items():
# make sure chart options are present and a dict
if "options" not in chart.keys() and not isinstance(chart["options"], dict):
break
variables_dict = process_chart_variables(variables=variables_dict)
options = chart["options"]
# if data_frame is present and a str that means we need to replace it with a value from variables
if "data_frame" in options.keys() and isinstance(
options["data_frame"], str
):
# dot dotation to datasource if exists
data_source = options["data_frame"].split(".")
data = variables
for path in data_source:
if path in data.keys():
data = data[path]
else:
break
if data:
chart["options"]["data_frame"] = data
variables["charts"][key] = generate_chart(
type=chart["chartType"],
format=chart["outputType"],
options=chart["options"],
traces=chart["traces"] if "traces" in chart.keys() else None,
layout=chart["layout"] if "layout" in chart.keys() else None,
)
return {**variables, **dependencies}
return variables_dict
class ResolveModelException(Exception):
@@ -250,31 +184,33 @@ def resolve_model(*, data_source: Dict[str, Any]) -> Dict[str, Any]:
raise ResolveModelException("Model key must be present on data_source")
ALLOWED_OPERATIONS = (
from enum import Enum
class AllowedOperations(Enum):
# filtering
"only",
"defer",
"filter",
"exclude",
"limit",
"get",
"first",
"all",
ONLY = "only"
DEFER = "defer"
FILTER = "filter"
EXCLUDE = "exclude"
LIMIT = "limit"
GET = "get"
FIRST = "first"
ALL = "all"
# custom fields
"custom_fields",
CUSTOM_FIELDS = "custom_fields"
# presentation
"json",
JSON = "json"
# relations
"select_related",
"prefetch_related",
SELECT_RELATED = "select_related"
PREFETCH_RELATED = "prefetch_related"
# operations
"aggregate",
"annotate",
"count",
"values",
AGGREGATE = "aggregate"
ANNOTATE = "annotate"
COUNT = "count"
VALUES = "values"
# ordering
"order_by",
)
ORDER_BY = "order_by"
class InvalidDBOperationException(Exception):
@@ -284,22 +220,23 @@ class InvalidDBOperationException(Exception):
def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None) -> Any:
local_data_source = data_source
Model = local_data_source.pop("model")
limit = limit
count = False
get = False
first = False
all = False
isJson = False
columns = local_data_source["only"] if "only" in local_data_source.keys() else None
defer = local_data_source.get("defer", None)
columns = local_data_source.get("only", None)
fields_to_add = []
# create a base reporting queryset
queryset = Model.objects.using("reporting")
model_name = Model.__name__.lower()
for operation, values in local_data_source.items():
if operation not in ALLOWED_OPERATIONS:
# Usage in the build_queryset function:
if operation not in [op.value for op in AllowedOperations]:
raise InvalidDBOperationException(
f"DB operation: {operation} not allowed. Supported operations: only, defer, filter, get, first, all, custom_fields, exclude, limit, select_related, prefetch_related, annotate, aggregate, order_by, count"
f"DB operation: {operation} not allowed. Supported operations: {', '.join(op.value for op in AllowedOperations)}"
)
if operation == "meta":
@@ -308,10 +245,9 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
from core.models import CustomField
if model_name in ["client", "site", "agent"]:
fields = CustomField.objects.filter(model=model_name)
fields_to_add = [
field
for field in values
if CustomField.objects.filter(model=model_name, name=field).exists()
field for field in values if fields.filter(name=field).exists()
]
elif operation == "limit":
@@ -319,6 +255,9 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
elif operation == "count":
count = True
elif operation == "get":
# need to add a filter for the get if present
if isinstance(values, dict):
queryset = queryset.filter(**values)
get = True
elif operation == "first":
first = True
@@ -343,9 +282,20 @@ def build_queryset(*, data_source: Dict[str, Any], limit: Optional[int] = None)
queryset = queryset[:limit]
if columns:
# remove columns from only if defer is also present
if defer:
columns = [column for column in columns if column not in defer]
if "id" not in columns:
columns.append("id")
queryset = queryset.values(*columns)
elif defer:
# Since values seems to ignore only and defer, we need to get all columns and remove the defered ones.
# Then we can pass the rest of the columns in
included_fields = [
field.name for field in Model._meta.local_fields if field.name not in defer
]
queryset = queryset.values(*included_fields)
else:
queryset = queryset.values()
@@ -386,7 +336,7 @@ def add_custom_fields(
fields_to_add: List[str],
model_name: Literal["client", "site", "agent"],
dict_value: bool = False,
):
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
from core.models import CustomField
from agents.models import AgentCustomField
from clients.models import ClientCustomField, SiteCustomField
@@ -454,6 +404,8 @@ def add_custom_fields(
else:
row["custom_fields"][custom_field.name] = custom_field.default_value
return data
def normalize_asset_url(text: str, type: Literal["pdf", "html"]) -> str:
RE_ASSET_URL = re.compile(
@@ -496,7 +448,9 @@ def base64_encode_assets(template: str) -> List[Dict[str, Any]]:
"file": encoded_base64_str,
}
)
added_ids.append(asset.id)
added_ids.append(
str(asset.id)
) # need to convert uuid to str for easy comparison
except ReportAsset.DoesNotExist:
continue
@@ -509,6 +463,106 @@ def decode_base64_asset(asset: str) -> bytes:
return base64.b64decode(asset.encode("utf-8"))
def process_data_sources(
*, variables: Dict[str, Any], limit_query_results: Optional[int] = None
) -> Dict[str, Any]:
data_sources = variables.get("data_sources")
if isinstance(data_sources, dict):
for key, value in data_sources.items():
if isinstance(value, dict):
modified_datasource = resolve_model(data_source=value)
queryset = build_queryset(
data_source=modified_datasource, limit=limit_query_results
)
data_sources[key] = queryset
return variables
def process_dependencies(
*, variables: str, dependencies: Dict[str, Any]
) -> Dict[str, Any]:
DEPENDENCY_MODELS = {
"client": ("clients", "Client"),
"site": ("clients", "Site"),
"agent": ("agents", "Agent"),
}
# Resolve dependencies that are agent, site, or client
for dep, (app_label, model_name) in DEPENDENCY_MODELS.items():
if dep in dependencies:
Model = apps.get_model(app_label, model_name)
# Assumes each model has a unique lookup mechanism
lookup_param = "agent_id" if dep == "agent" else "id"
dependencies[dep] = Model.objects.get(**{lookup_param: dependencies[dep]})
# Handle database value placeholders
for string, model, prop in re.findall(RE_DB_VALUE, variables):
value = get_value_for_model(model, prop, dependencies)
if value:
variables = variables.replace(string, str(value))
# Handle non-database dependencies
for string, dep in re.findall(RE_DEPENDENCY_VALUE, variables):
value = dependencies.get(dep)
if value:
variables = variables.replace(string, str(value))
# Load yaml variables if they exist
variables = yaml.safe_load(variables) or {}
return {**variables, **dependencies}
def get_value_for_model(model: str, prop: str, dependencies: Dict[str, Any]) -> Any:
if model == "global":
return get_db_value(string=f"{model}.{prop}")
instance = dependencies.get(model)
return get_db_value(string=prop, instance=instance) if instance else None
def process_chart_variables(*, variables: Dict[str, Any]) -> Dict[str, Any]:
charts = variables.get("charts")
if not isinstance(charts, dict):
return variables
for key, chart in charts.items():
options = chart.get("options")
if not isinstance(options, dict):
continue
data_frame = options.get("data_frame")
if isinstance(data_frame, str):
data_source = data_frame.split(".")
data = variables
path_exists = True
for path in data_source:
data = data.get(path)
if data is None:
path_exists = False
break
if not path_exists:
continue
options["data_frame"] = data
traces = chart.get("traces")
layout = chart.get("layout")
charts[key] = generate_chart(
type=chart["chartType"],
format=chart["outputType"],
options=options,
traces=traces,
layout=layout,
)
return variables
def generate_chart(
*,
type: Literal["pie", "bar", "line"],

View File

@@ -8,7 +8,7 @@ import json
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from django.conf import settings as djangosettings
from django.core.exceptions import (
@@ -16,6 +16,7 @@ from django.core.exceptions import (
PermissionDenied,
SuspiciousFileOperation,
)
from django.db import transaction
from django.core.files.base import ContentFile
from django.http import FileResponse, HttpResponse, JsonResponse
from django.shortcuts import get_object_or_404
@@ -26,6 +27,10 @@ from rest_framework.response import Response
from rest_framework.serializers import (
CharField,
ListField,
IntegerField,
JSONField,
ChoiceField,
BooleanField,
ModelSerializer,
Serializer,
ValidationError,
@@ -35,7 +40,6 @@ from rest_framework.views import APIView
from tacticalrmm.utils import notify_error
from .models import ReportAsset, ReportDataQuery, ReportHTMLTemplate, ReportTemplate
from .settings import settings
from .storage import report_assets_fs
from .utils import (
base64_encode_assets,
@@ -112,6 +116,9 @@ class GenerateReport(APIView):
format = request.data["format"]
if format not in ["pdf", "html"]:
return notify_error("Report format is incorrect.")
try:
html_report, _ = generate_html(
template=template.template_md,
@@ -128,7 +135,7 @@ class GenerateReport(APIView):
if format == "html":
return Response(html_report)
elif format == "pdf":
else:
pdf_bytes = generate_pdf(html=html_report)
return FileResponse(
@@ -136,8 +143,7 @@ class GenerateReport(APIView):
content_type="application/pdf",
filename=f"{template.name}.pdf",
)
else:
return notify_error("Report format is incorrect.")
except TemplateError as error:
if hasattr(error, "lineno"):
return notify_error(f"Line {error.lineno}: {error.message}")
@@ -148,75 +154,94 @@ class GenerateReport(APIView):
class GenerateReportPreview(APIView):
class InputRequest:
template_md: str
type: Literal["markdown", "html"]
template_css: str
template_html: int
template_variables: Dict[str, Any]
dependencies: Dict[str, Any]
format: Literal["html", "pdf"]
debug: bool
class InputSerializer(Serializer[InputRequest]):
template_md = CharField()
type = CharField()
template_css = CharField(required=False)
template_html = IntegerField(required=False)
template_variables = JSONField()
dependencies = JSONField()
format = ChoiceField(choices=["html", "pdf"])
debug = BooleanField(default=False)
def post(self, request: Request) -> Union[FileResponse, Response]:
try:
debug = request.data["debug"]
report_data = self._parse_and_validate_request_data(request.data)
html_report, variables = generate_html(
template=request.data["template_md"],
template_type=request.data["type"],
css=request.data["template_css"],
html_template=(
request.data["template_html"]
if "template_html" in request.data.keys()
else None
),
variables=request.data["template_variables"],
dependencies=request.data["dependencies"],
template=report_data["template_md"],
template_type=report_data["type"],
css=report_data.get("template_css", ""),
html_template=report_data.get("template_html"),
variables=report_data["template_variables"],
dependencies=report_data["dependencies"],
)
response_format = request.data["format"]
debug = request.data["debug"]
html_report = normalize_asset_url(html_report, response_format)
if debug:
# need to serialize the models if an agent, site, or client is specified
if variables:
from django.forms.models import model_to_dict
if "agent" in variables.keys():
variables["agent"] = model_to_dict(
variables["agent"],
fields=[
field.name for field in variables["agent"]._meta.fields
],
)
if "site" in variables.keys():
variables["site"] = model_to_dict(
variables["site"],
fields=[
field.name for field in variables["site"]._meta.fields
],
)
if "client" in variables.keys():
variables["client"] = model_to_dict(
variables["client"],
fields=[
field.name for field in variables["client"]._meta.fields
],
)
return Response({"template": html_report, "variables": variables})
elif response_format == "html":
return Response(html_report)
else:
pdf_bytes = generate_pdf(html=html_report)
return FileResponse(
ContentFile(pdf_bytes),
content_type="application/pdf",
filename=f"preview.pdf",
)
if report_data["debug"]:
return self._process_debug_response(html_report, variables)
return self._generate_response_based_on_format(
html_report, report_data["format"]
)
except TemplateError as error:
if hasattr(error, "lineno"):
return notify_error(f"Line {error.lineno}: {error.message}")
else:
return notify_error(str(error))
return self._handle_template_error(error)
except Exception as error:
return notify_error(str(error))
def _parse_and_validate_request_data(self, data: Dict[str, Any]) -> Any:
serializer = self.InputSerializer(data=data)
serializer.is_valid(raise_exception=True)
return serializer.validated_data
def _process_debug_response(
self, html_report: str, variables: Dict[str, Any]
) -> Response:
if variables:
from django.forms.models import model_to_dict
# serialize any model instances provided
for model_name in ["agent", "site", "client"]:
if model_name in variables:
model_instance = variables[model_name]
serialized_model = model_to_dict(
model_instance,
fields=[field.name for field in model_instance._meta.fields],
)
variables[model_name] = serialized_model
return Response({"template": html_report, "variables": variables})
def _generate_response_based_on_format(
self, html_report: str, format: Literal["html", "pdf"]
) -> Union[Response, FileResponse]:
html_report = normalize_asset_url(html_report, format)
if format == "html":
return Response(html_report)
else:
pdf_bytes = generate_pdf(html=html_report)
return FileResponse(
ContentFile(pdf_bytes),
content_type="application/pdf",
filename="preview.pdf",
)
def _handle_template_error(self, error: TemplateError) -> Response:
if hasattr(error, "lineno"):
error_message = f"Line {error.lineno}: {error.message}"
else:
error_message = str(error)
return notify_error(error_message)
class ExportReportTemplate(APIView):
def post(self, request: Request, pk: int) -> Response:
@@ -253,121 +278,156 @@ class ExportReportTemplate(APIView):
class ImportReportTemplate(APIView):
@transaction.atomic
def post(self, request: Request) -> Response:
import random
import string
base_template = None
report_template = None
try:
template_obj = json.loads(request.data["template"])
if "template" not in template_obj.keys():
return notify_error("Missing template information")
# import base template if exists
base_template_id = self._import_base_template(
template_obj.get("base_template")
)
# create base template
if "base_template" in template_obj.keys() and template_obj["base_template"]:
# check if there is a name conflict and append some characters to the name if so
if (
"name" in template_obj["base_template"].keys()
and ReportHTMLTemplate.objects.filter(
name=template_obj["base_template"]["name"]
).exists()
):
template_obj["base_template"]["name"] += "".join(
random.choice(string.ascii_lowercase) for i in range(6)
)
base_template = ReportHTMLTemplate.objects.create(
**template_obj["base_template"]
)
base_template.refresh_from_db()
# import base template if exists
report_template = self._import_report_template(
template_obj.get("template"), base_template_id
)
# create template
if "template" in template_obj.keys() and template_obj["template"]:
# check if there is a name conflict and append some characters to the name if so
if (
"name" in template_obj["template"].keys()
and ReportTemplate.objects.filter(
name=template_obj["template"]["name"]
).exists()
):
template_obj["template"]["name"] += "".join(
random.choice(string.ascii_lowercase) for i in range(6)
)
report_template = ReportTemplate.objects.create(
**template_obj["template"],
template_html=base_template if base_template else None,
)
# import assets
if "assets" in template_obj.keys() and isinstance(
template_obj["assets"], list
):
for asset in template_obj["assets"]:
# asset should have id, name, and file fields
try:
asset = ReportAsset(
id=asset["id"], file=decode_base64_asset(asset["file"])
)
asset.file.name = os.path.join(
settings.REPORTING_ASSETS_BASE_PATH, asset["name"]
)
asset.save()
except:
pass
# import assets if exists
self._import_assets(template_obj.get("assets"))
return Response(ReportTemplateSerializer(report_template).data)
except:
base_template.delete() if base_template else None
report_template.delete() if report_template else None
return notify_error("There was an error with the request")
except Exception as e:
# rollback db transaction if any exception occurs
transaction.set_rollback(True)
return notify_error(str(e))
def _import_base_template(
self, base_template_data: Optional[Dict[str, Any]] = None
) -> Optional[int]:
if base_template_data:
# Check name conflict and modify name if necessary
name = base_template_data.get("name")
html = base_template_data.get("html")
if not name:
raise ValidationError("base_template is missing 'name' key")
if not html:
raise ValidationError("base_template is missing 'html' field")
if ReportHTMLTemplate.objects.filter(name=name).exists():
name += self._generate_random_string()
base_template = ReportHTMLTemplate.objects.create(name=name, html=html)
base_template.refresh_from_db()
return base_template.id
return None
def _import_report_template(
self,
report_template_data: Dict[str, Any],
base_template_id: Optional[int] = None,
) -> "ReportTemplate":
if report_template_data:
name = report_template_data.pop("name", None)
template_md = report_template_data.get("template_md")
if not name:
raise ValidationError("template requires a 'name' key")
if not template_md:
raise ValidationError("template requires a 'template_md' field")
if ReportTemplate.objects.filter(name=name).exists():
name += self._generate_random_string()
report_template = ReportTemplate.objects.create(
name=name, template_html_id=base_template_id, **report_template_data
)
return report_template
else:
raise ValidationError("'template' key is required in input")
def _import_assets(self, assets: List[Dict[str, Any]]) -> None:
from django.core.files import File
import io
from .storage import report_assets_fs
if isinstance(assets, list):
for asset in assets:
parent_folder = report_assets_fs.getreldir(path=asset["name"])
path = report_assets_fs.get_available_name(
os.path.join(parent_folder, asset["name"])
)
asset_obj = ReportAsset(
id=asset["id"],
file=File(
io.BytesIO(decode_base64_asset(asset["file"])),
name=path,
),
)
asset_obj.save()
@staticmethod
def _generate_random_string(length: int = 6) -> str:
import random
import string
return "".join(random.choice(string.ascii_lowercase) for i in range(length))
class GetAllowedValues(APIView):
def post(self, request: Request) -> Response:
# pass in blank template. We are just interested in variables
variables = request.data.get("variables", None)
if variables is None:
return notify_error("'variables' is required")
dependencies = request.data.get("dependencies", None)
# process variables and dependencies
variables = prep_variables_for_template(
variables=request.data["variables"],
dependencies=request.data["dependencies"],
variables=variables,
dependencies=dependencies,
limit_query_results=1, # only get first item for querysets
)
# recursive function to get properties on any embedded objects
def get_dot_notation(
d: Dict[str, Any], parent_key: str = "", path: Optional[str] = None
) -> Dict[str, Any]:
items = {}
for k, v in d.items():
new_key = f"{parent_key}.{k}" if parent_key else k
if isinstance(v, dict):
items[new_key] = "Object"
items.update(get_dot_notation(v, new_key, path=path))
elif isinstance(v, list) or type(v).__name__ == "PermissionQuerySet":
items[new_key] = f"Array ({len(v)} Results)"
if v: # Ensure the list is not empty
item = v[0]
if isinstance(item, dict):
items.update(
get_dot_notation(item, f"{new_key}[0]", path=path)
)
else:
items[f"{new_key}[0]"] = type(item).__name__
else:
items[new_key] = type(v).__name__
return items
if variables:
return Response(get_dot_notation(variables))
return Response(self._get_dot_notation(variables))
else:
return Response()
# recursive function to get properties on any embedded objects
def _get_dot_notation(
self, d: Dict[str, Any], parent_key: str = "", path: Optional[str] = None
) -> Dict[str, Any]:
items = {}
for k, v in d.items():
new_key = f"{parent_key}.{k}" if parent_key else k
if isinstance(v, dict):
items[new_key] = "Object"
items.update(self._get_dot_notation(v, new_key, path=path))
elif isinstance(v, list) or type(v).__name__ == "PermissionQuerySet":
items[new_key] = f"Array ({len(v)} Results)"
if v: # Ensure the list is not empty
item = v[0]
if isinstance(item, dict):
items.update(
self._get_dot_notation(item, f"{new_key}[0]", path=path)
)
else:
items[f"{new_key}[0]"] = type(item).__name__
else:
items[new_key] = type(v).__name__
return items
class GetReportAssets(APIView):
def get(self, request: Request) -> Response:
path = request.query_params.get("path", "").lstrip("/")
directories, files = report_assets_fs.listdir(path)
try:
directories, files = report_assets_fs.listdir(path)
except FileNotFoundError:
return notify_error("The path is invalid")
response = list()
# parse directories
@@ -401,62 +461,61 @@ class GetReportAssets(APIView):
class GetAllAssets(APIView):
def get(self, request: Request) -> Response:
only_folders = request.query_params.get("OnlyFolders", None)
only_folders = request.query_params.get("onlyFolders", None)
only_folders = True if only_folders and only_folders == "true" else False
# pull report assets from the database so we can pair with the file system assets
assets = ReportAsset.objects.all()
# TODO: define a Type for file node
def walk_folder_and_return_node(path: str):
for current_dir, subdirs, files in os.walk(path):
print(current_dir, subdirs, files)
current_dir = "Report Assets" if current_dir == "." else current_dir
node = {
"type": "folder",
"name": current_dir.replace("./", ""),
"path": path.replace("./", ""),
"children": list(),
"selectable": False,
"icon": "folder",
"iconColor": "yellow-9",
}
for dirname in subdirs:
dirpath = f"{path}/{dirname}"
node["children"].append(
walk_folder_and_return_node(dirpath) # recursively call
)
if not only_folders:
for filename in files:
print(current_dir, filename)
path = f"{current_dir}/{filename}".replace("./", "").replace(
"Report Assets/", ""
)
try:
# need to remove the relative path
id = assets.get(file=path).id
node["children"].append(
{
"id": id,
"type": "file",
"name": filename,
"path": path,
"icon": "description",
}
)
except ReportAsset.DoesNotExist:
pass
return node
try:
os.chdir(report_assets_fs.base_location)
response = [walk_folder_and_return_node(".")]
response = [self._walk_folder_and_return_node(".", only_folders)]
return Response(response)
except FileNotFoundError:
return notify_error("Unable to process request")
# TODO: define a Type for file node
def _walk_folder_and_return_node(self, path: str, only_folders: bool = False):
# pull report assets from the database so we can pair with the file system assets
assets = ReportAsset.objects.all()
for current_dir, subdirs, files in os.walk(path):
current_dir = "Report Assets" if current_dir == "." else current_dir
node = {
"type": "folder",
"name": current_dir.replace("./", ""),
"path": path.replace("./", ""),
"children": list(),
"selectable": False,
"icon": "folder",
"iconColor": "yellow-9",
}
for dirname in subdirs:
dirpath = f"{path}/{dirname}"
node["children"].append(
# recursively call
self._walk_folder_and_return_node(dirpath, only_folders)
)
if not only_folders:
for filename in files:
path = f"{current_dir}/{filename}".replace("./", "").replace(
"Report Assets/", ""
)
try:
# need to remove the relative path
id = assets.get(file=path).id
node["children"].append(
{
"id": id,
"type": "file",
"name": filename,
"path": path,
"icon": "description",
}
)
except ReportAsset.DoesNotExist:
pass
return node
class RenameReportAsset(APIView):
class InputRequest:
@@ -496,6 +555,8 @@ class CreateAssetFolder(APIView):
def post(self, request: Request) -> Response:
path = request.data["path"].lstrip("/") if "path" in request.data else ""
if not path:
return notify_error("'path' in required.")
try:
new_path = report_assets_fs.createfolder(path=path)
return Response(new_path)
@@ -752,10 +813,13 @@ class QuerySchema(APIView):
schema_path = "static/reporting/schemas/query_schema.json"
if djangosettings.DEBUG:
with open(djangosettings.BASE_DIR / schema_path, "r") as f:
data = json.load(f)
try:
with open(djangosettings.BASE_DIR / schema_path, "r") as f:
data = json.load(f)
return JsonResponse(data)
return JsonResponse(data)
except FileNotFoundError:
return notify_error("There was an error getting the file")
else:
response = HttpResponse()
response["X-Accel-Redirect"] = f"/{schema_path}"

View File

@@ -6,7 +6,7 @@ import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from functools import wraps
from typing import List, Optional, Union
from typing import List, Optional, Union, Literal, TYPE_CHECKING
from zoneinfo import ZoneInfo
import requests
@@ -41,6 +41,11 @@ from tacticalrmm.helpers import (
notify_error,
)
if TYPE_CHECKING:
from agents.models import Agent
from clients.models import Site, Client
def generate_winagent_exe(
*,
client: int,
@@ -286,35 +291,75 @@ def get_latest_trmm_ver() -> str:
return "error"
# Receives something like {{ client.name }} and a Model instance of Client, Site, or Agent. If an
# Receives something like {{ client.name }} and a Model instance of Client, Site, or Agent. If an
# agent instance is passed it will resolve the value of agent.client.name and return the agent's client name.
#
#
# You can query custom fields by using their name. {{ site.Custom Field Name }}
#
# This will recursively lookup values for relations. {{ client.site.id }}
#
#
# You can also use {{ global.value }} without an obj instance to use the global key store
def get_db_value(*, string: str, instance=None) -> Union[str, List, True, False, None]:
def get_db_value(
*, string: str, instance: Optional[Union["Agent", "Client", "Site"]] = None
) -> Union[str, List[str], Literal[True], Literal[False], None]:
from core.models import CustomField, GlobalKVStore
# get properties into an array
props = string.strip().split(".")
model = props[0]
# value is in the global keystore and replace value
if props[0] == "global" and len(props) == 2:
try:
try:
return GlobalKVStore.objects.get(name=props[1]).value
except GlobalKVStore.DoesNotExist:
DebugLog.error(
log_type=DebugLogType.SCRIPTING,
message=f"Couldn't lookup value for: {string}. Make sure it exists in CoreSettings > Key Store", # type:ignore
message=f"Couldn't lookup value for: {string}. Make sure it exists in CoreSettings > Key Store",
)
return None
if not instance:
# instance must be set if not global property
return None
# custom field lookup
try:
# looking up custom field directly on this instance
if len(props) == 2:
field = CustomField.objects.get(model=props[0], name=props[1])
model_fields = getattr(field, f"{props[0]}_fields")
try:
# resolve the correct model id
if props[0] != instance.__class__.__name__.lower():
value = model_fields.get(
**{props[0]: getattr(instance, props[0])}
).value
else:
value = model_fields.get(**{f"{props[0]}_id": instance.id}).value
if field.type != CustomFieldType.CHECKBOX:
if value:
return value
else:
return field.default_value
else:
return bool(value)
except:
return (
field.default_value
if field.type != CustomFieldType.CHECKBOX
else bool(field.default_value)
)
except CustomField.DoesNotExist:
pass
# if the instance is the same as the first prop. We remove it.
if props[0] == instance.__class__.__name__.lower():
del props[0]
instance_value = instance
# look through all properties and return the value
@@ -324,27 +369,16 @@ def get_db_value(*, string: str, instance=None) -> Union[str, List, True, False,
if callable(value):
return None
instance_value = value
else:
try:
field = CustomField.objects.get(model=props[0], name=prop)
model_fields = getattr(field, f"{props[0]}_fields")
try:
value = model_fields.get(**{props[0]: instance}).value
return value if field.type != CustomFieldType.CHECKBOX else bool(field.default_value)
except:
return field.default_value if field.type != CustomFieldType.CHECKBOX else bool(field.default_value)
except CustomField.DoesNotExist:
return None
if not instance_value:
return None
return instance_value
def replace_arg_db_values(
string: str, instance=None, shell: str = None, quotes=True # type:ignore
) -> Union[str, None]:
# resolve the value
value = get_db_value(string=string, instance=instance)
@@ -365,12 +399,8 @@ def replace_arg_db_values(
# format args for list
elif isinstance(value, list):
return (
f"'{format_shell_array(value)}'"
if quotes
else format_shell_array(value)
)
return f"'{format_shell_array(value)}'" if quotes else format_shell_array(value)
# format args for bool
elif value is True or value is False:
return format_shell_bool(value, shell)