Skip to content

Commit af93358

Browse files
committed
Merge branch 'extending-tests' into 'master'
Extend Test Cases Phase 1 See merge request JDOsborne1/db_to_d2!2
2 parents f2b5381 + 1069b65 commit af93358

6 files changed

+367
-18
lines changed

cmd/configuration.go

+34-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"core"
55
"encoding/json"
66
"fmt"
7+
"io"
78
"os"
89
"virtual"
910
)
@@ -12,30 +13,58 @@ import (
1213
// These are currently set by a file specified by the VIRTUAL_LINKS_PATH environment variable.
1314
// The file should be a json array of virtual links. See the virtual package for more information.
1415
func get_virtual_links() []virtual.VirtualLink {
15-
links := []virtual.VirtualLink{}
16-
links_json, err := os.ReadFile(os.Getenv("VIRTUAL_LINKS_PATH"))
16+
links_reader, err := os.Open(os.Getenv("VIRTUAL_LINKS_PATH"))
1717
if err != nil {
1818
//TODO: Log error, or bubble up instead of printing to console
19+
fmt.Println("Failed to open virtual links file")
20+
}
21+
links, err := read_virtual_links(links_reader)
22+
if err != nil {
1923
fmt.Println("Failed to read virtual links file")
2024
}
21-
json.Unmarshal(links_json, &links)
25+
2226
return links
2327
}
2428

29+
func read_virtual_links(_input io.Reader) ([]virtual.VirtualLink, error) {
30+
links := []virtual.VirtualLink{}
31+
links_json, err := io.ReadAll(_input)
32+
if err != nil {
33+
return links, err
34+
}
35+
err = json.Unmarshal(links_json, &links)
36+
return links, err
37+
}
38+
2539
// get_table_groups returns the table groups for the program.
2640
// These are currently set by a file specified by the TABLE_GROUPS_PATH environment variable.
2741
// The file should be a json array of table groups. See the core package for more information.
2842
func get_table_groups() []core.TableGroup {
2943
table_groups := []core.TableGroup{}
30-
table_groups_json, err := os.ReadFile(os.Getenv("TABLE_GROUPS_PATH"))
44+
table_groups_reader, err := os.Open(os.Getenv("TABLE_GROUPS_PATH"))
3145
if err != nil {
3246
//TODO: Log error, or bubble up instead of printing to console
3347
fmt.Println("Failed to read table groups file")
3448
}
35-
json.Unmarshal(table_groups_json, &table_groups)
49+
50+
table_groups, err = read_table_groups(table_groups_reader)
51+
52+
if err != nil {
53+
fmt.Println("Failed to read table groups file")
54+
}
3655
return table_groups
3756
}
3857

58+
func read_table_groups(_input io.Reader) ([]core.TableGroup, error) {
59+
table_groups := []core.TableGroup{}
60+
table_groups_json, err := io.ReadAll(_input)
61+
if err != nil {
62+
return table_groups, err
63+
}
64+
err = json.Unmarshal(table_groups_json, &table_groups)
65+
return table_groups, err
66+
}
67+
3968
// get_designated_user returns the designated user for the program. Set by the DESIGNATED_USER environment variable.
4069
// This is used to restrict the schema to the tables that the designated user has access to. See the mysql package for more information.
4170
func get_designated_user() string {

pkg/core/d2_generator_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package core
2+
3+
// The d2 tests can be done by using strings.TrimSpace() to remove the whitespace from the output of the d2 generator.
4+
// This should remove the influence of formatting differences between the two strings. This is a simple way to test that that two are equal.
5+
6+
import (
7+
"strings"
8+
"testing"
9+
10+
)
11+
12+
func simple_formatter(_input string) string {
13+
_input = strings.Replace(_input, "\t", "", -1)
14+
_input = strings.Replace(_input, "\n\n", "\n", -1)
15+
_input = strings.TrimSpace(_input)
16+
return _input
17+
}
18+
19+
func Test_d2_generator(t *testing.T) {
20+
schema := Schema{
21+
Tables: []Table{
22+
{
23+
Name: "table1",
24+
Columns: []Column{
25+
{
26+
Name: "col1",
27+
Type: "int",
28+
},
29+
{
30+
Name: "col2",
31+
Type: "varchar",
32+
},
33+
},
34+
},
35+
{
36+
Name: "table2",
37+
Columns: []Column{
38+
{
39+
Name: "col1",
40+
Type: "int",
41+
},
42+
{
43+
Name: "col2",
44+
Type: "varchar",
45+
},
46+
},
47+
},
48+
},
49+
}
50+
51+
expected := `
52+
table1: {
53+
shape: sql_table
54+
col1: int
55+
col2: varchar
56+
}
57+
58+
table2: {
59+
shape: sql_table
60+
col1: int
61+
62+
col2: varchar
63+
}
64+
`
65+
expected = simple_formatter(expected)
66+
67+
actual := simple_formatter(Schema_to_d2(schema, []TableGroup{}))
68+
69+
if actual != expected {
70+
t.Errorf("Expected \n%s\n, got \n%s", expected, actual)
71+
}
72+
73+
}
74+

pkg/core/go.sum

Whitespace-only changes.

pkg/core/restrictions_test.go

+159-13
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,175 @@ import (
55
"testing"
66
)
77

8-
// TestRestrictionsIdentity is a basic test that checks that the identity function returns the same value as the input.
9-
func TestRestrictionsIdentity(t *testing.T) {
10-
//Test data
11-
input := Schema{
12-
Tables: []Table{
13-
{
14-
Name: "test",
15-
Columns: []Column{
16-
{
17-
Name: "test",
8+
var example_schema = Schema{
9+
Tables: []Table{
10+
{
11+
Name: "users",
12+
Columns: []Column{
13+
{
14+
Name: "id",
15+
Type: "int",
16+
Nullable: false,
17+
Key: "PRIMARY KEY",
18+
Extra: "AUTO_INCREMENT",
19+
},
20+
{
21+
Name: "username",
22+
Type: "varchar(255)",
23+
Nullable: false,
24+
Key: "UNIQUE KEY",
25+
Extra: "",
26+
},
27+
{
28+
Name: "email",
29+
Type: "varchar(255)",
30+
Nullable: false,
31+
Key: "UNIQUE KEY",
32+
Extra: "",
33+
},
34+
},
35+
},
36+
{
37+
Name: "posts",
38+
Columns: []Column{
39+
{
40+
Name: "id",
41+
Type: "int",
42+
Nullable: false,
43+
Key: "PRIMARY KEY",
44+
Extra: "AUTO_INCREMENT",
45+
},
46+
{
47+
Name: "title",
48+
Type: "varchar(255)",
49+
Nullable: false,
50+
Key: "",
51+
Extra: "",
52+
},
53+
{
54+
Name: "content",
55+
Type: "text",
56+
Nullable: true,
57+
Key: "",
58+
Extra: "",
59+
},
60+
{
61+
Name: "user_id",
62+
Type: "int",
63+
Nullable: false,
64+
Key: "FOREIGN KEY",
65+
Extra: "",
66+
Reference: &Reference{
67+
Table: "users",
68+
Column: "id",
69+
OnDelete: "CASCADE",
70+
OnUpdate: "CASCADE",
1871
},
1972
},
2073
},
2174
},
75+
{
76+
Name: "meta",
77+
Columns: []Column{
78+
{
79+
Name: "meta_key",
80+
Type: "varchar(255)",
81+
Nullable: false,
82+
Key: "",
83+
Extra: "",
84+
},
85+
{
86+
Name: "meta_value",
87+
Type: "varchar(255)",
88+
Nullable: false,
89+
Key: "",
90+
Extra: "",
91+
},
92+
},
93+
},
94+
},
95+
}
96+
97+
func get_table_names(_input Schema) []string {
98+
var output []string
99+
for _, table := range _input.Tables {
100+
output = append(output, table.Name)
101+
}
102+
return output
103+
}
104+
105+
func get_column_names(_input Table) []string {
106+
var output []string
107+
for _, column := range _input.Columns {
108+
output = append(output, column.Name)
22109
}
23-
expected := input
110+
return output
111+
}
24112

113+
// TestRestrictionsIdentity is a basic test that checks that the identity function returns the same value as the input.
114+
func TestRestrictionsIdentity(t *testing.T) {
25115
//Execute test
26-
actual := Restrict(input, Standard)
116+
actual := Restrict(example_schema, Standard)
27117

28118
//Compare actual to expected
29-
if !reflect.DeepEqual(actual, expected) {
119+
if !reflect.DeepEqual(actual, example_schema) {
30120
t.Log("Identity function failed to return the same value as the input.")
31121
t.Fail()
32122
}
33123
}
124+
125+
func TestRestrictionsMinimalist(t *testing.T) {
126+
actual := Restrict(example_schema, Minimalist)
127+
names := get_table_names(actual)
128+
if !(equal_set(names, []string{"users", "posts"})) {
129+
t.Log("Minimalist cleared inappropriate tables.")
130+
t.Fail()
131+
}
132+
133+
if !(equal_set(get_column_names(actual.Tables[0]), []string{"id", "username", "email"})) {
134+
t.Log("Minimalist cleared columns in a table where all the columns are keys.")
135+
t.Fail()
136+
}
137+
138+
if !(equal_set(get_column_names(actual.Tables[1]), []string{"id", "user_id"})) {
139+
t.Log("Minimalist didn't clear the right columns in a table which has partial keys.")
140+
t.Fail()
141+
}
142+
143+
}
144+
145+
// Simple example of a restriction function. This is intended to simulate the scenario where a user has been
146+
// restricted to not be able to see the PII (Personally Identifiable Information) of other users. But they can
147+
// still access the user_id fields for analysis purposes.
148+
func example_permission_restrictor_analyst(_table Table, _column Column) bool {
149+
return _column.Name == "username" || _column.Name == "email"
150+
}
151+
152+
// Another example of a restriction function. This is intended to simulate the scenario where a user has been
153+
// restricted to not be able to see the content generated by users, but can still edit the metadata associated
154+
// with the a user. This could be an example permission profile for a client account for the user management service.
155+
func example_permission_restrictor_user_profile(_table Table, _column Column) bool {
156+
return _table.Name == "posts"
157+
}
158+
159+
func TestRestrictionsCustom(t *testing.T) {
160+
analyst := Restrict(example_schema, example_permission_restrictor_analyst)
161+
if !(equal_set(get_table_names(analyst), []string{"users", "posts", "meta"})) {
162+
t.Log("Custom cleared inappropriate tables.")
163+
t.Fail()
164+
}
165+
166+
if !(equal_set(get_column_names(analyst.Tables[0]), []string{"id"})) {
167+
t.Log("Custom failed to clear the PII columns in the users table.")
168+
t.Fail()
169+
}
170+
171+
profile_service := Restrict(example_schema, example_permission_restrictor_user_profile)
172+
173+
if !(equal_set(get_table_names(profile_service), []string{"users","meta"})) {
174+
t.Log("Custom cleared inappropriate tables.")
175+
t.Fail()
176+
}
177+
178+
179+
}

pkg/core/utilities.go

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ func in_set(_element string, _set []string) bool {
99
return false
1010
}
1111

12+
func equal_set(_set1 []string, _set2 []string) bool {
13+
if len(_set1) != len(_set2) {
14+
return false
15+
}
16+
for _, element := range _set1 {
17+
if !in_set(element, _set2) {
18+
return false
19+
}
20+
}
21+
return true
22+
}
23+
1224
// Wraps a table name in a group tag if it is in a group. Otherwise, returns the table name.
1325
// This is used to ensure that links in the d2 graph are drawn correctly when table groups are used.
1426
func wrap_name_in_group(_table_name string, _grouping []TableGroup) string {

0 commit comments

Comments
 (0)