diff --git a/tests/base.py b/tests/conftest.py similarity index 51% rename from tests/base.py rename to tests/conftest.py index 3e7567a5673..07340b50532 100644 --- a/tests/base.py +++ b/tests/conftest.py @@ -5,13 +5,19 @@ """Shared resources for tests.""" -import unittest + from functools import lru_cache from typing import Union -from detection_rules.rule import TOMLRule -from detection_rules.rule_loader import DeprecatedCollection, DeprecatedRule, RuleCollection, production_filter +import pytest +from detection_rules.rule import TOMLRule +from detection_rules.rule_loader import ( + DeprecatedCollection, + DeprecatedRule, + RuleCollection, + production_filter, +) RULE_LOADER_FAIL = False RULE_LOADER_FAIL_MSG = None @@ -28,48 +34,41 @@ def default_bbr() -> RuleCollection: return RuleCollection.default_bbr() -class BaseRuleTest(unittest.TestCase): +@pytest.fixture(scope="class") +def rule_data(request): + global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG + if not RULE_LOADER_FAIL: + try: + rc = default_rules() + rc_bbr = default_bbr() + request.cls.all_rules = rc.rules + request.cls.rule_lookup = rc.id_map + request.cls.production_rules = rc.filter(production_filter) + request.cls.bbr = rc_bbr.rules + request.cls.deprecated_rules: DeprecatedCollection = rc.deprecated + except Exception as e: + RULE_LOADER_FAIL = True + RULE_LOADER_FAIL_MSG = str(e) + + +@pytest.mark.usefixtures("rule_data") +class TestBaseRule: """Base class for shared test cases which need to load rules""" RULE_LOADER_FAIL = False RULE_LOADER_FAIL_MSG = None RULE_LOADER_FAIL_RAISED = False - @classmethod - def setUpClass(cls): - global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG - - # too noisy; refactor - # os.environ["DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE"] = "1" - - if not RULE_LOADER_FAIL: - try: - rc = default_rules() - rc_bbr = default_bbr() - cls.all_rules = rc.rules - cls.rule_lookup = rc.id_map - cls.production_rules = rc.filter(production_filter) - cls.bbr = rc_bbr.rules - cls.deprecated_rules: DeprecatedCollection = rc.deprecated - except Exception as e: - RULE_LOADER_FAIL = True - RULE_LOADER_FAIL_MSG = str(e) - @staticmethod def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->') -> str: return f'{rule.id} - {rule.name}{trailer or ""}' - def setUp(self) -> None: + def setup_method(self, method): global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG, RULE_LOADER_FAIL_RAISED - if RULE_LOADER_FAIL: # limit the loader failure to just one run # raise a dedicated test failure for the loader if not RULE_LOADER_FAIL_RAISED: RULE_LOADER_FAIL_RAISED = True - with self.subTest('Test that the rule loader loaded with no validation or other failures.'): - self.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}') - - self.skipTest('Rule loader failure') - else: - super().setUp() + pytest.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}') + pytest.skip('Rule loader failure') diff --git a/tests/kuery/test_dsl.py b/tests/kuery/test_dsl.py index 4af3217ebc0..8595ed53114 100644 --- a/tests/kuery/test_dsl.py +++ b/tests/kuery/test_dsl.py @@ -3,15 +3,14 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest import kql -class TestKQLtoDSL(unittest.TestCase): +class TestKQLtoDSL: def validate(self, kql_source, dsl, **kwargs): actual_dsl = kql.to_dsl(kql_source, **kwargs) - self.assertListEqual(list(actual_dsl), ["bool"]) - self.assertDictEqual(actual_dsl["bool"], dsl) + assert list(actual_dsl) == ["bool"] + assert actual_dsl["bool"] == dsl def test_field_match(self): def match(**kv): diff --git a/tests/kuery/test_eql2kql.py b/tests/kuery/test_eql2kql.py index de0e4404cb7..f346a555209 100644 --- a/tests/kuery/test_eql2kql.py +++ b/tests/kuery/test_eql2kql.py @@ -4,14 +4,15 @@ # 2.0. import eql -import unittest +import pytest + import kql -class TestEql2Kql(unittest.TestCase): +class TestEql2Kql: def validate(self, kql_source, eql_source): - self.assertEqual(kql_source, str(kql.from_eql(eql_source))) + assert kql_source == str(kql.from_eql(eql_source)) def test_field_equals(self): self.validate("field:value", "field == 'value'") @@ -58,6 +59,7 @@ def test_wildcard_field(self): self.validate('field:value-*', 'field : "value-*"') self.validate('field:value-?', 'field : "value-?"') - with eql.parser.elasticsearch_validate_optional_fields, self.assertRaises(AssertionError): + # pytest.raises is used to handle exceptions in pytest + with eql.parser.elasticsearch_validate_optional_fields, pytest.raises(AssertionError): self.validate('field:"value-*"', 'field == "value-*"') self.validate('field:"value-?"', 'field == "value-?"') diff --git a/tests/kuery/test_evaluator.py b/tests/kuery/test_evaluator.py index 97033e97b03..08e605a433d 100644 --- a/tests/kuery/test_evaluator.py +++ b/tests/kuery/test_evaluator.py @@ -3,29 +3,19 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest - import kql -class EvaluatorTests(unittest.TestCase): - +class TestEvaluator: document = { "number": 1, "boolean": True, "ip": "192.168.16.3", "string": "hello world", - "string_list": ["hello world", "example"], "number_list": [1, 2, 3], "boolean_list": [True, False], - "structured": [ - { - "a": [ - {"b": 1} - ] - } - ], + "structured": [{"a": [{"b": 1}]}], } def evaluate(self, source_text, document=None): @@ -36,89 +26,87 @@ def evaluate(self, source_text, document=None): return evaluator(document) def test_single_value(self): - self.assertTrue(self.evaluate('number:1')) - self.assertTrue(self.evaluate('number:"1"')) - self.assertTrue(self.evaluate('boolean:true')) - self.assertTrue(self.evaluate('string:"hello world"')) + assert self.evaluate('number:1') + assert self.evaluate('number:"1"') + assert self.evaluate('boolean:true') + assert self.evaluate('string:"hello world"') - self.assertFalse(self.evaluate('number:0')) - self.assertFalse(self.evaluate('boolean:false')) - self.assertFalse(self.evaluate('string:"missing"')) + assert not self.evaluate('number:0') + assert not self.evaluate('boolean:false') + assert not self.evaluate('string:"missing"') def test_list_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertTrue(self.evaluate('number_list:2')) - self.assertTrue(self.evaluate('number_list:3')) + assert self.evaluate('number_list:1') + assert self.evaluate('number_list:2') + assert self.evaluate('number_list:3') - self.assertTrue(self.evaluate('boolean_list:true')) - self.assertTrue(self.evaluate('boolean_list:false')) + assert self.evaluate('boolean_list:true') + assert self.evaluate('boolean_list:false') - self.assertTrue(self.evaluate('string_list:"hello world"')) - self.assertTrue(self.evaluate('string_list:example')) + assert self.evaluate('string_list:"hello world"') + assert self.evaluate('string_list:example') - self.assertFalse(self.evaluate('number_list:4')) - self.assertFalse(self.evaluate('string_list:"missing"')) + assert not self.evaluate('number_list:4') + assert not self.evaluate('string_list:"missing"') def test_and_values(self): - self.assertTrue(self.evaluate('number_list:(1 and 2)')) - self.assertTrue(self.evaluate('boolean_list:(false and true)')) - self.assertFalse(self.evaluate('string:("missing" and "hello world")')) + assert self.evaluate('number_list:(1 and 2)') + assert self.evaluate('boolean_list:(false and true)') + assert not self.evaluate('string:("missing" and "hello world")') - self.assertFalse(self.evaluate('number:(0 and 1)')) - self.assertFalse(self.evaluate('boolean:(false and true)')) + assert not self.evaluate('number:(0 and 1)') + assert not self.evaluate('boolean:(false and true)') def test_not_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertFalse(self.evaluate('not number_list:1')) - self.assertFalse(self.evaluate('number_list:(not 1)')) + assert self.evaluate('number_list:1') + assert not self.evaluate('not number_list:1') + assert not self.evaluate('number_list:(not 1)') def test_or_values(self): - self.assertTrue(self.evaluate('number:(0 or 1)')) - self.assertTrue(self.evaluate('number:(1 or 2)')) - self.assertTrue(self.evaluate('boolean:(false or true)')) - self.assertTrue(self.evaluate('string:("missing" or "hello world")')) + assert self.evaluate('number:(0 or 1)') + assert self.evaluate('number:(1 or 2)') + assert self.evaluate('boolean:(false or true)') + assert self.evaluate('string:("missing" or "hello world")') - self.assertFalse(self.evaluate('number:(0 or 3)')) + assert not self.evaluate('number:(0 or 3)') def test_and_expr(self): - self.assertTrue(self.evaluate('number:1 and boolean:true')) - - self.assertFalse(self.evaluate('number:1 and boolean:false')) + assert self.evaluate('number:1 and boolean:true') + assert not self.evaluate('number:1 and boolean:false') def test_or_expr(self): - self.assertTrue(self.evaluate('number:1 or boolean:false')) - self.assertFalse(self.evaluate('number:0 or boolean:false')) + assert self.evaluate('number:1 or boolean:false') + assert not self.evaluate('number:0 or boolean:false') def test_range(self): - self.assertTrue(self.evaluate('number < 2')) - self.assertFalse(self.evaluate('number > 2')) + assert self.evaluate('number < 2') + assert not self.evaluate('number > 2') def test_cidr_match(self): - self.assertTrue(self.evaluate('ip:192.168.0.0/16')) - - self.assertFalse(self.evaluate('ip:10.0.0.0/8')) + assert self.evaluate('ip:192.168.0.0/16') + assert not self.evaluate('ip:10.0.0.0/8') def test_quoted_wildcard(self): - self.assertFalse(self.evaluate("string:'*'")) - self.assertFalse(self.evaluate("string:'?'")) + assert not self.evaluate("string:'*'") + assert not self.evaluate("string:'?'") def test_wildcard(self): - self.assertTrue(self.evaluate('string:hello*')) - self.assertTrue(self.evaluate('string:*world')) - self.assertFalse(self.evaluate('string:foobar*')) + assert self.evaluate('string:hello*') + assert self.evaluate('string:*world') + assert not self.evaluate('string:foobar*') def test_field_exists(self): - self.assertTrue(self.evaluate('number:*')) - self.assertTrue(self.evaluate('boolean:*')) - self.assertTrue(self.evaluate('ip:*')) - self.assertTrue(self.evaluate('string:*')) - self.assertTrue(self.evaluate('string_list:*')) - self.assertTrue(self.evaluate('number_list:*')) - self.assertTrue(self.evaluate('boolean_list:*')) + assert self.evaluate('number:*') + assert self.evaluate('boolean:*') + assert self.evaluate('ip:*') + assert self.evaluate('string:*') + assert self.evaluate('string_list:*') + assert self.evaluate('number_list:*') + assert self.evaluate('boolean_list:*') - self.assertFalse(self.evaluate('a:*')) + assert not self.evaluate('a:*') def test_flattening(self): - self.assertTrue(self.evaluate("structured.a.b:*")) - self.assertTrue(self.evaluate("structured.a.b:1")) - self.assertFalse(self.evaluate("structured.a.b:2")) + assert self.evaluate("structured.a.b:*") + assert self.evaluate("structured.a.b:1") + assert not self.evaluate("structured.a.b:2") diff --git a/tests/kuery/test_kql2eql.py b/tests/kuery/test_kql2eql.py index bfa9a242589..29867d63477 100644 --- a/tests/kuery/test_kql2eql.py +++ b/tests/kuery/test_kql2eql.py @@ -3,16 +3,16 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest import eql +import pytest import kql -class TestKql2Eql(unittest.TestCase): +class TestKql2Eql: def validate(self, kql_source, eql_source, schema=None): - self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source)) + assert kql.to_eql(kql_source, schema=schema) == eql.parse_expression(eql_source) def test_field_equals(self): self.validate("field:value", "field == 'value'") @@ -36,7 +36,7 @@ def test_and_query(self): self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'") def test_nested_query(self): - with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"): kql.to_eql("field:{outer:1 and middle:{inner:2}}") def test_not_query(self): @@ -55,7 +55,7 @@ def test_list_of_values(self): def test_lone_value(self): for value in ["1", "-1.4", "true", "\"string test\""]: - with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"): + with pytest.raises(kql.KqlParseError, match="Value not tied to field"): kql.to_eql(value) def test_schema(self): @@ -78,22 +78,22 @@ def test_schema(self): self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) - with self.assertRaises(eql.EqlSemanticError): + with pytest.raises(eql.EqlSemanticError): self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema) - with self.assertRaises(eql.EqlSemanticError): + with pytest.raises(eql.EqlSemanticError): self.validate("top.text : 1 ", "top.text == '1'", schema=schema) - with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match top.middle's type: nested"): + with pytest.raises(kql.KqlParseError, match=r"Value doesn't match top.middle's type: nested"): kql.to_eql("top.middle : 1", schema=schema) - with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"): kql.to_eql("top:{keyword : 1}", schema=schema) - with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): + with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"): kql.to_eql("top:{middle:{bool: true}}", schema=schema) invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""] for ip in invalid_ips: - with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"): - kql.to_eql("dest:{ip}".format(ip=ip), schema=schema) + with pytest.raises(kql.KqlParseError, match=r"Value doesn't match dest's type: ip"): + kql.to_eql(f"dest:{ip}", schema=schema) diff --git a/tests/kuery/test_lint.py b/tests/kuery/test_lint.py index 7f0e97bd195..5b7b01b729e 100644 --- a/tests/kuery/test_lint.py +++ b/tests/kuery/test_lint.py @@ -3,14 +3,16 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest +import pytest + import kql -class LintTests(unittest.TestCase): +class TestLint: - def validate(self, source, linted, *args): - self.assertEqual(kql.lint(source), linted, *args) + def validate(self, source, linted, msg=None): + actual = kql.lint(source) + assert actual == linted, msg if msg is not None else f"Expected {linted}, got {actual}" def test_lint_field(self): self.validate("a : b", "a:b") @@ -31,7 +33,7 @@ def test_upper_tokens(self): ] for q in queries: - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse(q) def test_lint_precedence(self): diff --git a/tests/kuery/test_parser.py b/tests/kuery/test_parser.py index 444d55f1bd1..5d086fc1f9d 100644 --- a/tests/kuery/test_parser.py +++ b/tests/kuery/test_parser.py @@ -3,23 +3,18 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest +import pytest + import kql -from kql.ast import ( - Field, - FieldComparison, - FieldRange, - String, - Number, - Exists, -) +from kql.ast import Exists, Field, FieldComparison, FieldRange, Number, String -class ParserTests(unittest.TestCase): +class TestParser: def validate(self, source, tree, *args, **kwargs): kwargs.setdefault("optimize", False) - self.assertEqual(kql.parse(source, *args, **kwargs), tree) + actual = kql.parse(source, *args, **kwargs) + assert actual == tree, f"Expected {tree}, got {actual}" def test_keyword(self): schema = { @@ -50,24 +45,24 @@ def test_conversion(self): self.validate('text:"1"', FieldComparison(Field("text"), String("1")), schema=schema) def test_list_equals(self): - self.assertEqual(kql.parse("a:(1 or 2)", optimize=False), kql.parse("a:(2 or 1)", optimize=False)) + assert kql.parse("a:(1 or 2)", optimize=False) == kql.parse("a:(2 or 1)", optimize=False) def test_number_exists(self): - self.assertEqual(kql.parse("foo:*", schema={"foo": "long"}), FieldComparison(Field("foo"), Exists())) + assert kql.parse("foo:*", schema={"foo": "long"}) == FieldComparison(Field("foo"), Exists()) def test_multiple_types_success(self): schema = {"common.a": "keyword", "common.b": "keyword"} self.validate("common.* : \"hello\"", FieldComparison(Field("common.*"), String("hello")), schema=schema) def test_multiple_types_fail(self): - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse("common.* : \"hello\"", schema={"common.a": "keyword", "common.b": "ip"}) def test_number_wildcard_fail(self): - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse("foo:*wc", schema={"foo": "long"}) - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse("foo:wc*", schema={"foo": "long"}) def test_type_family_success(self): @@ -76,12 +71,12 @@ def test_type_family_success(self): kql.parse("abc >= now-30d", schema={"abc": "date_nanos"}) def test_type_family_fail(self): - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse('foo : "hello world"', schema={"foo": "scaled_float"}) def test_date(self): schema = {"@time": "date"} self.validate('@time <= now-10d', FieldRange(Field("@time"), "<=", String("now-10d")), schema=schema) - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): kql.parse("@time > 5", schema=schema) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index b63d393f146..fd54e99eb93 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -5,8 +5,8 @@ """Test that all rules have valid metadata and syntax.""" import os +import pytest import re -import unittest import uuid import warnings from collections import defaultdict @@ -34,30 +34,30 @@ from detection_rules.version_lock import default_version_lock from rta import get_available_tests -from .base import BaseRuleTest +from tests.conftest import TestBaseRule PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True) -class TestValidRules(BaseRuleTest): +class TestValidRules(TestBaseRule): """Test that all detection rules load properly without duplicates.""" def test_schema_and_dupes(self): """Ensure that every rule matches the schema and there are no duplicates.""" - self.assertGreaterEqual(len(self.all_rules), 1, 'No rules were loaded from rules directory!') + assert len(self.all_rules) >= 1, 'No rules were loaded from rules directory!' def test_file_names(self): """Test that the file names meet the requirement.""" file_pattern = FILE_PATTERN - self.assertIsNone(re.match(file_pattern, 'NotValidRuleFile.toml'), - f'Incorrect pattern for verifying rule names: {file_pattern}') - self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'), - f'Incorrect pattern for verifying rule names: {file_pattern}') + assert re.match(file_pattern, 'NotValidRuleFile.toml') is None, \ + f'Incorrect pattern for verifying rule names: {file_pattern}' + assert re.match(file_pattern, 'still_not_a_valid_file_name.not_json') is None, \ + f'Incorrect pattern for verifying rule names: {file_pattern}' for rule in self.all_rules: file_name = str(rule.path.name) - self.assertIsNotNone(re.match(file_pattern, file_name), f'Invalid file name for {rule.path}') + assert re.match(file_pattern, file_name) is not None, f'Invalid file name for {rule.path}' def test_all_rule_queries_optimized(self): """Ensure that every rule query is in optimized form.""" @@ -72,7 +72,7 @@ def test_all_rule_queries_optimized(self): optimized = tree.optimize(recursive=True) err_message = f'\n{self.rule_str(rule)} Query not optimized for rule\n' \ f'Expected: {optimized}\nActual: {source}' - self.assertEqual(tree, optimized, err_message) + assert tree == optimized, err_message def test_production_rules_have_rta(self): """Ensure that all production rules have RTAs.""" @@ -83,11 +83,11 @@ def test_production_rules_have_rta(self): if isinstance(rule.contents.data, QueryRuleData) and rule.id in mappings: matching_rta = mappings[rule.id].get('rta_name') - self.assertIsNotNone(matching_rta, f'{self.rule_str(rule)} does not have RTAs') + assert matching_rta is not None, f'{self.rule_str(rule)} does not have RTAs' rta_name, ext = os.path.splitext(matching_rta) if rta_name not in ttp_names: - self.fail(f'{self.rule_str(rule)} references unknown RTA: {rta_name}') + pytest.fail(f'{self.rule_str(rule)} references unknown RTA: {rta_name}') def test_duplicate_file_names(self): """Test that no file names are duplicated.""" @@ -98,7 +98,7 @@ def test_duplicate_file_names(self): duplicates = {name: paths for name, paths in name_map.items() if len(paths) > 1} if duplicates: - self.fail(f"Found duplicated file names: {duplicates}") + pytest.fail(f"Found duplicated file names: {duplicates}") def test_rule_type_changes(self): """Test that a rule type did not change for a locked version""" @@ -142,14 +142,14 @@ def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m") build_rule(query=query) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): build_rule(query=query, bbr_type="invalid") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): build_rule(query=query, from_field="now-10m", interval="10m") -class TestThreatMappings(BaseRuleTest): +class TestThreatMappings(TestBaseRule): """Test threat mapping data for rules.""" def test_technique_deprecations(self): @@ -171,7 +171,7 @@ def test_technique_deprecations(self): if revoked_techniques: old_new_mapping = "\n".join(f'Actual: {k} -> Expected {v}' for k, v in revoked_techniques.items()) - self.fail(f'{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}') + pytest.fail(f'{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}') def test_tactic_to_technique_correlations(self): """Ensure rule threat info is properly related to a single tactic and technique.""" @@ -184,52 +184,52 @@ def test_tactic_to_technique_correlations(self): mismatched = [t.id for t in techniques if t.id not in attack.matrix[tactic.name]] if mismatched: - self.fail(f'mismatched ATT&CK techniques for rule: {self.rule_str(rule)} ' - f'{", ".join(mismatched)} not under: {tactic["name"]}') + pytest.fail(f'mismatched ATT&CK techniques for rule: {self.rule_str(rule)} ' + f'{", ".join(mismatched)} not under: {tactic["name"]}') # tactic expected_tactic = attack.tactics_map[tactic.name] - self.assertEqual(expected_tactic, tactic.id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_tactic} for {tactic.name}\n' - f'actual: {tactic.id}') + assert expected_tactic == tactic.id, \ + f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' \ + f'expected: {expected_tactic} for {tactic.name}\n' \ + f'actual: {tactic.id}' tactic_reference_id = tactic.reference.rstrip('/').split('/')[-1] - self.assertEqual(tactic.id, tactic_reference_id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'tactic ID {tactic.id} does not match the reference URL ID ' - f'{tactic.reference}') + assert tactic.id == tactic_reference_id, \ + f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' \ + f'tactic ID {tactic.id} does not match the reference URL ID ' \ + f'{tactic.reference}' # techniques for technique in techniques: expected_technique = attack.technique_lookup[technique.id]['name'] - self.assertEqual(expected_technique, technique.name, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_technique} for {technique.id}\n' - f'actual: {technique.name}') + assert expected_technique == technique.name, \ + f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' \ + f'expected: {expected_technique} for {technique.id}\n' \ + f'actual: {technique.name}' technique_reference_id = technique.reference.rstrip('/').split('/')[-1] - self.assertEqual(technique.id, technique_reference_id, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'technique ID {technique.id} does not match the reference URL ID ' - f'{technique.reference}') + assert technique.id == technique_reference_id, \ + f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' \ + f'technique ID {technique.id} does not match the reference URL ID ' \ + f'{technique.reference}' # sub-techniques sub_techniques = technique.subtechnique or [] if sub_techniques: for sub_technique in sub_techniques: expected_sub_technique = attack.technique_lookup[sub_technique.id]['name'] - self.assertEqual(expected_sub_technique, sub_technique.name, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_sub_technique} for {sub_technique.id}\n' - f'actual: {sub_technique.name}') + assert expected_sub_technique == sub_technique.name, \ + f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' \ + f'expected: {expected_sub_technique} for {sub_technique.id}\n' \ + f'actual: {sub_technique.name}' sub_technique_reference_id = '.'.join( sub_technique.reference.rstrip('/').split('/')[-2:]) - self.assertEqual(sub_technique.id, sub_technique_reference_id, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'sub-technique ID {sub_technique.id} does not match the reference URL ID ' # noqa: E501 - f'{sub_technique.reference}') + assert sub_technique.id == sub_technique_reference_id, \ + f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' \ + f'sub-technique ID {sub_technique.id} does not match the reference URL ID ' \ + f'{sub_technique.reference}' def test_duplicated_tactics(self): """Check that a tactic is only defined once.""" @@ -239,12 +239,12 @@ def test_duplicated_tactics(self): duplicates = sorted(set(t for t in tactics if tactics.count(t) > 1)) if duplicates: - self.fail(f'{self.rule_str(rule)} duplicate tactics defined for {duplicates}. ' - f'Flatten to a single entry per tactic') + pytest.fail(f'{self.rule_str(rule)} duplicate tactics defined for {duplicates}. ' + f'Flatten to a single entry per tactic') -@unittest.skipIf(os.environ.get('DR_BYPASS_TAGS_VALIDATION') is not None, "Skipping tag validation") -class TestRuleTags(BaseRuleTest): +@pytest.mark.skipif(os.environ.get('DR_BYPASS_TAGS_VALIDATION') is not None, reason="Skipping tag validation") +class TestRuleTags(TestBaseRule): """Test tags data for rules.""" def test_casing_and_spacing(self): @@ -263,7 +263,7 @@ def test_casing_and_spacing(self): error_msg = f'{self.rule_str(rule)} Invalid casing for expected tags\n' error_msg += f'Actual tags: {", ".join(invalid_tags)}\n' error_msg += f'Expected tags: {", ".join(invalid_tags.values())}' - self.fail(error_msg) + pytest.fail(error_msg) def test_required_tags(self): """Test that expected tags are present within rules.""" @@ -314,7 +314,7 @@ def test_required_tags(self): error_msg += f'\nMissing any of: {", " .join(consolidated_optional_tags)}' if is_missing_any_tags else '' if missing_required_tags or is_missing_any_tags: - self.fail(error_msg) + pytest.fail(error_msg) def test_primary_tactic_as_tag(self): """Test that the primary tactic is present as a tag.""" @@ -354,7 +354,7 @@ def test_primary_tactic_as_tag(self): if invalid: err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned tags and tactics:\n{err_msg}') + pytest.fail(f'Rules with misaligned tags and tactics:\n{err_msg}') def test_os_tags(self): """Test that OS tags are present within rules.""" @@ -376,7 +376,7 @@ def test_os_tags(self): if invalid: err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing OS tags:\n{err_msg}') + pytest.fail(f'Rules with missing OS tags:\n{err_msg}') def test_ml_rule_type_tags(self): """Test that ML rule type tags are present within rules.""" @@ -397,9 +397,9 @@ def test_ml_rule_type_tags(self): if invalid: err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned ML rule type tags:\n{err_msg}') + pytest.fail(f'Rules with misaligned ML rule type tags:\n{err_msg}') - @unittest.skip("Skipping until all Investigation Guides follow the proper format.") + @pytest.mark.skip(reason="Skipping until all Investigation Guides follow the proper format.") def test_investigation_guide_tag(self): """Test that investigation guide tags are present within rules.""" invalid = [] @@ -415,7 +415,7 @@ def test_investigation_guide_tag(self): invalid.append(err_msg) if invalid: err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing Investigation tag:\n{err_msg}') + pytest.fail(f'Rules with missing Investigation tag:\n{err_msg}') def test_tag_prefix(self): """Ensure all tags have a prefix from an expected list.""" @@ -427,7 +427,7 @@ def test_tag_prefix(self): [invalid.append(f"{self.rule_str(rule)}-{tag}") for tag in rule_tags if not any(prefix in tag for prefix in expected_prefixes)] if invalid: - self.fail(f'Rules with invalid tags:\n{invalid}') + pytest.fail(f'Rules with invalid tags:\n{invalid}') def test_no_duplicate_tags(self): """Ensure no rules have duplicate tags.""" @@ -439,10 +439,10 @@ def test_no_duplicate_tags(self): invalid.append(self.rule_str(rule)) if invalid: - self.fail(f'Rules with duplicate tags:\n{invalid}') + pytest.fail(f'Rules with duplicate tags:\n{invalid}') -class TestRuleTimelines(BaseRuleTest): +class TestRuleTimelines(TestBaseRule): """Test timelines in rules are valid.""" def test_timeline_has_title(self): @@ -455,21 +455,21 @@ def test_timeline_has_title(self): if (timeline_title or timeline_id) and not (timeline_title and timeline_id): missing_err = f'{self.rule_str(rule)} timeline "title" and "id" required when timelines are defined' - self.fail(missing_err) + pytest.fail(missing_err) if timeline_id: unknown_id = f'{self.rule_str(rule)} Unknown timeline_id: {timeline_id}.' unknown_id += f' replace with {", ".join(TIMELINE_TEMPLATES)} ' \ f'or update this unit test with acceptable ids' - self.assertIn(timeline_id, list(TIMELINE_TEMPLATES), unknown_id) + assert timeline_id in list(TIMELINE_TEMPLATES), unknown_id unknown_title = f'{self.rule_str(rule)} unknown timeline_title: {timeline_title}' unknown_title += f' replace with {", ".join(TIMELINE_TEMPLATES.values())}' unknown_title += ' or update this unit test with acceptable titles' - self.assertEqual(timeline_title, TIMELINE_TEMPLATES[timeline_id], unknown_title) + assert timeline_title == TIMELINE_TEMPLATES[timeline_id], unknown_title -class TestRuleFiles(BaseRuleTest): +class TestRuleFiles(TestBaseRule): """Test the expected file names.""" def test_rule_file_name_tactic(self): @@ -497,32 +497,32 @@ def test_rule_file_name_tactic(self): if bad_name_rules: error_msg = 'filename does not start with the primary tactic - update the tactic or the rule filename' rule_err_str = '\n'.join(bad_name_rules) - self.fail(f'{error_msg}:\n{rule_err_str}') + pytest.fail(f'{error_msg}:\n{rule_err_str}') def test_bbr_in_correct_dir(self): """Ensure that BBR are in the correct directory.""" for rule in self.bbr: # Is the rule a BBR - self.assertEqual(rule.contents.data.building_block_type, 'default', - f'{self.rule_str(rule)} should have building_block_type = "default"') + assert rule.contents.data.building_block_type == 'default', \ + f'{self.rule_str(rule)} should have building_block_type = "default"' # Is the rule in the rules_building_block directory - self.assertEqual(rule.path.parent.name, 'rules_building_block', - f'{self.rule_str(rule)} should be in the rules_building_block directory') + assert rule.path.parent.name == 'rules_building_block', \ + f'{self.rule_str(rule)} should be in the rules_building_block directory' def test_non_bbr_in_correct_dir(self): """Ensure that non-BBR are not in BBR directory.""" proper_directory = 'rules_building_block' for rule in self.all_rules: if rule.path.parent.name == 'rules_building_block': - self.assertIn(rule, self.bbr, f'{self.rule_str(rule)} should be in the {proper_directory}') + assert rule in self.bbr, f'{self.rule_str(rule)} should be in the {proper_directory}' else: # Is the rule of type BBR and not in the correct directory - self.assertEqual(rule.contents.data.building_block_type, None, - f'{self.rule_str(rule)} should be in {proper_directory}') + assert rule.contents.data.building_block_type is None, \ + f'{self.rule_str(rule)} should be in {proper_directory}' -class TestRuleMetadata(BaseRuleTest): +class TestRuleMetadata(TestBaseRule): """Test the metadata of rules.""" def test_updated_date_newer_than_creation(self): @@ -538,7 +538,7 @@ def test_updated_date_newer_than_creation(self): if invalid: rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in invalid) err_msg = f'The following rules have an updated_date older than the creation_date\n {rules_str}' - self.fail(err_msg) + pytest.fail(err_msg) def test_deprecated_rules(self): """Test that deprecated rules are properly handled.""" @@ -559,7 +559,7 @@ def test_deprecated_rules(self): misplaced = '\n'.join(f'{self.rule_str(r)} {r.contents.metadata.maturity}' for r in misplaced_rules) err_str = f'The following rules are stored in {deprecated_path} but are not marked as deprecated:\n{misplaced}' - self.assertListEqual(misplaced_rules, [], err_str) + assert misplaced_rules == [], err_str for rule in self.deprecated_rules: meta = rule.contents.metadata @@ -567,18 +567,18 @@ def test_deprecated_rules(self): deprecated_rules[rule.id] = rule err_msg = f'{self.rule_str(rule)} cannot be deprecated if it has not been version locked. ' \ f'Convert to `development` or delete the rule file instead' - self.assertIn(rule.id, versions, err_msg) + assert rule.id in versions, err_msg rule_path = rule.path.relative_to(rules_path) err_msg = f'{self.rule_str(rule)} deprecated rules should be stored in ' \ f'"{deprecated_path}" folder' - self.assertEqual('_deprecated', rule_path.parts[-2], err_msg) + assert '_deprecated' == rule_path.parts[-2], err_msg err_msg = f'{self.rule_str(rule)} missing deprecation date' - self.assertIsNotNone(meta['deprecation_date'], err_msg) + assert meta['deprecation_date'] is not None, err_msg err_msg = f'{self.rule_str(rule)} deprecation_date and updated_date should match' - self.assertEqual(meta['deprecation_date'], meta['updated_date'], err_msg) + assert meta['deprecation_date'] == meta['updated_date'], err_msg # skip this so the lock file can be shared across branches # @@ -587,7 +587,7 @@ def test_deprecated_rules(self): # err_msg = f'Deprecated rules should not be removed, but moved to the rules/_deprecated folder instead. ' \ # f'The following rules have been version locked and are missing. ' \ # f'Re-add to the deprecated folder and update maturity to "deprecated": \n {missing_rule_strings}' - # self.assertEqual([], missing_rules, err_msg) + # assert missing_rules == [], err_msg for rule_id, entry in deprecations.items(): # if a rule is deprecated and not backported in order to keep the rule active in older branches, then it @@ -598,10 +598,10 @@ def test_deprecated_rules(self): continue rule_str = f'{rule_id} - {entry["rule_name"]} ->' - self.assertIn(rule_id, deprecated_rules, f'{rule_str} is logged in "deprecated_rules.json" but is missing') + assert rule_id in deprecated_rules, f'{rule_str} is logged in "deprecated_rules.json" but is missing' - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks regarding related integrations build time field.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + reason="Test only applicable to 8.3+ stacks regarding related integrations build time field.") def test_integration_tag(self): """Test integration rules defined by metadata tag.""" failures = [] @@ -682,7 +682,7 @@ def test_integration_tag(self): Try updating the integrations manifest file: - `python -m detection_rules dev integrations build-manifests`\n """ - self.fail(err_msg + '\n'.join(failures)) + pytest.fail(err_msg + '\n'.join(failures)) def test_invalid_queries(self): invalid_queries_eql = [ @@ -797,22 +797,22 @@ def build_rule(query: str, query_language: str): build_rule(query, "eql") for query in invalid_queries_eql: - with self.assertRaises(eql.EqlSchemaError): + with pytest.raises(eql.EqlSchemaError): build_rule(query, "eql") for query in invalid_integration_queries_eql: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): build_rule(query, "eql") # kql for query in valid_queries_kql: build_rule(query, "kuery") for query in invalid_queries_kql: - with self.assertRaises(kql.KqlParseError): + with pytest.raises(kql.KqlParseError): build_rule(query, "kuery") for query in invalid_integration_queries_kql: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): build_rule(query, "kuery") def test_event_dataset(self): @@ -844,10 +844,10 @@ def test_event_dataset(self): raise validation_integrations_check -class TestIntegrationRules(BaseRuleTest): +class TestIntegrationRules(TestBaseRule): """Test integration rules.""" - @unittest.skip("8.3+ Stacks Have Related Integrations Feature") + @pytest.mark.skip(reason="8.3+ Stacks Have Related Integrations Feature") def test_integration_guide(self): """Test that rules which require a config note are using standard verbiage.""" config = '## Setup\n\n' @@ -869,12 +869,12 @@ def test_integration_guide(self): note_str = integration_notes.get(integration) if note_str: - self.assert_(rule.contents.data.note, f'{self.rule_str(rule)} note required for config information') + assert rule.contents.data.note, f'{self.rule_str(rule)} note required for config information' if note_str not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n' - f'Expected: {note_str}\n\n' - f'Actual: {rule.contents.data.note}') + pytest.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n' + f'Expected: {note_str}\n\n' + f'Actual: {rule.contents.data.note}') def test_rule_demotions(self): """Test to ensure a locked rule is not dropped to development, only deprecated""" @@ -889,7 +889,7 @@ def test_rule_demotions(self): if failures: err_msg = '\n'.join(failures) - self.fail(f'The following rules have been improperly demoted:\n{err_msg}') + pytest.fail(f'The following rules have been improperly demoted:\n{err_msg}') def test_all_min_stack_rules_have_comment(self): failures = [] @@ -901,8 +901,8 @@ def test_all_min_stack_rules_have_comment(self): if failures: err_msg = '\n'.join(failures) - self.fail(f'The following ({len(failures)}) rules have a `min_stack_version` defined but missing comments:' - f'\n{err_msg}') + pytest.fail(f'The following ({len(failures)}) rules have a `min_stack_version` defined but ' + f'missing comments: \n{err_msg}') def test_ml_integration_jobs_exist(self): """Test that machine learning jobs exist in the integration.""" @@ -942,12 +942,12 @@ def test_ml_integration_jobs_exist(self): if failures: err_msg = '\n'.join(failures) - self.fail( + pytest.fail( f'The following ({len(failures)}) rules are missing a valid `machine_learning_job_id`:\n{err_msg}' ) -class TestRuleTiming(BaseRuleTest): +class TestRuleTiming(TestBaseRule): """Test rule timing and timestamps.""" def test_event_override(self): @@ -1036,7 +1036,7 @@ def test_event_override(self): continue err_strings.append(f'({len(type_errors)}) {errors_by_type["msg"]}') err_strings.extend([f' - {e}' for e in type_errors]) - self.fail('\n'.join(err_strings)) + pytest.fail('\n'.join(err_strings)) def test_required_lookback(self): """Ensure endpoint rules have the proper lookback time.""" @@ -1053,7 +1053,7 @@ def test_required_lookback(self): if missing: rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in missing) err_msg = f'The following rules should have a longer `from` defined, due to indexes used\n {rules_str}' - self.fail(err_msg) + pytest.fail(err_msg) def test_eql_lookback(self): """Ensure EQL rules lookback => max_span, when defined.""" @@ -1080,7 +1080,7 @@ def test_eql_lookback(self): if invalids: invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have longer max_spans than lookbacks:\n{invalids_str}') + pytest.fail(f'The following rules have longer max_spans than lookbacks:\n{invalids_str}') def test_eql_interval_to_maxspan(self): """Check the ratio of interval to maxspan for eql rules.""" @@ -1102,10 +1102,10 @@ def test_eql_interval_to_maxspan(self): if invalids: invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}') + pytest.fail(f'The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}') -class TestLicense(BaseRuleTest): +class TestLicense(TestBaseRule): """Test rule license.""" def test_elastic_license_only_v2(self): @@ -1114,10 +1114,10 @@ def test_elastic_license_only_v2(self): rule_license = rule.contents.data.license if 'elastic license' in rule_license.lower(): err_msg = f'{self.rule_str(rule)} If Elastic License is used, only v2 should be used' - self.assertEqual(rule_license, 'Elastic License v2', err_msg) + assert rule_license == 'Elastic License v2', err_msg -class TestIncompatibleFields(BaseRuleTest): +class TestIncompatibleFields(TestBaseRule): """Test stack restricted fields do not backport beyond allowable limits.""" def test_rule_backports_for_restricted_fields(self): @@ -1133,10 +1133,10 @@ def test_rule_backports_for_restricted_fields(self): invalid_str = '\n'.join(invalid_rules) err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n' err_msg += invalid_str - self.fail(err_msg) + pytest.fail(err_msg) -class TestBuildTimeFields(BaseRuleTest): +class TestBuildTimeFields(TestBaseRule): """Test validity of build-time fields.""" def test_build_fields_min_stack(self): @@ -1161,10 +1161,10 @@ def test_build_fields_min_stack(self): f' to be set: {err_str}') if invalids: - self.fail(invalids) + pytest.fail(invalids) -class TestRiskScoreMismatch(BaseRuleTest): +class TestRiskScoreMismatch(TestBaseRule): """Test that severity and risk_score fields contain corresponding values""" def test_rule_risk_score_severity_mismatch(self): @@ -1188,10 +1188,10 @@ def test_rule_risk_score_severity_mismatch(self): invalid_str = '\n'.join(invalid_list) err_msg = 'The following rules have mismatches between Severity and Risk Score field values:\n' err_msg += invalid_str - self.fail(err_msg) + pytest.fail(err_msg) -class TestNoteMarkdownPlugins(BaseRuleTest): +class TestNoteMarkdownPlugins(TestBaseRule): """Test if a guide containing Osquery Plugin syntax contains the version note.""" def test_note_has_osquery_warning(self): @@ -1211,13 +1211,13 @@ def test_note_has_osquery_warning(self): osquery = rule.contents.transform.get('osquery') if osquery and osquery_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain ' - f'the following note:\n{osquery_note_pattern}') + pytest.fail(f'{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain ' + f'the following note:\n{osquery_note_pattern}') investigate = rule.contents.transform.get('investigate') if investigate and invest_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain ' - f'the following note:\n{invest_note_pattern}') + pytest.fail(f'{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain ' + f'the following note:\n{invest_note_pattern}') def test_plugin_placeholders_match_entries(self): """Test that the number of plugin entries match their respective placeholders in note.""" @@ -1228,7 +1228,7 @@ def test_plugin_placeholders_match_entries(self): if has_transform: if not has_note: - self.fail(f'{self.rule_str(rule)} transformed defined with no note') + pytest.fail(f'{self.rule_str(rule)} transformed defined with no note') else: if not has_note: continue @@ -1238,7 +1238,7 @@ def test_plugin_placeholders_match_entries(self): if not has_transform: if identifiers: - self.fail(f'{self.rule_str(rule)} note contains plugin placeholders with no transform entries') + pytest.fail(f'{self.rule_str(rule)} note contains plugin placeholders with no transform entries') else: continue @@ -1259,7 +1259,7 @@ def test_plugin_placeholders_match_entries(self): note_counts[plugin] += 1 err_msg = f'{self.rule_str(rule)} plugin entry count mismatch between transform and note' - self.assertDictEqual(transform_counts, note_counts, err_msg) + assert transform_counts == note_counts, err_msg def test_if_plugins_explicitly_defined(self): """Check if plugins are explicitly defined with the pattern in note vs using transform.""" @@ -1268,24 +1268,24 @@ def test_if_plugins_explicitly_defined(self): if note is not None: results = re.search(r'(!{osquery|!{investigate)', note, re.I | re.M) err_msg = f'{self.rule_str(rule)} investigation guide plugin pattern detected! Use Transform' - self.assertIsNone(results, err_msg) + assert results is None, err_msg -class TestAlertSuppression(BaseRuleTest): +class TestAlertSuppression(TestBaseRule): """Test rule alert suppression.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), - "Test only applicable to 8.6+ stacks for rule alert suppression feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), + reason="Test only applicable to 8.6+ stacks for rule alert suppression feature.") def test_group_length(self): """Test to ensure the rule alert suppression group_by does not exceed 3 elements.""" for rule in self.production_rules: if rule.contents.data.get('alert_suppression'): group_length = len(rule.contents.data.alert_suppression.group_by) if group_length > 3: - self.fail(f'{self.rule_str(rule)} has rule alert suppression with more than 3 elements.') + pytest.fail(f'{self.rule_str(rule)} has rule alert suppression with more than 3 elements.') - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), - "Test only applicable to 8.8+ stacks for rule alert suppression feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), + reason="Test only applicable to 8.8+ stacks for rule alert suppression feature.") def test_group_field_in_schemas(self): """Test to ensure the fields are defined is in ECS/Beats/Integrations schema.""" for rule in self.production_rules: @@ -1312,5 +1312,5 @@ def test_group_field_in_schemas(self): schema.update(**int_schema[data_source]) for fld in group_by_fields: if fld not in schema.keys(): - self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ + pytest.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ found in ECS, Beats, or non-ecs schemas") diff --git a/tests/test_gh_workflows.py b/tests/test_gh_workflows.py index 0aee5322d98..1a8ac3a3ec8 100644 --- a/tests/test_gh_workflows.py +++ b/tests/test_gh_workflows.py @@ -5,7 +5,6 @@ """Tests for GitHub workflow functionality.""" -import unittest from pathlib import Path import yaml @@ -17,7 +16,7 @@ GITHUB_WORKFLOWS = GITHUB_FILES / 'workflows' -class TestWorkflows(unittest.TestCase): +class TestWorkflows: """Test GitHub workflow functionality.""" def test_matrix_to_lock_version_defaults(self): @@ -28,4 +27,4 @@ def test_matrix_to_lock_version_defaults(self): matrix_versions = get_stack_versions(drop_patch=True) err_msg = 'lock-versions workflow default does not match current matrix in stack-schema-map' - self.assertListEqual(lock_versions, matrix_versions[:-1], err_msg) + assert lock_versions == matrix_versions[:-1], err_msg diff --git a/tests/test_mappings.py b/tests/test_mappings.py index 77a1b05ba87..20c3e14d0d2 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -5,16 +5,16 @@ """Test that all rules appropriately match against expected data sets.""" import copy -import unittest import warnings -from . import get_data_files, get_fp_data_files from detection_rules.utils import combine_sources, evaluate, load_etc_dump from rta import get_available_tests -from .base import BaseRuleTest +from tests.conftest import TestBaseRule + +from . import get_data_files, get_fp_data_files -class TestMappings(BaseRuleTest): +class TestMappings(TestBaseRule): """Test that all rules appropriately match against expected data sets.""" FP_FILES = get_fp_data_files() @@ -22,7 +22,7 @@ class TestMappings(BaseRuleTest): def evaluate(self, documents, rule, expected, msg): """KQL engine to evaluate.""" filtered = evaluate(rule, documents) - self.assertEqual(expected, len(filtered), msg) + assert expected == len(filtered), msg return filtered def test_true_positives(self): @@ -41,7 +41,7 @@ def test_true_positives(self): rta_file = mapping['rta_name'] # ensure sources is defined and not empty; schema allows it to not be set since 'pending' bypasses - self.assertTrue(sources, 'No sources defined for: {} - {} '.format(rule.id, rule.name)) + assert sources, f'No sources defined for: {rule.id} - {rule.name}' msg = 'Expected TP results did not match for: {} - {}'.format(rule.id, rule.name) data_files = [get_data_files('true_positives', rta_file).get(s) for s in sources] @@ -70,7 +70,7 @@ def test_false_positives(self): self.evaluate(copy.deepcopy(merged_data), rule, 0, msg) -class TestRTAs(unittest.TestCase): +class TestRTAs: """Test that all RTAs have appropriate fields added.""" def test_rtas_with_triggered_rules_have_uuid(self): @@ -79,8 +79,7 @@ def test_rtas_with_triggered_rules_have_uuid(self): rule_keys = ["rule_id", "rule_name"] for rta_test in sorted(get_available_tests().values(), key=lambda r: r['name']): - self.assertIsNotNone(rta_test.get("uuid"), f'RTA {rta_test.get("name")} missing uuid') + assert rta_test.get("uuid") is not None, f'RTA {rta_test.get("name")} missing uuid' for rule_info in rta_test.get("siem"): for rule_key in rule_keys: - self.assertIsNotNone(rule_info.get(rule_key), - f'RTA {rta_test.get("name")} - {rta_test.get("uuid")} missing {rule_key}') + assert rule_info.get(rule_key) is not None, f'RTA {rta_test.get("name")} - {rta_test.get("uuid")} missing {rule_key}' # noqa: E501 diff --git a/tests/test_packages.py b/tests/test_packages.py index ca014cbee61..442723395c0 100644 --- a/tests/test_packages.py +++ b/tests/test_packages.py @@ -4,23 +4,25 @@ # 2.0. """Test that the packages are built correctly.""" -import unittest import uuid -from semver import Version + +import pytest from marshmallow import ValidationError +from semver import Version from detection_rules import rule_loader -from detection_rules.schemas.registry_package import (RegistryPackageManifestV1, - RegistryPackageManifestV3) from detection_rules.packaging import PACKAGE_FILE, Package from detection_rules.rule_loader import RuleCollection - -from tests.base import BaseRuleTest +from detection_rules.schemas.registry_package import ( + RegistryPackageManifestV1, + RegistryPackageManifestV3, +) +from tests.conftest import TestBaseRule package_configs = Package.load_configs() -class TestPackages(BaseRuleTest): +class TestPackages(TestBaseRule): """Test package building and saving.""" @staticmethod @@ -72,8 +74,8 @@ def test_rule_versioning(self): post_bump_hashes = [] # test that no rules have versions defined - for rule in rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + for rule in rules.rules: + assert rule.contents.autobumped_version >= 1, '{} - {}: version is not being set in package' original_hashes.append(rule.contents.sha256()) package = Package(rules, 'test-package') @@ -81,22 +83,21 @@ def test_rule_versioning(self): # test that all rules have versions defined # package.bump_versions(save_changes=False) for rule in package.rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + assert rule.contents.autobumped_version >= 1, '{} - {}: version is not being set in package' # test that rules validate with version for rule in package.rules: post_bump_hashes.append(rule.contents.sha256()) # test that no hashes changed as a result of the version bumps - self.assertListEqual(original_hashes, post_bump_hashes, 'Version bumping modified the hash of a rule') + assert original_hashes == post_bump_hashes, 'Version bumping modified the hash of a rule' -class TestRegistryPackage(unittest.TestCase): +class TestRegistryPackage: """Test the OOB registry package.""" @classmethod - def setUpClass(cls) -> None: - + def setup_class(cls): assert 'registry_data' in package_configs, f'Missing registry_data in {PACKAGE_FILE}' cls.registry_config = package_configs['registry_data'] stack_version = Version.parse(cls.registry_config['conditions']['kibana.version'].strip("^"), @@ -111,5 +112,5 @@ def test_registry_package_config(self): registry_config = self.registry_config.copy() registry_config['version'] += '7.1.1.' - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): RegistryPackageManifestV1.from_dict(registry_config) diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index e422239ce62..f00b732d12e 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -3,15 +3,17 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest +import pytest -from .base import BaseRuleTest from detection_rules.misc import get_default_config + +from tests.conftest import TestBaseRule + # from detection_rules.remote_validation import RemoteValidator -@unittest.skipIf(get_default_config() is None, 'Skipping remote validation due to missing config') -class TestRemoteRules(BaseRuleTest): +@pytest.mark.skipif(get_default_config() is None, reason='Skipping remote validation due to missing config') +class TestRemoteRules(TestBaseRule): """Test rules against a remote Elastic stack instance.""" # def test_esql_rules(self): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 2ac7fd84535..b45cf5e44f8 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,24 +5,25 @@ """Test stack versioned schemas.""" import copy -import unittest import uuid -from semver import Version import eql +import pytest +from marshmallow import ValidationError +from semver import Version + from detection_rules import utils from detection_rules.misc import load_current_package_version from detection_rules.rule import TOMLRuleContents from detection_rules.schemas import downgrade from detection_rules.version_lock import VersionLockFile -from marshmallow import ValidationError -class TestSchemas(unittest.TestCase): +class TestSchemas: """Test schemas and downgrade functions.""" @classmethod - def setUpClass(cls): + def setup_class(cls): cls.current_version = load_current_package_version() # expected contents for a downgraded rule @@ -102,20 +103,20 @@ def test_query_downgrade_7_x(self): if Version.parse(self.current_version, optional_minor_and_patch=True).major > 7: return - self.assertDictEqual(downgrade(self.v711_kql, "7.11"), self.v711_kql) - self.assertDictEqual(downgrade(self.v711_kql, "7.9"), self.v79_kql) - self.assertDictEqual(downgrade(self.v711_kql, "7.9.2"), self.v79_kql) - self.assertDictEqual(downgrade(self.v711_kql, "7.8.1"), self.v78_kql) - self.assertDictEqual(downgrade(self.v79_kql, "7.8"), self.v78_kql) - self.assertDictEqual(downgrade(self.v79_kql, "7.8"), self.v78_kql) + assert downgrade(self.v711_kql, "7.11") == self.v711_kql + assert downgrade(self.v711_kql, "7.9") == self.v79_kql + assert downgrade(self.v711_kql, "7.9.2") == self.v79_kql + assert downgrade(self.v711_kql, "7.8.1") == self.v78_kql + assert downgrade(self.v79_kql, "7.8") == self.v78_kql + assert downgrade(self.v79_kql, "7.8") == self.v78_kql - with self.assertRaises(ValueError): + with pytest.raises(ValueError): downgrade(self.v711_kql, "7.7") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): downgrade(self.v79_kql, "7.7") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): downgrade(self.v78_kql, "7.7", current_version="7.8") def test_versioned_downgrade_7_x(self): @@ -124,16 +125,16 @@ def test_versioned_downgrade_7_x(self): return api_contents = self.v79_kql - self.assertDictEqual(downgrade(api_contents, "7.9"), api_contents) - self.assertDictEqual(downgrade(api_contents, "7.9.2"), api_contents) + assert downgrade(api_contents, "7.9") == api_contents + assert downgrade(api_contents, "7.9.2") == api_contents api_contents78 = api_contents.copy() api_contents78.pop("author") api_contents78.pop("license") - self.assertDictEqual(downgrade(api_contents, "7.8"), api_contents78) + assert downgrade(api_contents, "7.8") == api_contents78 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): downgrade(api_contents, "7.7") def test_threshold_downgrade_7_x(self): @@ -142,27 +143,27 @@ def test_threshold_downgrade_7_x(self): return api_contents = self.v712_threshold_rule - self.assertDictEqual(downgrade(api_contents, '7.13'), api_contents) - self.assertDictEqual(downgrade(api_contents, '7.13.1'), api_contents) + assert downgrade(api_contents, '7.13') == api_contents + assert downgrade(api_contents, '7.13.1') == api_contents exc_msg = 'Cannot downgrade a threshold rule that has multiple threshold fields defined' - with self.assertRaisesRegex(ValueError, exc_msg): + with pytest.raises(ValueError, exc_msg): downgrade(api_contents, '7.9') v712_threshold_contents_single_field = copy.deepcopy(api_contents) v712_threshold_contents_single_field['threshold']['field'].pop() - with self.assertRaisesRegex(ValueError, "Cannot downgrade a threshold rule that has a defined cardinality"): + with pytest.raises(ValueError, "Cannot downgrade a threshold rule that has a defined cardinality"): downgrade(v712_threshold_contents_single_field, "7.9") v712_no_cardinality = copy.deepcopy(v712_threshold_contents_single_field) v712_no_cardinality['threshold'].pop('cardinality') - self.assertEqual(downgrade(v712_no_cardinality, "7.9"), self.v79_threshold_contents) + assert downgrade(v712_no_cardinality, "7.9") == self.v79_threshold_contents - with self.assertRaises(ValueError): + with pytest.raises(ValueError): downgrade(v712_no_cardinality, "7.7") - with self.assertRaisesRegex(ValueError, "Unsupported rule type"): + with pytest.raises(ValueError, "Unsupported rule type"): downgrade(v712_no_cardinality, "7.8") def test_query_downgrade_8_x(self): @@ -216,32 +217,32 @@ def build_rule(query): 'file.target_path.text', 'host.os.full.text', 'host.os.name.text', 'host.user.full_name.text', 'host.user.name.text'] for text_field in example_text_fields: - with self.assertRaises(eql.parser.EqlSchemaError): + with pytest.raises(eql.parser.EqlSchemaError): build_rule(f""" any where {text_field} == "some string field" """) - with self.assertRaises(eql.EqlSyntaxError): + with pytest.raises(eql.EqlSyntaxError): build_rule(""" process where process.name == this!is$not#v@lid """) - with self.assertRaises(eql.EqlSemanticError): + with pytest.raises(eql.EqlSemanticError): build_rule(""" process where process.invalid_field == "hello world" """) - with self.assertRaises(eql.EqlTypeMismatchError): + with pytest.raises(eql.EqlTypeMismatchError): build_rule(""" process where process.pid == "some string field" """) -class TestVersionLockSchema(unittest.TestCase): +class TestVersionLockSchema: """Test that the version lock has proper entries.""" @classmethod - def setUpClass(cls): + def setup_class(cls): cls.version_lock_contents = { "33f306e8-417c-411b-965c-c2812d6d3f4d": { "rule_name": "Remote File Download via PowerShell", @@ -274,13 +275,13 @@ def test_version_lock_no_previous(self): def test_version_lock_has_nested_previous(self): """Fail field validation on version lock with nested previous fields""" version_lock_contents = copy.deepcopy(self.version_lock_contents) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): previous = version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"] version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"]["previous"] = previous VersionLockFile.from_dict(dict(data=version_lock_contents)) -class TestVersions(unittest.TestCase): +class TestVersions: """Test that schema versioning aligns.""" def test_stack_schema_map(self): @@ -288,4 +289,4 @@ def test_stack_schema_map(self): package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) stack_map = utils.load_etc_dump('stack-schema-map.yaml') err_msg = f'There is no entry defined for the current package ({package_version}) in the stack-schema-map' - self.assertIn(package_version, [Version.parse(v) for v in stack_map], err_msg) + assert package_version in [Version.parse(v) for v in stack_map], err_msg diff --git a/tests/test_specific_rules.py b/tests/test_specific_rules.py index f844f89f4aa..48fc68b7dbe 100644 --- a/tests/test_specific_rules.py +++ b/tests/test_specific_rules.py @@ -3,17 +3,18 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import unittest from copy import deepcopy from pathlib import Path import eql.ast - +import pytest from semver import Version import kql from detection_rules.integrations import ( - find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas + find_latest_compatible_version, + load_integrations_manifests, + load_integrations_schemas, ) from detection_rules.misc import load_current_package_version from detection_rules.packaging import current_stack_version @@ -22,15 +23,16 @@ from detection_rules.schemas import get_stack_schemas from detection_rules.utils import get_path, load_rule_contents -from .base import BaseRuleTest +from tests.conftest import TestBaseRule + PACKAGE_STACK_VERSION = Version.parse(current_stack_version(), optional_minor_and_patch=True) -class TestEndpointQuery(BaseRuleTest): +class TestEndpointQuery(TestBaseRule): """Test endpoint-specific rules.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + reason="Test only applicable to 8.3+ stacks since query updates are min_stacked at 8.3.0") def test_os_and_platform_in_query(self): """Test that all endpoint rules have an os defined and linux includes platform.""" for rule in self.production_rules: @@ -47,19 +49,19 @@ def test_os_and_platform_in_query(self): if 'host.os.type' not in fields: # Exception for Forwarded Events which contain Windows-only fields. if rule.path.parent.name == 'windows' and not any(field.startswith('winlog.') for field in fields): - self.assertIn('host.os.type', fields, err_msg) + assert 'host.os.type' in fields, err_msg # going to bypass this for now # if rule.path.parent.name == 'linux': # err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule' - # self.assertIn('host.os.platform', fields, err_msg) + # assert 'host.os.platform' in fields, err_msg -class TestNewTerms(BaseRuleTest): +class TestNewTerms(TestBaseRule): """Test new term rules.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + reason="Test only applicable to 8.4+ stacks for new terms feature.") def test_history_window_start(self): """Test new terms history window start field.""" @@ -72,8 +74,8 @@ def test_history_window_start(self): assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", \ f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + reason="Test only applicable to 8.4+ stacks for new terms feature.") def test_new_terms_field_exists(self): # validate new terms and history window start fields are correct for rule in self.production_rules: @@ -81,8 +83,8 @@ def test_new_terms_field_exists(self): assert rule.contents.data.new_terms.field == "new_terms_fields", \ f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + reason="Test only applicable to 8.4+ stacks for new terms feature.") def test_new_terms_fields(self): """Test new terms fields are schema validated.""" # ecs validation @@ -122,8 +124,8 @@ def test_new_terms_fields(self): assert new_terms_field in schema.keys(), \ f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.4.0"), + reason="Test only applicable to 8.4+ stacks for new terms feature.") def test_new_terms_max_limit(self): """Test new terms max limit.""" # validates length of new_terms to stack version - https://github.com/elastic/kibana/issues/142862 @@ -141,8 +143,8 @@ def test_new_terms_max_limit(self): assert len(rule.contents.data.new_terms.value) == 1, \ f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), - "Test only applicable to 8.4+ stacks for new terms feature.") + @pytest.mark.skipif(PACKAGE_STACK_VERSION < Version.parse("8.6.0"), + reason="Test only applicable to 8.4+ stacks for new terms feature.") def test_new_terms_fields_unique(self): """Test new terms fields are unique.""" # validate fields are unique @@ -152,7 +154,7 @@ def test_new_terms_fields_unique(self): f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" -class TestESQLRules(BaseRuleTest): +class TestESQLRules(TestBaseRule): """Test ESQL Rules.""" def run_esql_test(self, esql_query, expectation, message): @@ -182,5 +184,6 @@ def test_esql_queries(self): # ('from .ds-logs-endpoint.events.process-default-* | where process.name like "Microsoft*"', # does_not_raise(), None), # ] + # TODO: Refactor to use pytest parameterization when we want to enable these tests # for esql_query, expectation, message in test_cases: # self.run_esql_test(esql_query, expectation, message) diff --git a/tests/test_toml_formatter.py b/tests/test_toml_formatter.py index 4787354fa83..aa984878207 100644 --- a/tests/test_toml_formatter.py +++ b/tests/test_toml_formatter.py @@ -5,20 +5,20 @@ import copy import json -import os -import unittest +from pathlib import Path import pytoml from detection_rules.rule_formatter import nested_normalize, toml_write from detection_rules.utils import get_etc_path -tmp_file = 'tmp_file.toml' +tmp_file = Path('tmp_file.toml') -class TestRuleTomlFormatter(unittest.TestCase): +class TestRuleTomlFormatter: """Test that the custom toml formatting is not compromising the integrity of the data.""" - with open(get_etc_path("test_toml.json"), "r") as f: + + with open(get_etc_path("test_toml.json"), "r", encoding='utf-8') as f: test_data = json.load(f) def compare_formatted(self, data, callback=None, kwargs=None): @@ -26,7 +26,7 @@ def compare_formatted(self, data, callback=None, kwargs=None): try: toml_write(copy.deepcopy(data), tmp_file) - with open(tmp_file, 'r') as f: + with tmp_file.open('r', encoding='utf-8') as f: formatted_contents = pytoml.load(f) # callbacks such as nested normalize leave in line breaks, so this must be manually done @@ -46,24 +46,24 @@ def compare_formatted(self, data, callback=None, kwargs=None): formatted_contents['rule']['query'] = query.strip() formatted = json.dumps(formatted_contents, sort_keys=True) - self.assertEqual(original, formatted, 'Formatting may be modifying contents') + assert original == formatted, 'Formatting may be modifying contents' finally: - os.remove(tmp_file) + tmp_file.unlink(missing_ok=True) def compare_test_data(self, test_dicts, callback=None): """Compare test data against expected.""" for data in test_dicts: self.compare_formatted(data, callback=callback) - def test_normalization(self): - """Test that normalization does not change the rule contents.""" - self.compare_test_data([nested_normalize(self.test_data[0])], callback=nested_normalize) - def test_formatter_rule(self): """Test that formatter and encoder do not change the rule contents.""" self.compare_test_data([self.test_data[0]]) + def test_normalization(self): + """Test that normalization does not change the rule contents.""" + self.compare_test_data([nested_normalize(self.test_data[0])], callback=nested_normalize) + def test_formatter_deep(self): """Test that the data remains unchanged from formatting.""" self.compare_test_data(self.test_data[1:]) diff --git a/tests/test_transform_fields.py b/tests/test_transform_fields.py index a2b69cbcc6d..597ea669714 100644 --- a/tests/test_transform_fields.py +++ b/tests/test_transform_fields.py @@ -5,7 +5,6 @@ """Test fields in TOML [transform].""" import copy -import unittest from pathlib import Path from textwrap import dedent @@ -18,11 +17,11 @@ RULES_DIR = Path(__file__).parent.parent / 'rules' -class TestGuideMarkdownPlugins(unittest.TestCase): +class TestGuideMarkdownPlugins: """Test the Markdown plugin features within the investigation guide.""" @classmethod - def setUpClass(cls) -> None: + def setup_class(cls) -> None: cls.osquery_patterns = [ """!{osquery{"label":"Osquery - Retrieve DNS Cache","query":"SELECT * FROM dns_cache"}}""", """!{osquery{"label":"Osquery - Retrieve All Services","query":"SELECT description, display_name, name, path, pid, service_type, start_type, status, user_account FROM services"}}""", # noqa: E501 @@ -95,7 +94,7 @@ def test_transform_guide_markdown_plugins(self) -> None: rendered_note = new_rule.contents.to_api_format()['note'] for pattern in self.osquery_patterns: - self.assertIn(pattern, rendered_note) + assert pattern in rendered_note def test_plugin_conversion(self): """Test the conversion function to ensure parsing is correct.""" @@ -111,4 +110,4 @@ def test_plugin_conversion(self): new_rule = TOMLRule(path=sample_rule.path, contents=new_rule_contents) rendered_note = new_rule.contents.to_api_format()['note'] - self.assertIn(pattern, rendered_note) + assert pattern in rendered_note diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d86e01db7b..3801b4f0d00 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,14 +6,13 @@ """Test util time functions.""" import random import time -import unittest from detection_rules.utils import normalize_timing_and_sort, cached from detection_rules.eswrap import RtaEvents from detection_rules.ecs import get_kql_schema -class TestTimeUtils(unittest.TestCase): +class TestTimeUtils: """Test util time functions.""" @staticmethod @@ -41,10 +40,11 @@ def _get_data(func): return {fmt: _get_data(func) for fmt, func in date_formats.items()} - def assert_sort(self, normalized_events, date_format): + @staticmethod + def assert_sort(normalized_events, date_format): """Assert normalize and sort.""" order = [e['id'] for e in normalized_events] - self.assertListEqual([1, 2, 3, 4, 5, 6], order, 'Sorting failed for date_format: {}'.format(date_format)) + assert order == [1, 2, 3, 4, 5, 6], f'Sorting failed for date_format: {date_format}' def test_time_normalize(self): """Test normalize_timing_from_date_format.""" @@ -63,8 +63,8 @@ def test_event_class_normalization(self): def test_schema_multifields(self): """Tests that schemas are loading multifields correctly.""" schema = get_kql_schema(version="1.4.0") - self.assertEqual(schema.get("process.name"), "keyword") - self.assertEqual(schema.get("process.name.text"), "text") + assert schema.get("process.name") == "keyword" + assert schema.get("process.name.text") == "text" def test_caching(self): """Test that caching is working.""" @@ -77,27 +77,27 @@ def increment(*args, **kwargs): counter += 1 return counter - self.assertEqual(increment(), 1) - self.assertEqual(increment(), 1) - self.assertEqual(increment(), 1) + assert increment() == 1 + assert increment() == 1 + assert increment() == 1 - self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment(["hello", "world"]), 2) + assert increment(["hello", "world"]) == 2 + assert increment(["hello", "world"]) == 2 + assert increment(["hello", "world"]) == 2 - self.assertEqual(increment(), 1) - self.assertEqual(increment(["hello", "world"]), 2) + assert increment() == 1 + assert increment(["hello", "world"]) == 2 - self.assertEqual(increment({"hello": [("world", )]}), 3) - self.assertEqual(increment({"hello": [("world", )]}), 3) + assert increment({"hello": [("world", )]}) == 3 + assert increment({"hello": [("world", )]}) == 3 - self.assertEqual(increment(), 1) - self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment({"hello": [("world", )]}), 3) + assert increment() == 1 + assert increment(["hello", "world"]) == 2 + assert increment({"hello": [("world", )]}) == 3 increment.clear() - self.assertEqual(increment({"hello": [("world", )]}), 4) - self.assertEqual(increment(["hello", "world"]), 5) - self.assertEqual(increment(), 6) - self.assertEqual(increment(None), 7) - self.assertEqual(increment(1), 8) + assert increment({"hello": [("world", )]}) == 4 + assert increment(["hello", "world"]) == 5 + assert increment() == 6 + assert increment(None) == 7 + assert increment(1) == 8 diff --git a/tests/test_version_locking.py b/tests/test_version_locking.py index 0e2ffa77745..d77609c5adb 100644 --- a/tests/test_version_locking.py +++ b/tests/test_version_locking.py @@ -5,15 +5,14 @@ """Test version locking of rules.""" -import unittest - +import pytest from semver import Version from detection_rules.schemas import get_min_supported_stack_version from detection_rules.version_lock import default_version_lock -class TestVersionLock(unittest.TestCase): +class TestVersionLock: """Test version locking.""" def test_previous_entries_gte_current_min_stack(self): @@ -31,6 +30,6 @@ def test_previous_entries_gte_current_min_stack(self): # stack-schema-map if errors: err_str = '\n'.join(f'{k}: {", ".join(v)}' for k, v in errors.items()) - self.fail(f'The following version.lock entries have previous locked versions which are lower than the ' - f'currently supported min_stack ({min_version}). To address this, run the ' - f'`dev trim-version-lock {min_version}` command.\n\n{err_str}') + pytest.fail(f'The following version.lock entries have previous locked versions which are lower than the ' + f'currently supported min_stack ({min_version}). To address this, run the ' + f'`dev trim-version-lock {min_version}` command.\n\n{err_str}')