Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [POC] Refactor: port unittest to pytest #3361

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 31 additions & 32 deletions tests/base.py → tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

"""Shared resources for tests."""

import unittest

from functools import lru_cache
from typing import Union

from detection_rules.rule import TOMLRule
from detection_rules.rule_loader import DeprecatedCollection, DeprecatedRule, RuleCollection, production_filter
import pytest

from detection_rules.rule import TOMLRule
from detection_rules.rule_loader import (
DeprecatedCollection,
DeprecatedRule,
RuleCollection,
production_filter,
)

RULE_LOADER_FAIL = False
RULE_LOADER_FAIL_MSG = None
Expand All @@ -28,48 +34,41 @@ def default_bbr() -> RuleCollection:
return RuleCollection.default_bbr()


class BaseRuleTest(unittest.TestCase):
@pytest.fixture(scope="class")
def rule_data(request):
global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG
if not RULE_LOADER_FAIL:
try:
rc = default_rules()
rc_bbr = default_bbr()
request.cls.all_rules = rc.rules
request.cls.rule_lookup = rc.id_map
request.cls.production_rules = rc.filter(production_filter)
request.cls.bbr = rc_bbr.rules
request.cls.deprecated_rules: DeprecatedCollection = rc.deprecated
except Exception as e:
RULE_LOADER_FAIL = True
RULE_LOADER_FAIL_MSG = str(e)


@pytest.mark.usefixtures("rule_data")
class TestBaseRule:
"""Base class for shared test cases which need to load rules"""

RULE_LOADER_FAIL = False
RULE_LOADER_FAIL_MSG = None
RULE_LOADER_FAIL_RAISED = False

@classmethod
def setUpClass(cls):
global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG

# too noisy; refactor
# os.environ["DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE"] = "1"

if not RULE_LOADER_FAIL:
try:
rc = default_rules()
rc_bbr = default_bbr()
cls.all_rules = rc.rules
cls.rule_lookup = rc.id_map
cls.production_rules = rc.filter(production_filter)
cls.bbr = rc_bbr.rules
cls.deprecated_rules: DeprecatedCollection = rc.deprecated
except Exception as e:
RULE_LOADER_FAIL = True
RULE_LOADER_FAIL_MSG = str(e)

@staticmethod
def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->') -> str:
return f'{rule.id} - {rule.name}{trailer or ""}'

def setUp(self) -> None:
def setup_method(self, method):
global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG, RULE_LOADER_FAIL_RAISED

if RULE_LOADER_FAIL:
# limit the loader failure to just one run
# raise a dedicated test failure for the loader
if not RULE_LOADER_FAIL_RAISED:
RULE_LOADER_FAIL_RAISED = True
with self.subTest('Test that the rule loader loaded with no validation or other failures.'):
self.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}')

self.skipTest('Rule loader failure')
else:
super().setUp()
pytest.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}')
pytest.skip('Rule loader failure')
7 changes: 3 additions & 4 deletions tests/kuery/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

import unittest
import kql


class TestKQLtoDSL(unittest.TestCase):
class TestKQLtoDSL:
def validate(self, kql_source, dsl, **kwargs):
actual_dsl = kql.to_dsl(kql_source, **kwargs)
self.assertListEqual(list(actual_dsl), ["bool"])
self.assertDictEqual(actual_dsl["bool"], dsl)
assert list(actual_dsl) == ["bool"]
assert actual_dsl["bool"] == dsl

def test_field_match(self):
def match(**kv):
Expand Down
10 changes: 6 additions & 4 deletions tests/kuery/test_eql2kql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# 2.0.

import eql
import unittest
import pytest

import kql


class TestEql2Kql(unittest.TestCase):
class TestEql2Kql:

def validate(self, kql_source, eql_source):
self.assertEqual(kql_source, str(kql.from_eql(eql_source)))
assert kql_source == str(kql.from_eql(eql_source))

def test_field_equals(self):
self.validate("field:value", "field == 'value'")
Expand Down Expand Up @@ -58,6 +59,7 @@ def test_wildcard_field(self):
self.validate('field:value-*', 'field : "value-*"')
self.validate('field:value-?', 'field : "value-?"')

with eql.parser.elasticsearch_validate_optional_fields, self.assertRaises(AssertionError):
# pytest.raises is used to handle exceptions in pytest
with eql.parser.elasticsearch_validate_optional_fields, pytest.raises(AssertionError):
self.validate('field:"value-*"', 'field == "value-*"')
self.validate('field:"value-?"', 'field == "value-?"')
122 changes: 55 additions & 67 deletions tests/kuery/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,19 @@
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

import unittest

import kql


class EvaluatorTests(unittest.TestCase):

class TestEvaluator:
document = {
"number": 1,
"boolean": True,
"ip": "192.168.16.3",
"string": "hello world",

"string_list": ["hello world", "example"],
"number_list": [1, 2, 3],
"boolean_list": [True, False],
"structured": [
{
"a": [
{"b": 1}
]
}
],
"structured": [{"a": [{"b": 1}]}],
}

def evaluate(self, source_text, document=None):
Expand All @@ -36,89 +26,87 @@ def evaluate(self, source_text, document=None):
return evaluator(document)

def test_single_value(self):
self.assertTrue(self.evaluate('number:1'))
self.assertTrue(self.evaluate('number:"1"'))
self.assertTrue(self.evaluate('boolean:true'))
self.assertTrue(self.evaluate('string:"hello world"'))
assert self.evaluate('number:1')
assert self.evaluate('number:"1"')
assert self.evaluate('boolean:true')
assert self.evaluate('string:"hello world"')

self.assertFalse(self.evaluate('number:0'))
self.assertFalse(self.evaluate('boolean:false'))
self.assertFalse(self.evaluate('string:"missing"'))
assert not self.evaluate('number:0')
assert not self.evaluate('boolean:false')
assert not self.evaluate('string:"missing"')

def test_list_value(self):
self.assertTrue(self.evaluate('number_list:1'))
self.assertTrue(self.evaluate('number_list:2'))
self.assertTrue(self.evaluate('number_list:3'))
assert self.evaluate('number_list:1')
assert self.evaluate('number_list:2')
assert self.evaluate('number_list:3')

self.assertTrue(self.evaluate('boolean_list:true'))
self.assertTrue(self.evaluate('boolean_list:false'))
assert self.evaluate('boolean_list:true')
assert self.evaluate('boolean_list:false')

self.assertTrue(self.evaluate('string_list:"hello world"'))
self.assertTrue(self.evaluate('string_list:example'))
assert self.evaluate('string_list:"hello world"')
assert self.evaluate('string_list:example')

self.assertFalse(self.evaluate('number_list:4'))
self.assertFalse(self.evaluate('string_list:"missing"'))
assert not self.evaluate('number_list:4')
assert not self.evaluate('string_list:"missing"')

def test_and_values(self):
self.assertTrue(self.evaluate('number_list:(1 and 2)'))
self.assertTrue(self.evaluate('boolean_list:(false and true)'))
self.assertFalse(self.evaluate('string:("missing" and "hello world")'))
assert self.evaluate('number_list:(1 and 2)')
assert self.evaluate('boolean_list:(false and true)')
assert not self.evaluate('string:("missing" and "hello world")')

self.assertFalse(self.evaluate('number:(0 and 1)'))
self.assertFalse(self.evaluate('boolean:(false and true)'))
assert not self.evaluate('number:(0 and 1)')
assert not self.evaluate('boolean:(false and true)')

def test_not_value(self):
self.assertTrue(self.evaluate('number_list:1'))
self.assertFalse(self.evaluate('not number_list:1'))
self.assertFalse(self.evaluate('number_list:(not 1)'))
assert self.evaluate('number_list:1')
assert not self.evaluate('not number_list:1')
assert not self.evaluate('number_list:(not 1)')

def test_or_values(self):
self.assertTrue(self.evaluate('number:(0 or 1)'))
self.assertTrue(self.evaluate('number:(1 or 2)'))
self.assertTrue(self.evaluate('boolean:(false or true)'))
self.assertTrue(self.evaluate('string:("missing" or "hello world")'))
assert self.evaluate('number:(0 or 1)')
assert self.evaluate('number:(1 or 2)')
assert self.evaluate('boolean:(false or true)')
assert self.evaluate('string:("missing" or "hello world")')

self.assertFalse(self.evaluate('number:(0 or 3)'))
assert not self.evaluate('number:(0 or 3)')

def test_and_expr(self):
self.assertTrue(self.evaluate('number:1 and boolean:true'))

self.assertFalse(self.evaluate('number:1 and boolean:false'))
assert self.evaluate('number:1 and boolean:true')
assert not self.evaluate('number:1 and boolean:false')

def test_or_expr(self):
self.assertTrue(self.evaluate('number:1 or boolean:false'))
self.assertFalse(self.evaluate('number:0 or boolean:false'))
assert self.evaluate('number:1 or boolean:false')
assert not self.evaluate('number:0 or boolean:false')

def test_range(self):
self.assertTrue(self.evaluate('number < 2'))
self.assertFalse(self.evaluate('number > 2'))
assert self.evaluate('number < 2')
assert not self.evaluate('number > 2')

def test_cidr_match(self):
self.assertTrue(self.evaluate('ip:192.168.0.0/16'))

self.assertFalse(self.evaluate('ip:10.0.0.0/8'))
assert self.evaluate('ip:192.168.0.0/16')
assert not self.evaluate('ip:10.0.0.0/8')

def test_quoted_wildcard(self):
self.assertFalse(self.evaluate("string:'*'"))
self.assertFalse(self.evaluate("string:'?'"))
assert not self.evaluate("string:'*'")
assert not self.evaluate("string:'?'")

def test_wildcard(self):
self.assertTrue(self.evaluate('string:hello*'))
self.assertTrue(self.evaluate('string:*world'))
self.assertFalse(self.evaluate('string:foobar*'))
assert self.evaluate('string:hello*')
assert self.evaluate('string:*world')
assert not self.evaluate('string:foobar*')

def test_field_exists(self):
self.assertTrue(self.evaluate('number:*'))
self.assertTrue(self.evaluate('boolean:*'))
self.assertTrue(self.evaluate('ip:*'))
self.assertTrue(self.evaluate('string:*'))
self.assertTrue(self.evaluate('string_list:*'))
self.assertTrue(self.evaluate('number_list:*'))
self.assertTrue(self.evaluate('boolean_list:*'))
assert self.evaluate('number:*')
assert self.evaluate('boolean:*')
assert self.evaluate('ip:*')
assert self.evaluate('string:*')
assert self.evaluate('string_list:*')
assert self.evaluate('number_list:*')
assert self.evaluate('boolean_list:*')

self.assertFalse(self.evaluate('a:*'))
assert not self.evaluate('a:*')

def test_flattening(self):
self.assertTrue(self.evaluate("structured.a.b:*"))
self.assertTrue(self.evaluate("structured.a.b:1"))
self.assertFalse(self.evaluate("structured.a.b:2"))
assert self.evaluate("structured.a.b:*")
assert self.evaluate("structured.a.b:1")
assert not self.evaluate("structured.a.b:2")
24 changes: 12 additions & 12 deletions tests/kuery/test_kql2eql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

import unittest
import eql
import pytest

import kql


class TestKql2Eql(unittest.TestCase):
class TestKql2Eql:

def validate(self, kql_source, eql_source, schema=None):
self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source))
assert kql.to_eql(kql_source, schema=schema) == eql.parse_expression(eql_source)

def test_field_equals(self):
self.validate("field:value", "field == 'value'")
Expand All @@ -36,7 +36,7 @@ def test_and_query(self):
self.validate("field:value and field2:value2", "field == 'value' and field2 == 'value2'")

def test_nested_query(self):
with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"):
kql.to_eql("field:{outer:1 and middle:{inner:2}}")

def test_not_query(self):
Expand All @@ -55,7 +55,7 @@ def test_list_of_values(self):

def test_lone_value(self):
for value in ["1", "-1.4", "true", "\"string test\""]:
with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"):
with pytest.raises(kql.KqlParseError, match="Value not tied to field"):
kql.to_eql(value)

def test_schema(self):
Expand All @@ -78,22 +78,22 @@ def test_schema(self):
self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema)
self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema)

with self.assertRaises(eql.EqlSemanticError):
with pytest.raises(eql.EqlSemanticError):
self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema)

with self.assertRaises(eql.EqlSemanticError):
with pytest.raises(eql.EqlSemanticError):
self.validate("top.text : 1 ", "top.text == '1'", schema=schema)

with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match top.middle's type: nested"):
with pytest.raises(kql.KqlParseError, match=r"Value doesn't match top.middle's type: nested"):
kql.to_eql("top.middle : 1", schema=schema)

with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"):
kql.to_eql("top:{keyword : 1}", schema=schema)

with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"):
with pytest.raises(kql.KqlParseError, match="Unable to convert nested query to EQL"):
kql.to_eql("top:{middle:{bool: true}}", schema=schema)

invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""]
for ip in invalid_ips:
with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"):
kql.to_eql("dest:{ip}".format(ip=ip), schema=schema)
with pytest.raises(kql.KqlParseError, match=r"Value doesn't match dest's type: ip"):
kql.to_eql(f"dest:{ip}", schema=schema)
Loading