Skip to content

Commit 0f1d1b6

Browse files
committed
feat: selector engine
1 parent 36ebbec commit 0f1d1b6

File tree

9 files changed

+634
-10
lines changed

9 files changed

+634
-10
lines changed

cmd/api/src/daemons/datapipe/agt.go

+381
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
// Copyright 2025 Specter Ops, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package datapipe
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"log/slog"
23+
"slices"
24+
"sync"
25+
26+
"github.com/specterops/bloodhound/analysis"
27+
"github.com/specterops/bloodhound/bhlog/measure"
28+
"github.com/specterops/bloodhound/dawgs/cardinality"
29+
"github.com/specterops/bloodhound/dawgs/graph"
30+
"github.com/specterops/bloodhound/dawgs/ops"
31+
"github.com/specterops/bloodhound/dawgs/query"
32+
"github.com/specterops/bloodhound/dawgs/traversal"
33+
"github.com/specterops/bloodhound/graphschema/ad"
34+
"github.com/specterops/bloodhound/graphschema/azure"
35+
"github.com/specterops/bloodhound/graphschema/common"
36+
"github.com/specterops/bloodhound/src/database"
37+
"github.com/specterops/bloodhound/src/database/types/null"
38+
"github.com/specterops/bloodhound/src/model"
39+
)
40+
41+
func fetchNodesFromSeeds(ctx context.Context, graphDb graph.Database, seeds []model.SelectorSeed) (graph.ThreadSafeNodeSet, error) {
42+
var (
43+
seedNodes = graph.NodeSet{}
44+
result = graph.NewThreadSafeNodeSet(graph.NodeSet{})
45+
err error
46+
)
47+
48+
if err = graphDb.ReadTransaction(ctx, func(tx graph.Transaction) error {
49+
// Then we grab the nodes that should be selected
50+
for _, seed := range seeds {
51+
switch seed.Type {
52+
case model.SelectorTypeObjectId:
53+
if seedNodes, err = ops.FetchNodeSet(tx.Nodes().Filter(query.Equals(query.NodeProperty(common.ObjectID.String()), seed.Value))); err != nil {
54+
return err
55+
} else {
56+
seedNodes.AddSet(seedNodes)
57+
}
58+
case model.SelectorTypeCypher:
59+
if seedNodes, err = ops.FetchNodesByQuery(tx, seed.Value); err != nil {
60+
return err
61+
} else {
62+
seedNodes.AddSet(seedNodes)
63+
}
64+
default:
65+
slog.WarnContext(ctx, fmt.Sprintf("AGT: Unsupported selector type: %d", seed.Type))
66+
}
67+
}
68+
return nil
69+
}); err != nil {
70+
return *result, err
71+
}
72+
73+
traversalInst := traversal.New(graphDb, analysis.MaximumDatabaseParallelWorkers)
74+
// Expand to child nodes as needed based on kind
75+
for _, node := range seedNodes {
76+
if err = expandNodes(ctx, traversalInst, node, result); err != nil {
77+
return *result, err
78+
}
79+
}
80+
81+
return *result, err
82+
}
83+
84+
func expandNodes(ctx context.Context, tx traversal.Traversal, node *graph.Node, result *graph.ThreadSafeNodeSet) error {
85+
var pattern traversal.PatternContinuation
86+
87+
// Add visited node to result set
88+
result.AddIfNotExists(node)
89+
90+
switch {
91+
case node.Kinds.ContainsOneOf(ad.Group, azure.Group):
92+
pattern = traversal.NewPattern().InboundWithDepth(0, 0, query.And(
93+
query.KindIn(query.Relationship(), ad.MemberOf, azure.MemberOf),
94+
query.KindIn(query.Start(), ad.Entity, azure.Entity),
95+
))
96+
case node.Kinds.ContainsOneOf(ad.OU, ad.Container, azure.ResourceGroup, azure.ManagementGroup, azure.Subscription):
97+
pattern = traversal.NewPattern().OutboundWithDepth(0, 0, query.And(
98+
query.KindIn(query.Relationship(), ad.Contains, azure.Contains),
99+
query.KindIn(query.Start(), ad.Entity, azure.Entity),
100+
))
101+
case node.Kinds.ContainsOneOf(azure.Role):
102+
pattern = traversal.NewPattern().InboundWithDepth(0, 0, query.And(
103+
query.KindIn(query.Relationship(), azure.HasRole),
104+
query.KindIn(query.Start(), azure.Entity),
105+
))
106+
default:
107+
// Skip any that do not need expanding
108+
return nil
109+
}
110+
111+
addedNodes := graph.NewThreadSafeNodeSet(graph.NodeSet{})
112+
if err := tx.BreadthFirst(ctx, traversal.Plan{
113+
Root: node,
114+
Driver: pattern.Do(func(path *graph.PathSegment) error {
115+
if path.Trunk != nil {
116+
if result.AddIfNotExists(path.Trunk.Node) {
117+
addedNodes.Add(path.Trunk.Node)
118+
}
119+
}
120+
if result.AddIfNotExists(path.Node) {
121+
addedNodes.Add(path.Node)
122+
}
123+
124+
return nil
125+
})}); err != nil {
126+
return err
127+
}
128+
129+
if addedNodes != nil && addedNodes.Len() > 0 {
130+
for _, node := range addedNodes.Slice() {
131+
// Expand to child nodes as needed based on kind
132+
if err := expandNodes(ctx, tx, node, result); err != nil {
133+
return err
134+
}
135+
}
136+
}
137+
138+
return nil
139+
}
140+
141+
// TODO Batching?
142+
func selectNodes(ctx context.Context, db database.Database, graphDb graph.Database, selector model.AssetGroupTagSelector) error {
143+
var (
144+
countInserted int
145+
146+
certified = model.AssetGroupCertificationNone
147+
certifiedBy null.String
148+
149+
oldSelectedNodes []model.AssetGroupSelectorNode
150+
151+
nodeIdsToDelete []graph.ID
152+
nodeIdsToUpdate []graph.ID
153+
)
154+
if selector.AutoCertify {
155+
certified = model.AssetGroupCertificationAuto
156+
certifiedBy = null.StringFrom(model.AssetGroupActorSystem)
157+
}
158+
159+
// 1. Grab the graph nodes
160+
if nodes, err := fetchNodesFromSeeds(ctx, graphDb, selector.Seeds); err != nil {
161+
return err
162+
// 2. Grab the already selected nodes
163+
} else if oldSelectedNodes, err = db.GetSelectorNodesBySelectorIds(ctx, selector.ID); err != nil {
164+
return err
165+
} else {
166+
oldSelectedNodesByNodeId := make(map[graph.ID]*model.AssetGroupSelectorNode)
167+
for _, node := range oldSelectedNodes {
168+
oldSelectedNodesByNodeId[node.NodeId] = &node
169+
}
170+
171+
// 3. Range the graph nodes and insert any that haven't been inserted yet, mark for update any that need updating, pare down the existing map for future deleting
172+
for _, id := range nodes.IDs() {
173+
// Missing, insert the record
174+
if oldSelectedNodesByNodeId[id] == nil {
175+
if err = db.InsertSelectorNode(ctx, selector.ID, id, certified, certifiedBy); err != nil {
176+
return err
177+
}
178+
countInserted++
179+
// Auto certify is enabled but this node hasn't been certified, certify it
180+
} else if selector.AutoCertify && oldSelectedNodesByNodeId[id].Certified == model.AssetGroupCertificationNone {
181+
nodeIdsToUpdate = append(nodeIdsToUpdate, id)
182+
delete(oldSelectedNodesByNodeId, id)
183+
} else {
184+
delete(oldSelectedNodesByNodeId, id)
185+
}
186+
}
187+
188+
// Update the selected nodes that need updating
189+
if len(nodeIdsToUpdate) > 0 {
190+
if err = db.UpdateSelectorNodesByNodeId(ctx, selector.ID, certified, certifiedBy, nodeIdsToUpdate...); err != nil {
191+
return err
192+
}
193+
}
194+
195+
// Delete the selected nodes that need to be deleted
196+
if len(oldSelectedNodesByNodeId) > 0 {
197+
for nodeId := range oldSelectedNodesByNodeId {
198+
nodeIdsToDelete = append(nodeIdsToDelete, nodeId)
199+
}
200+
if err = db.DeleteSelectorNodesByNodeId(ctx, selector.ID, nodeIdsToDelete...); err != nil {
201+
return err
202+
}
203+
}
204+
205+
slog.Info("AGT: Completed selecting", "selector", selector.Name, "countTotal", nodes.Len(), "countInserted", countInserted, "countUpdated", len(nodeIdsToUpdate), "countDeleted", len(nodeIdsToDelete))
206+
}
207+
return nil
208+
}
209+
210+
func SelectAssetGroupNodes(ctx context.Context, db database.Database, graphDb graph.Database) error {
211+
defer measure.ContextMeasure(ctx, slog.LevelInfo, "Finished selecting asset group nodes via new selectors")()
212+
213+
if tags, err := db.GetAssetGroupTagForSelection(ctx); err != nil {
214+
return err
215+
} else {
216+
for _, tag := range tags {
217+
if selectors, err := db.GetAssetGroupTagSelectorsByTagId(ctx, tag.ID); err != nil {
218+
return err
219+
} else {
220+
wg := sync.WaitGroup{}
221+
for _, selector := range selectors {
222+
if !selector.DisabledAt.IsZero() {
223+
continue
224+
}
225+
// Parallelize the selection of nodes
226+
go func() {
227+
defer wg.Done()
228+
if err = selectNodes(ctx, db, graphDb, selector); err != nil {
229+
slog.Error("AGT: Error selecting nodes", "selector", selector, "err", err)
230+
}
231+
}()
232+
wg.Add(1)
233+
}
234+
wg.Wait()
235+
}
236+
}
237+
}
238+
return nil
239+
}
240+
241+
// TODO Batching?
242+
func tagAssetGroupNodes(ctx context.Context, db database.Database, graphDb graph.Database, tag model.AssetGroupTag) error {
243+
if selectors, err := db.GetAssetGroupTagSelectorsByTagId(ctx, tag.ID); err != nil {
244+
return err
245+
} else {
246+
var (
247+
countTagged, countUntagged, countTotal int
248+
selectorIds []int
249+
selectedNodes []model.AssetGroupSelectorNode
250+
251+
tagKind = tag.ToKind()
252+
253+
nodesSeen = cardinality.NewBitmap64()
254+
oldTaggedNodes = cardinality.NewBitmap64()
255+
newTaggedNodes = cardinality.NewBitmap64()
256+
)
257+
258+
disabledSelectors := cardinality.NewBitmap32()
259+
for _, selector := range selectors {
260+
if !selector.DisabledAt.IsZero() {
261+
disabledSelectors.Add(uint32(selector.ID))
262+
}
263+
selectorIds = append(selectorIds, selector.ID)
264+
}
265+
266+
// 1. Fetch the selected nodes for this label
267+
if selectedNodes, err = db.GetSelectorNodesBySelectorIds(ctx, selectorIds...); err != nil {
268+
return err
269+
} else if err = graphDb.WriteTransaction(ctx, func(tx graph.Transaction) error {
270+
// 2. Fetch already tagged nodes
271+
if oldTaggedNodeSet, err := ops.FetchNodeSet(tx.Nodes().Filter(query.Kind(query.Node(), tagKind))); err != nil {
272+
return err
273+
} else {
274+
oldTaggedNodes = oldTaggedNodeSet.IDBitmap()
275+
}
276+
277+
// 3. Diff the sets filling the respective sets for later db updates
278+
for _, nodeDb := range selectedNodes {
279+
if !nodesSeen.Contains(nodeDb.NodeId.Uint64()) {
280+
// Skip any that are not certified when tag requires certification or are selected by disabled selectors
281+
if tag.RequireCertify.Bool && nodeDb.Certified <= 0 || disabledSelectors.Contains(uint32(nodeDb.SelectorId)) {
282+
continue
283+
}
284+
285+
// If the id is not present, we must queue it for tagging
286+
if !oldTaggedNodes.Contains(nodeDb.NodeId.Uint64()) {
287+
newTaggedNodes.Add(nodeDb.NodeId.Uint64())
288+
} else {
289+
// If it is present, we don't need to update anything and will remove tags from any nodes left in this bitmap
290+
oldTaggedNodes.Remove(nodeDb.NodeId.Uint64())
291+
}
292+
// Once a node is processed, we can skip future duplicates that might be selected by other selectors
293+
nodesSeen.Add(nodeDb.NodeId.Uint64())
294+
countTotal++
295+
}
296+
}
297+
298+
// 4. Tag the new nodes
299+
newTaggedNodes.Each(func(nodeId uint64) bool {
300+
node := &graph.Node{ID: graph.ID(nodeId), Properties: &graph.Properties{}}
301+
node.AddKinds(tagKind)
302+
err = tx.UpdateNode(node)
303+
countTagged++
304+
return err == nil
305+
})
306+
if err != nil {
307+
return err
308+
}
309+
310+
// 5. Remove the old nodes
311+
oldTaggedNodes.Each(func(nodeId uint64) bool {
312+
node := &graph.Node{ID: graph.ID(nodeId), Properties: &graph.Properties{}}
313+
node.DeleteKinds(tagKind)
314+
err = tx.UpdateNode(node)
315+
countUntagged++
316+
return err == nil
317+
})
318+
if err != nil {
319+
return err
320+
}
321+
322+
return nil
323+
}); err != nil {
324+
return err
325+
}
326+
327+
slog.Info("AGT: Completed tagging", tag.ToType(), tag.Name, "total", countTotal, "tagged", countTagged, "untagged", countUntagged)
328+
}
329+
return nil
330+
}
331+
332+
func TagAssetGroupNodes(ctx context.Context, db database.Database, graphDb graph.Database) error {
333+
if tags, err := db.GetAssetGroupTagForSelection(ctx); err != nil {
334+
return err
335+
} else {
336+
// Tiers are hierarchical and must be handled synchronously while labels can be tagged in parallel
337+
var (
338+
labels []model.AssetGroupTag
339+
tiersOrdered []model.AssetGroupTag
340+
)
341+
for _, tag := range tags {
342+
switch tag.Type {
343+
case model.AssetGroupTagTypeTier:
344+
tiersOrdered = append(tiersOrdered, tag)
345+
case model.AssetGroupTagTypeLabel:
346+
labels = append(labels, tag)
347+
default:
348+
slog.WarnContext(ctx, fmt.Sprintf("AGT: Tag type %d is not supported", tag.Type), "tag", tag)
349+
}
350+
}
351+
352+
// Order the tiers by position
353+
slices.SortFunc(tiersOrdered, func(i, j model.AssetGroupTag) int {
354+
return int(i.Position.Int32 - j.Position.Int32)
355+
})
356+
357+
// Fire off the label tagging
358+
wg := sync.WaitGroup{}
359+
for _, label := range labels {
360+
// Parallelize the tagging of label nodes
361+
go func() {
362+
defer wg.Done()
363+
if err = tagAssetGroupNodes(ctx, db, graphDb, label); err != nil {
364+
slog.Error("AGT: Error tagging nodes", "label", label, "err", err)
365+
}
366+
}()
367+
wg.Add(1)
368+
}
369+
370+
// Process the tier tagging synchronously
371+
for _, tier := range tiersOrdered {
372+
if err := tagAssetGroupNodes(ctx, db, graphDb, tier); err != nil {
373+
slog.Error("AGT: Error tagging nodes", "tier", tier, "err", err)
374+
}
375+
}
376+
377+
// Wait for labels to finish
378+
wg.Wait()
379+
}
380+
return nil
381+
}

0 commit comments

Comments
 (0)