Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BED-5249: feat: Get Tier/Label List Endpoint #1213

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/api/src/api/registration/v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) {
routerInst.POST(fmt.Sprintf("/api/v2/asset-groups/{%s}/selectors", api.URIPathVariableAssetGroupID), resources.UpdateAssetGroupSelectors).RequirePermissions(permissions.GraphDBWrite),

// Asset group management API
routerInst.GET("/api/v2/asset-group-tags", resources.GetAssetGroupTags).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
routerInst.POST(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors", api.URIPathVariableAssetGroupTagID), resources.CreateAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBWrite),
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTagSelectors).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTag).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
Expand Down
65 changes: 65 additions & 0 deletions cmd/api/src/api/v2/assetgrouptags.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package v2

import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/database"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/queries"
"github.com/specterops/bloodhound/src/utils/validation"
Expand All @@ -39,6 +41,69 @@ const (
ErrInvalidSelectorType = "invalid selector type"
)

type GetAssetGroupTagsResponse struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably be private?

AssetGroupTags model.AssetGroupTags `json:"asset_group_tags"`
Counts struct {
Selectors map[int]int `json:"selectors"`
Members map[int]int `json:"members"`
} `json:"counts,omitempty"`
}

func (s Resources) GetAssetGroupTags(response http.ResponseWriter, request *http.Request) {
const (
pnameTagType = "type"
pnameIncludeCounts = "includeCounts"
)
var (
pvalsTagType = map[string]model.AssetGroupTagType{
strconv.Itoa(int(model.AssetGroupTagTypeLabel)): model.AssetGroupTagTypeLabel,
strconv.Itoa(int(model.AssetGroupTagTypeTier)): model.AssetGroupTagTypeTier,
"": model.AssetGroupTagTypeAll, // default
}
pvalsIncludeCounts = map[string]bool{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the "p" stand for in these const and variable names?

"false": false,
"true": true,
"": false, // default
}
)

var params = request.URL.Query()

if paramTagType, ok := pvalsTagType[params.Get(pnameTagType)]; !ok {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "Invalid value specifed for tag type", request), response)
} else if paramIncludeCounts, ok := pvalsIncludeCounts[params.Get(pnameIncludeCounts)]; !ok {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "Invalid value specifed for include counts", request), response)
} else if tags, err := s.DB.GetAssetGroupTags(request.Context(), paramTagType); err != nil && !errors.Is(err, database.ErrNotFound) {
api.HandleDatabaseError(request, response, err)
} else {
resp := GetAssetGroupTagsResponse{AssetGroupTags: tags}
if paramIncludeCounts {
ids := make([]int, 0, len(tags))
for i := range tags {
ids = append(ids, tags[i].ID)
}
if selectorCounts, err := s.DB.GetAssetGroupTagSelectorCounts(request.Context(), ids); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
resp.Counts.Selectors = selectorCounts
}
memberCounts := make(map[int]int, len(tags))
for _, tag := range tags {
// TODO: use a more efficient query method
if nodelist, err := s.GraphQuery.GetNodesByKind(request.Context(), tag.ToKind()); err != nil {
Comment on lines +93 to +94
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we end up creating the method in the referenced comment we can use it here too instead of GetNodesByKind

https://github.com/SpecterOps/BloodHound/pull/1307/files#r2027599359

api.HandleDatabaseError(request, response, err)
return
} else {
memberCounts[tag.ID] = nodelist.Len()
}
}
resp.Counts.Members = memberCounts
}
api.WriteBasicResponse(request.Context(), resp, http.StatusOK, response)
}
}

// Checks that the selector seeds are valid.
func validateSelectorSeeds(graph queries.Graph, seeds []model.SelectorSeed) error {
// all seeds must be of the same type
Expand Down
248 changes: 248 additions & 0 deletions cmd/api/src/api/v2/assetgrouptags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

uuid2 "github.com/gofrs/uuid"
"github.com/gorilla/mux"
"github.com/specterops/bloodhound/dawgs/graph"
"github.com/specterops/bloodhound/headers"
"github.com/specterops/bloodhound/mediatypes"
"github.com/specterops/bloodhound/src/api"
Expand All @@ -43,6 +44,253 @@ import (
"go.uber.org/mock/gomock"
)

func TestResources_GetAssetGroupTags(t *testing.T) {
const (
queryParamTagType = "type"
queryParamIncludeCounts = "includeCounts"
)
var (
mockCtrl = gomock.NewController(t)
mockDB = mocks_db.NewMockDatabase(mockCtrl)
mockGraphDb = mocks_graph.NewMockGraph(mockCtrl)
resourcesInst = v2.Resources{
DB: mockDB,
GraphQuery: mockGraphDb,
}
)

defer mockCtrl.Finish()

apitest.
NewHarness(t, resourcesInst.GetAssetGroupTags).
Run([]apitest.Case{
{
Name: "InvalidTagType",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamTagType, "123456")
},
Test: func(output apitest.Output) {
apitest.StatusCode(output, http.StatusBadRequest)
},
},
{
Name: "InvalidIncludeCounts",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamIncludeCounts, "blah")
},
Test: func(output apitest.Output) {
apitest.StatusCode(output, http.StatusBadRequest)
},
},
{
Name: "DatabaseError",
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), gomock.Any()).
Return(model.AssetGroupTags{}, errors.New("failure"))
},
Test: func(output apitest.Output) {
apitest.StatusCode(output, http.StatusInternalServerError)
apitest.BodyContains(output, api.ErrorResponseDetailsInternalServerError)
},
},
{
Name: "NoResults",
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), gomock.Any()).
Return(model.AssetGroupTags{}, database.ErrNotFound)
},
Test: func(output apitest.Output) {
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, model.AssetGroupTags{}, resp.AssetGroupTags)
},
},
{
Name: "TagTypeLabel",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamTagType, "2") // model.AssetGroupTagTypeLabel
},
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), model.AssetGroupTagTypeLabel).
Return(model.AssetGroupTags{
model.AssetGroupTag{ID: 1, Type: model.AssetGroupTagTypeLabel},
model.AssetGroupTag{ID: 2, Type: model.AssetGroupTagTypeLabel},
}, nil)
},
Test: func(output apitest.Output) {
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, 2, len(resp.AssetGroupTags))
for _, t := range resp.AssetGroupTags {
apitest.Equal(output, model.AssetGroupTagTypeLabel, t.Type)
}
},
},
{
Name: "TagTypeTier",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamTagType, "1") // model.AssetGroupTagTypeTier
},
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), model.AssetGroupTagTypeTier).
Return(model.AssetGroupTags{
model.AssetGroupTag{ID: 1, Type: model.AssetGroupTagTypeTier},
model.AssetGroupTag{ID: 2, Type: model.AssetGroupTagTypeTier},
}, nil)
},
Test: func(output apitest.Output) {
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, 2, len(resp.AssetGroupTags))
for _, t := range resp.AssetGroupTags {
apitest.Equal(output, model.AssetGroupTagTypeTier, t.Type)
}
},
},
{
Name: "TagTypeDefault",
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), model.AssetGroupTagTypeAll).
Return(model.AssetGroupTags{
model.AssetGroupTag{ID: 1, Type: model.AssetGroupTagTypeLabel},
model.AssetGroupTag{ID: 2, Type: model.AssetGroupTagTypeLabel},
model.AssetGroupTag{ID: 3, Type: model.AssetGroupTagTypeTier},
model.AssetGroupTag{ID: 4, Type: model.AssetGroupTagTypeTier},
}, nil)
},
Test: func(output apitest.Output) {
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, 4, len(resp.AssetGroupTags))
tierCount := 0
for _, t := range resp.AssetGroupTags {
if t.Type == model.AssetGroupTagTypeTier {
apitest.Equal(output, model.AssetGroupTagTypeTier, t.Type)
tierCount++
} else {
apitest.Equal(output, model.AssetGroupTagTypeLabel, t.Type)
}
}
apitest.Equal(output, 2, tierCount)
},
},
{
Name: "IncludeCounts Selector counts",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamIncludeCounts, "true")
},
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), model.AssetGroupTagTypeAll).
Return(model.AssetGroupTags{
model.AssetGroupTag{ID: 1, Type: model.AssetGroupTagTypeTier},
model.AssetGroupTag{ID: 2, Type: model.AssetGroupTagTypeTier},
model.AssetGroupTag{ID: 3, Type: model.AssetGroupTagTypeLabel},
model.AssetGroupTag{ID: 4, Type: model.AssetGroupTagTypeLabel},
}, nil)
mockDB.EXPECT().
GetAssetGroupTagSelectorCounts(gomock.Any(), []int{1, 2, 3, 4}).
Return(map[int]int{
1: 5,
2: 10,
3: 0,
4: 8,
}, nil)
mockGraphDb.EXPECT().
GetNodesByKind(gomock.Any(), gomock.Any()).
Return(graph.EmptyNodeSet(), nil).Times(4)
},
Test: func(output apitest.Output) {
expectedCounts := map[int]int{
1: 5,
2: 10,
3: 0,
4: 8,
}
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, 4, len(resp.AssetGroupTags))
apitest.Equal(output, 4, len(resp.Counts.Selectors))
for _, t := range resp.AssetGroupTags {
expCount, ok := expectedCounts[t.ID]
apitest.Equal(output, true, ok)
_, ok = resp.Counts.Selectors[t.ID]
apitest.Equal(output, true, ok)
apitest.Equal(output, expCount, resp.Counts.Selectors[t.ID])
}
},
},
{
Name: "IncludeCounts member counts",
Input: func(input *apitest.Input) {
apitest.AddQueryParam(input, queryParamIncludeCounts, "true")
},
Setup: func() {
mockDB.EXPECT().
GetAssetGroupTags(gomock.Any(), model.AssetGroupTagTypeAll).
Return(model.AssetGroupTags{
model.AssetGroupTag{ID: 1, Name: "testlabel", Type: model.AssetGroupTagTypeLabel},
model.AssetGroupTag{ID: 2, Name: "testtier", Type: model.AssetGroupTagTypeTier},
}, nil)
mockDB.EXPECT().
GetAssetGroupTagSelectorCounts(gomock.Any(), []int{1, 2}).
Return(map[int]int{
1: 1,
2: 1,
}, nil)
mockGraphDb.EXPECT().
GetNodesByKind(gomock.Any(), gomock.Any()).
Return(graph.NewNodeSet(
graph.NewNode(graph.ID(1), graph.NewProperties()),
graph.NewNode(graph.ID(2), graph.NewProperties()),
graph.NewNode(graph.ID(3), graph.NewProperties()),
graph.NewNode(graph.ID(4), graph.NewProperties()),
graph.NewNode(graph.ID(5), graph.NewProperties()),
graph.NewNode(graph.ID(6), graph.NewProperties()),
), nil).
Times(1)
mockGraphDb.EXPECT().
GetNodesByKind(gomock.Any(), gomock.Any()).
Return(graph.NewNodeSet(
graph.NewNode(graph.ID(1), graph.NewProperties()),
graph.NewNode(graph.ID(2), graph.NewProperties()),
graph.NewNode(graph.ID(3), graph.NewProperties()),
graph.NewNode(graph.ID(4), graph.NewProperties()),
), nil).
Times(1)
},
Test: func(output apitest.Output) {
expectedMemberCounts := map[int]int{
1: 6,
2: 4,
}
resp := v2.GetAssetGroupTagsResponse{}
apitest.StatusCode(output, http.StatusOK)
apitest.UnmarshalData(output, &resp)
apitest.Equal(output, 2, len(resp.AssetGroupTags))
apitest.Equal(output, 2, len(resp.Counts.Members))
for _, t := range resp.AssetGroupTags {
expCount, ok := expectedMemberCounts[t.ID]
apitest.Equal(output, true, ok)
_, ok = resp.Counts.Members[t.ID]
apitest.Equal(output, true, ok)
apitest.Equal(output, expCount, resp.Counts.Members[t.ID])
}
},
},
})
}

func TestResources_CreateAssetGroupTagSelector(t *testing.T) {
var (
mockCtrl = gomock.NewController(t)
Expand Down
Loading
Loading