Skip to content

Commit 3dd12d5

Browse files
authored
supporting path mirror lookup with precedence during repo policy lookup (ee) (#1591)
* add project prefix directory lookups when loading policies from repo [ee]
1 parent 77ee0b1 commit 3dd12d5

File tree

7 files changed

+105
-47
lines changed

7 files changed

+105
-47
lines changed

cli/pkg/core/policy/policy.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ import (
55
)
66

77
type Provider interface {
8-
GetAccessPolicy(organisation string, repository string, projectname string) (string, error)
9-
GetPlanPolicy(organisation string, repository string, projectname string) (string, error)
8+
GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error)
9+
GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error)
1010
GetDriftPolicy() (string, error)
1111
GetOrganisation() string // TODO: remove this method from here since out of place
1212
}
1313

1414
type Checker interface {
1515
// TODO refactor arguments - use AccessPolicyContext
16-
CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error)
17-
CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, planOutput string) (bool, []string, error)
16+
CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error)
17+
CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error)
1818
CheckDriftPolicy(SCMOrganisation string, SCMrepository string, projectname string) (bool, error)
1919
}
2020

cli/pkg/digger/digger.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func RunJobs(jobs []orchestrator.Job, prService orchestrator.PullRequestService,
8080
SCMrepository := splits[1]
8181

8282
for _, command := range job.Commands {
83-
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, job.RequestedBy, []string{})
83+
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, job.RequestedBy, []string{})
8484

8585
if err != nil {
8686
return false, false, fmt.Errorf("error checking policy: %v", err)
@@ -187,7 +187,7 @@ func reportPolicyError(projectName string, command string, requestedBy string, r
187187
func run(command string, job orchestrator.Job, policyChecker policy.Checker, orgService orchestrator.OrgService, SCMOrganisation string, SCMrepository string, PRNumber *int, requestedBy string, reporter reporting.Reporter, lock locking2.Lock, prService orchestrator.PullRequestService, projectNamespace string, workingDir string, planStorage storage.PlanStorage, appliesPerProject map[string]bool) (*execution.DiggerExecutorResult, string, error) {
188188
log.Printf("Running '%s' for project '%s' (workflow: %s)\n", command, job.ProjectName, job.ProjectWorkflow)
189189

190-
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, requestedBy, []string{})
190+
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, requestedBy, []string{})
191191

192192
if err != nil {
193193
return nil, "error checking policy", fmt.Errorf("error checking policy: %v", err)
@@ -278,7 +278,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
278278
} else if planPerformed {
279279
if isNonEmptyPlan {
280280
reportTerraformPlanOutput(reporter, projectLock.LockId(), plan)
281-
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, planJsonOutput)
281+
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, planJsonOutput)
282282
if err != nil {
283283
msg := fmt.Sprintf("Failed to validate plan. %v", err)
284284
log.Printf(msg)
@@ -381,7 +381,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
381381
return nil, msg, fmt.Errorf(msg)
382382
}
383383

384-
_, violations, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, terraformPlanJsonStr)
384+
_, violations, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, terraformPlanJsonStr)
385385
if err != nil {
386386
msg := fmt.Sprintf("Failed to check plan policy. %v", err)
387387
log.Printf(msg)
@@ -393,7 +393,7 @@ func run(command string, job orchestrator.Job, policyChecker policy.Checker, org
393393
planPolicyViolations = []string{}
394394
}
395395

396-
allowedToApply, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, command, job.PullRequestNumber, requestedBy, planPolicyViolations)
396+
allowedToApply, err := policyChecker.CheckAccessPolicy(orgService, &prService, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, job.PullRequestNumber, requestedBy, planPolicyViolations)
397397
if err != nil {
398398
msg := fmt.Sprintf("Failed to run plan policy check before apply. %v", err)
399399
log.Printf(msg)
@@ -544,7 +544,7 @@ func RunJob(
544544

545545
for _, command := range job.Commands {
546546

547-
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, nil, SCMOrganisation, SCMrepository, job.ProjectName, command, nil, requestedBy, []string{})
547+
allowedToPerformCommand, err := policyChecker.CheckAccessPolicy(orgService, nil, SCMOrganisation, SCMrepository, job.ProjectName, job.ProjectDir, command, nil, requestedBy, []string{})
548548

549549
if err != nil {
550550
return fmt.Errorf("error checking policy: %v", err)
@@ -619,7 +619,7 @@ func RunJob(
619619
}
620620
return fmt.Errorf(msg)
621621
}
622-
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, planJsonOutput)
622+
planIsAllowed, messages, err := policyChecker.CheckPlanPolicy(SCMrepository, SCMOrganisation, job.ProjectName, job.ProjectDir, planJsonOutput)
623623
log.Printf(strings.Join(messages, "\n"))
624624
if err != nil {
625625
msg := fmt.Sprintf("Failed to validate plan %v", err)

cli/pkg/policy/policy.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ type DiggerHttpPolicyProvider struct {
3535
type NoOpPolicyChecker struct {
3636
}
3737

38-
func (p NoOpPolicyChecker) CheckAccessPolicy(_ orchestrator.OrgService, _ *orchestrator.PullRequestService, _ string, _ string, _ string, _ string, _ *int, _ string, _ []string) (bool, error) {
38+
func (p NoOpPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
3939
return true, nil
4040
}
4141

42-
func (p NoOpPolicyChecker) CheckPlanPolicy(_ string, _ string, _ string, _ string) (bool, []string, error) {
42+
func (p NoOpPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
4343
return true, nil, nil
4444
}
4545

@@ -181,7 +181,7 @@ func getPlanPolicyForNamespace(p *DiggerHttpPolicyProvider, namespace string, pr
181181
}
182182

183183
// GetPolicy fetches policy for particular project, if not found then it will fallback to org level policy
184-
func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string) (string, error) {
184+
func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
185185
namespace := fmt.Sprintf("%v-%v", organisation, repo)
186186
content, resp, err := getAccessPolicyForNamespace(&p, namespace, projectName)
187187
if err != nil {
@@ -211,7 +211,7 @@ func (p DiggerHttpPolicyProvider) GetAccessPolicy(organisation string, repo stri
211211
}
212212
}
213213

214-
func (p DiggerHttpPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string) (string, error) {
214+
func (p DiggerHttpPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
215215
namespace := fmt.Sprintf("%v-%v", organisation, repo)
216216
content, resp, err := getPlanPolicyForNamespace(&p, namespace, projectName)
217217
if err != nil {
@@ -264,9 +264,9 @@ type DiggerPolicyChecker struct {
264264
}
265265

266266
// TODO refactor to use AccessPolicyContext - too many arguments
267-
func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
267+
func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
268268

269-
policy, err := p.PolicyProvider.GetAccessPolicy(SCMOrganisation, SCMrepository, projectName)
269+
policy, err := p.PolicyProvider.GetAccessPolicy(SCMOrganisation, SCMrepository, projectName, projectDir)
270270

271271
if err != nil {
272272
log.Printf("Error while fetching policy: %v", err)
@@ -331,8 +331,8 @@ func (p DiggerPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService
331331
return true, nil
332332
}
333333

334-
func (p DiggerPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectName string, planOutput string) (bool, []string, error) {
335-
policy, err := p.PolicyProvider.GetPlanPolicy(SCMOrganisation, SCMrepository, projectName)
334+
func (p DiggerPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
335+
policy, err := p.PolicyProvider.GetPlanPolicy(SCMOrganisation, SCMrepository, projectname, projectDir)
336336
if err != nil {
337337
return false, nil, fmt.Errorf("failed get plan policy: %v", err)
338338
}

cli/pkg/policy/policy_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ func (s *OpaExamplePolicyProvider) GetOrganisation() string {
5050
type DiggerDefaultPolicyProvider struct {
5151
}
5252

53-
func (s *DiggerDefaultPolicyProvider) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
53+
func (s *DiggerDefaultPolicyProvider) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
5454
return DefaultAccessPolicy, nil
5555
}
5656

57-
func (s *DiggerDefaultPolicyProvider) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
57+
func (s *DiggerDefaultPolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
5858
return "package digger\n", nil
5959
}
6060

@@ -69,7 +69,7 @@ func (s *DiggerDefaultPolicyProvider) GetOrganisation() string {
6969
type DiggerExamplePolicyProvider struct {
7070
}
7171

72-
func (s *DiggerExamplePolicyProvider) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
72+
func (s *DiggerExamplePolicyProvider) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
7373
return "package digger\n" +
7474
"\n" +
7575
"user_permissions := {\n" +
@@ -85,7 +85,7 @@ func (s *DiggerExamplePolicyProvider) GetAccessPolicy(_ string, _ string, _ stri
8585
"", nil
8686
}
8787

88-
func (s *DiggerExamplePolicyProvider) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
88+
func (s *DiggerExamplePolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
8989
return "package digger\n", nil
9090
}
9191

@@ -100,7 +100,7 @@ func (s *DiggerExamplePolicyProvider) GetOrganisation() string {
100100
type DiggerExamplePolicyProvider2 struct {
101101
}
102102

103-
func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(_ string, _ string, _ string) (string, error) {
103+
func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
104104
return "package digger\n" +
105105
"\n" +
106106
"user_permissions := {\n" +
@@ -119,7 +119,7 @@ func (s *DiggerExamplePolicyProvider2) GetAccessPolicy(_ string, _ string, _ str
119119
"", nil
120120
}
121121

122-
func (s *DiggerExamplePolicyProvider2) GetPlanPolicy(_ string, _ string, _ string) (string, error) {
122+
func (s *DiggerExamplePolicyProvider2) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
123123
return "package digger\n\ndeny[sprintf(message, [resource.address])] {\n message := \"Cannot create EC2 instances!\"\n resource := input.terraform.resource_changes[_]\n resource.change.actions[_] == \"create\"\n resource[type] == \"aws_instance\"\n}\n", nil
124124
}
125125

@@ -223,7 +223,7 @@ func TestDiggerAccessPolicyChecker_Check(t *testing.T) {
223223
PolicyProvider: tt.fields.PolicyProvider,
224224
}
225225
ciService := utils.MockPullRequestManager{Teams: []string{"engineering"}}
226-
got, err := p.CheckAccessPolicy(ciService, nil, tt.organisation, tt.name, tt.name, tt.command, nil, tt.requestedBy, tt.planPolicyViolations)
226+
got, err := p.CheckAccessPolicy(ciService, nil, tt.organisation, tt.name, tt.name, "", tt.command, nil, tt.requestedBy, tt.planPolicyViolations)
227227
if (err != nil) != tt.wantErr {
228228
t.Errorf("DiggerPolicyChecker.CheckAccessPolicy() error = %v, wantErr %v", err, tt.wantErr)
229229
return
@@ -275,7 +275,7 @@ func TestDiggerPlanPolicyChecker_Check(t *testing.T) {
275275
var p = &DiggerPolicyChecker{
276276
PolicyProvider: tt.fields.PolicyProvider,
277277
}
278-
got, _, err := p.CheckPlanPolicy("", "", "", tt.planJsonOutput)
278+
got, _, err := p.CheckPlanPolicy("", "", "", "", tt.planJsonOutput)
279279
if (err != nil) != tt.wantErr {
280280
t.Errorf("DiggerPolicyChecker.CheckPlanPolicy() error = %v, wantErr %v", err, tt.wantErr)
281281
return

cli/pkg/utils/mocks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ func (tf *MockTerraform) Plan() (bool, string, string, error) {
2525
type MockPolicyChecker struct {
2626
}
2727

28-
func (t MockPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, command string, ptr *int, requestedBy string, planPolicyViolations []string) (bool, error) {
28+
func (t MockPolicyChecker) CheckAccessPolicy(ciService orchestrator.OrgService, prService *orchestrator.PullRequestService, SCMOrganisation string, SCMrepository string, projectName string, projectDir string, command string, prNumber *int, requestedBy string, planPolicyViolations []string) (bool, error) {
2929
return false, nil
3030
}
3131

32-
func (t MockPolicyChecker) CheckPlanPolicy(projectName string, SCMOrganisation string, command string, requestedBy string) (bool, []string, error) {
32+
func (t MockPolicyChecker) CheckPlanPolicy(SCMrepository string, SCMOrganisation string, projectname string, projectDir string, planOutput string) (bool, []string, error) {
3333
return false, nil, nil
3434
}
3535

ee/cli/pkg/policy/policy.go

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package policy
22

33
import (
4-
"fmt"
54
"github.com/diggerhq/digger/ee/cli/pkg/utils"
5+
"github.com/samber/lo"
66
"log"
77
"os"
88
"path"
9+
"path/filepath"
10+
"slices"
11+
"strings"
912
)
1013

1114
const DefaultAccessPolicy = `
@@ -31,31 +34,61 @@ func getContents(filePath string) (string, error) {
3134
return string(contents), nil
3235
}
3336

34-
func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName string, fileName string) (string, error) {
37+
// GetPrefixesForPath
38+
// @path is the total path example /dev/vpc/subnets
39+
// @filename is the name of the file to search for example access.rego
40+
// returns the list of prefixes in priority order example:
41+
// /dev/vpc/subnets/access.rego
42+
// /dev/vpc/access.rego
43+
// /dev/access.rego
44+
func GetPrefixesForPath(path string, fileName string) []string {
45+
var prefixes []string
46+
parts := strings.Split(filepath.Clean(path), string(filepath.Separator))
47+
for i := range parts {
48+
prefixes = append(prefixes, filepath.Join(parts[:i+1]...))
49+
}
50+
51+
slices.Reverse(prefixes)
52+
prefixes = lo.FilterMap(prefixes, func(item string, index int) (string, bool) {
53+
// if input path was absolute then result should be absolute and ignore last item ""
54+
if parts[0] == "" {
55+
return string(filepath.Separator) + item + string(filepath.Separator) + fileName, index < len(prefixes)-1
56+
} else {
57+
return item + string(filepath.Separator) + fileName, index < len(prefixes)
58+
}
59+
})
60+
61+
return prefixes
62+
}
63+
64+
func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName string, projectDir string, fileName string) (string, error) {
3565
var contents string
3666
err := utils.CloneGitRepoAndDoAction(p.ManagementRepoUrl, "main", p.GitToken, func(basePath string) error {
67+
// we start with the project directory path prefixes as the highest priority
68+
prefixes := GetPrefixesForPath(path.Join(basePath, projectDir), fileName)
69+
70+
// we also add a known location as a least priority item
3771
orgAccesspath := path.Join(basePath, "policies", fileName)
3872
repoAccesspath := path.Join(basePath, "policies", repo, fileName)
3973
projectAccessPath := path.Join(basePath, "policies", repo, projectName, fileName)
74+
prefixes = append(prefixes, projectAccessPath)
75+
prefixes = append(prefixes, repoAccesspath)
76+
prefixes = append(prefixes, orgAccesspath)
4077

41-
log.Printf("loading repo orgAccess %v repoAccess %v projectAcces %v", orgAccesspath, repoAccesspath, projectAccessPath)
42-
var err error
43-
contents, err = getContents(projectAccessPath)
44-
if os.IsNotExist(err) {
45-
contents, err = getContents(repoAccesspath)
78+
for _, pathPrefix := range prefixes {
79+
var err error
80+
contents, err = getContents(pathPrefix)
81+
log.Printf("path: %v contents: %v, err: %v", pathPrefix, contents, err)
82+
if err == nil {
83+
return nil
84+
}
4685
if os.IsNotExist(err) {
47-
contents, err = getContents(orgAccesspath)
48-
if os.IsNotExist(err) {
49-
return nil
50-
} else {
51-
fmt.Errorf("could not find any matching policy for %v,%v", repo, projectName)
52-
}
86+
continue
5387
} else {
5488
return err
5589
}
56-
} else {
57-
return err
5890
}
91+
5992
return nil
6093
})
6194
if err != nil {
@@ -65,11 +98,11 @@ func (p DiggerRepoPolicyProvider) getPolicyFileContents(repo string, projectName
6598
}
6699

67100
// GetPolicy fetches policy for particular project, if not found then it will fallback to org level policy
68-
func (p DiggerRepoPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string) (string, error) {
69-
return p.getPolicyFileContents(repo, projectName, "access.rego")
101+
func (p DiggerRepoPolicyProvider) GetAccessPolicy(organisation string, repo string, projectName string, projectDir string) (string, error) {
102+
return p.getPolicyFileContents(repo, projectName, projectDir, "access.rego")
70103
}
71104

72-
func (p DiggerRepoPolicyProvider) GetPlanPolicy(organisation string, repo string, projectName string) (string, error) {
105+
func (p DiggerRepoPolicyProvider) GetPlanPolicy(organisation string, repository string, projectname string, projectDir string) (string, error) {
73106
return "", nil
74107
}
75108

ee/cli/pkg/policy/policy_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package policy
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
"log"
6+
"os"
7+
"testing"
8+
)
9+
10+
func init() {
11+
log.SetOutput(os.Stdout)
12+
log.SetFlags(log.Ldate | log.Ltime)
13+
}
14+
15+
func TestGetPrefixesForPath(t *testing.T) {
16+
prefixes := GetPrefixesForPath("dev/vpc/subnets", "access.rego")
17+
assert.Equal(t, []string{"dev/vpc/subnets/access.rego", "dev/vpc/access.rego", "dev/access.rego"}, prefixes)
18+
log.Printf("%v", prefixes)
19+
}
20+
21+
func TestGetPrefixesForPathAbsolute(t *testing.T) {
22+
prefixes := GetPrefixesForPath("/dev/vpc/subnets", "access.rego")
23+
assert.Equal(t, []string{"/dev/vpc/subnets/access.rego", "/dev/vpc/access.rego", "/dev/access.rego"}, prefixes)
24+
log.Printf("%v", prefixes)
25+
}

0 commit comments

Comments
 (0)