Skip to content

Commit 3b067d4

Browse files
authored
BED-5467 Allow DB injection into DAWGS (#1314)
* chore: combine driver and schemamanager into a hyperstruct * feat: move pgxpool creation outside of the driver with injection * fix: infinite loop * chore: add another lock to prevent deadlocks * chore: partial revert * feat: refactor schemamanager/driver to remove circular dependency This embeds schemamanager into the driver, so that the transaction methods can live on schemamanager but be accessed from the driver layer. This did necessitate moving batchWriteSize to a package level var instead of living on the driver. This ended up being the simplest change to allow injecting a pgxpool into dawgs, and additional work should be done to better unwind the driver and schemamanager responsibilities * chore: move of NewPool was too aggressive, moving back to pg driver package * chore: oops * fix: overapplied rename * chore: fix infertypeargs flags
1 parent 5425d6b commit 3b067d4

File tree

14 files changed

+200
-158
lines changed

14 files changed

+200
-158
lines changed

Diff for: cmd/api/src/api/auth_internal_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -333,15 +333,15 @@ func TestValidateRequestSignature(t *testing.T) {
333333

334334
tmpFiles, err := os.ReadDir(authenticator.cfg.TempDirectory())
335335
assert.NoError(t, err)
336-
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
336+
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
337337
return strings.HasPrefix(file.Name(), "bh-request-")
338338
}), 1)
339339

340340
// Closing the body should remove the tmp file
341341
req.Body.Close()
342342
tmpFiles, err = os.ReadDir(os.TempDir())
343343
assert.NoError(t, err)
344-
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
344+
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
345345
return strings.HasPrefix(file.Name(), "bh-request-")
346346
}), 0)
347347
})
@@ -376,7 +376,7 @@ func TestValidateRequestSignature(t *testing.T) {
376376
// "small" payloads should not create a tmp file
377377
tmpFiles, err := os.ReadDir(os.TempDir())
378378
assert.NoError(t, err)
379-
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
379+
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
380380
return strings.HasPrefix(file.Name(), "bh-request-")
381381
}), 0)
382382
})

Diff for: cmd/api/src/api/tools/pg.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ func migrateTypes(ctx context.Context, neoDB, pgDB graph.Database) error {
8383
return err
8484
}
8585

86-
_, err := pgDB.(*pg.Driver).KindMapper().AssertKinds(ctx, append(neoNodeKinds, neoEdgeKinds...))
86+
driver, ok := pgDB.(*pg.Driver)
87+
if !ok {
88+
return fmt.Errorf("current graph database is not a pg driver")
89+
}
90+
91+
_, err := driver.KindMapper().AssertKinds(ctx, append(neoNodeKinds, neoEdgeKinds...))
8792
return err
8893
}
8994

@@ -443,15 +448,20 @@ func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http
443448
}
444449

445450
func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) {
446-
return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{
447-
GraphQueryMemoryLimit: size.Gibibyte,
448-
DriverCfg: s.Cfg.Database.PostgreSQLConnectionString(),
449-
})
451+
if pool, err := pg.NewPool(s.Cfg.Database.PostgreSQLConnectionString()); err != nil {
452+
return nil, err
453+
} else {
454+
return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{
455+
GraphQueryMemoryLimit: size.Gibibyte,
456+
ConnectionString: s.Cfg.Database.PostgreSQLConnectionString(),
457+
Pool: pool,
458+
})
459+
}
450460
}
451461

452462
func (s *PGMigrator) OpenNeo4jGraphConnection() (graph.Database, error) {
453463
return dawgs.Open(s.ServerCtx, neo4j.DriverName, dawgs.Config{
454464
GraphQueryMemoryLimit: size.Gibibyte,
455-
DriverCfg: s.Cfg.Neo4J.Neo4jConnectionString(),
465+
ConnectionString: s.Cfg.Neo4J.Neo4jConnectionString(),
456466
})
457467
}

Diff for: cmd/api/src/bootstrap/util.go

+34-22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"log/slog"
2323
"os"
2424

25+
"github.com/jackc/pgx/v5/pgxpool"
2526
"github.com/specterops/bloodhound/bhlog"
2627
"github.com/specterops/bloodhound/bhlog/level"
2728
"github.com/specterops/bloodhound/dawgs"
@@ -75,34 +76,45 @@ func DefaultConfigFilePath() string {
7576
}
7677

7778
func ConnectGraph(ctx context.Context, cfg config.Configuration) (*graph.DatabaseSwitch, error) {
78-
var connectionString string
79-
80-
if driverName, err := tools.LookupGraphDriver(ctx, cfg); err != nil {
79+
var (
80+
connectionString string
81+
pool *pgxpool.Pool
82+
err error
83+
)
84+
85+
driverName, err := tools.LookupGraphDriver(ctx, cfg)
86+
if err != nil {
8187
return nil, err
82-
} else {
83-
switch driverName {
84-
case neo4j.DriverName:
85-
slog.InfoContext(ctx, "Connecting to graph using Neo4j")
86-
connectionString = cfg.Neo4J.Neo4jConnectionString()
88+
}
8789

88-
case pg.DriverName:
89-
slog.InfoContext(ctx, "Connecting to graph using PostgreSQL")
90-
connectionString = cfg.Database.PostgreSQLConnectionString()
90+
switch driverName {
91+
case neo4j.DriverName:
92+
slog.InfoContext(ctx, "Connecting to graph using Neo4j")
93+
connectionString = cfg.Neo4J.Neo4jConnectionString()
9194

92-
default:
93-
return nil, fmt.Errorf("unknown graphdb driver name: %s", driverName)
94-
}
95+
case pg.DriverName:
96+
slog.InfoContext(ctx, "Connecting to graph using PostgreSQL")
97+
connectionString = cfg.Database.PostgreSQLConnectionString()
9598

96-
if connectionString == "" {
97-
return nil, fmt.Errorf("graph connection requires a connection url to be set")
98-
} else if graphDatabase, err := dawgs.Open(ctx, driverName, dawgs.Config{
99-
GraphQueryMemoryLimit: size.Size(cfg.GraphQueryMemoryLimit) * size.Gibibyte,
100-
DriverCfg: connectionString,
101-
}); err != nil {
99+
pool, err = pg.NewPool(connectionString)
100+
if err != nil {
102101
return nil, err
103-
} else {
104-
return graph.NewDatabaseSwitch(ctx, graphDatabase), nil
105102
}
103+
104+
default:
105+
return nil, fmt.Errorf("unknown graphdb driver name: %s", driverName)
106+
}
107+
108+
if connectionString == "" {
109+
return nil, fmt.Errorf("graph connection requires a connection url to be set")
110+
} else if graphDatabase, err := dawgs.Open(ctx, driverName, dawgs.Config{
111+
GraphQueryMemoryLimit: size.Size(cfg.GraphQueryMemoryLimit) * size.Gibibyte,
112+
ConnectionString: connectionString,
113+
Pool: pool,
114+
}); err != nil {
115+
return nil, err
116+
} else {
117+
return graph.NewDatabaseSwitch(ctx, graphDatabase), nil
106118
}
107119
}
108120

Diff for: cmd/api/src/cmd/dawgs-harness/main.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"syscall"
2828
"time"
2929

30+
"github.com/jackc/pgx/v5/pgxpool"
3031
"github.com/jedib0t/go-pretty/v6/table"
3132
"github.com/specterops/bloodhound/bhlog"
3233
"github.com/specterops/bloodhound/dawgs"
@@ -44,9 +45,22 @@ func fatalf(format string, args ...any) {
4445
}
4546

4647
func RunTestSuite(ctx context.Context, connectionStr, driverName string) tests.TestSuite {
48+
var (
49+
pool *pgxpool.Pool
50+
err error
51+
)
52+
53+
if driverName == pg.DriverName {
54+
pool, err = pg.NewPool(connectionStr)
55+
if err != nil {
56+
fatalf("Failed creating a new pgxpool: %s", err)
57+
}
58+
}
59+
4760
if connection, err := dawgs.Open(context.TODO(), driverName, dawgs.Config{
4861
GraphQueryMemoryLimit: size.Gibibyte,
49-
DriverCfg: connectionStr,
62+
ConnectionString: connectionStr,
63+
Pool: pool,
5064
}); err != nil {
5165
fatalf("Failed opening %s database: %v", driverName, err)
5266
} else {
@@ -65,7 +79,7 @@ func RunTestSuite(ctx context.Context, connectionStr, driverName string) tests.T
6579
}
6680
}
6781

68-
panic(nil)
82+
panic("unexpected error")
6983
}
7084

7185
func newContext() context.Context {

Diff for: cmd/api/src/database/log.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (s *GormLogAdapter) Trace(ctx context.Context, begin time.Time, fc func() (
5353
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
5454
sql, _ := fc()
5555

56-
slog.ErrorContext(ctx, "Database error", "query", sql, "error", err)
56+
slog.ErrorContext(ctx, "Database error", "query", sql, "err", err)
5757
} else {
5858
elapsed := time.Since(begin)
5959

Diff for: cmd/api/src/test/integration/dawgs.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,24 @@ func OpenGraphDB(testCtrl test.Controller, schema graph.Schema) graph.Database {
4747

4848
switch cfg.GraphDriver {
4949
case pg.DriverName:
50+
pool, err := pg.NewPool(cfg.Database.PostgreSQLConnectionString())
51+
test.RequireNilErrf(testCtrl, err, "Failed to create new pgx pool: %v", err)
5052
graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{
51-
DriverCfg: cfg.Database.PostgreSQLConnectionString(),
53+
ConnectionString: cfg.Database.PostgreSQLConnectionString(),
54+
Pool: pool,
5255
})
56+
test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)
5357

5458
case neo4j.DriverName:
5559
graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{
56-
DriverCfg: cfg.Neo4J.Neo4jConnectionString(),
60+
ConnectionString: cfg.Neo4J.Neo4jConnectionString(),
5761
})
62+
test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)
5863

5964
default:
6065
testCtrl.Fatalf("unsupported graph driver name %s", cfg.GraphDriver)
6166
}
6267

63-
test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)
6468
test.RequireNilErr(testCtrl, graphDatabase.AssertSchema(context.Background(), schema))
6569

6670
return graphDatabase

Diff for: packages/go/dawgs/dawgs.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"errors"
2222

23+
"github.com/jackc/pgx/v5/pgxpool"
2324
"github.com/specterops/bloodhound/dawgs/graph"
2425
"github.com/specterops/bloodhound/dawgs/util/size"
2526
)
@@ -38,7 +39,8 @@ func Register(driverName string, constructor DriverConstructor) {
3839

3940
type Config struct {
4041
GraphQueryMemoryLimit size.Size
41-
DriverCfg any
42+
ConnectionString string
43+
Pool *pgxpool.Pool
4244
}
4345

4446
func Open(ctx context.Context, driverName string, config Config) (graph.Database, error) {

Diff for: packages/go/dawgs/drivers/neo4j/neo4j.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ const (
3535
)
3636

3737
func newNeo4jDB(_ context.Context, cfg dawgs.Config) (graph.Database, error) {
38-
if connectionURLStr, typeOK := cfg.DriverCfg.(string); !typeOK {
39-
return nil, fmt.Errorf("expected string for configuration type but got %T", cfg.DriverCfg)
40-
} else if connectionURL, err := url.Parse(connectionURLStr); err != nil {
38+
if connectionURL, err := url.Parse(cfg.ConnectionString); err != nil {
4139
return nil, err
4240
} else if connectionURL.Scheme != DriverName {
4341
return nil, fmt.Errorf("expected connection URL scheme %s for Neo4J but got %s", DriverName, connectionURL.Scheme)

Diff for: packages/go/dawgs/drivers/neo4j/result_internal_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,28 @@ func Test_mapValue(t *testing.T) {
5050
mapTestCase[uint32, uint32](t, 0, 0)
5151
mapTestCase[uint64, uint64](t, 0, 0)
5252

53-
mapTestCase[int, int](t, 0, 0)
53+
mapTestCase(t, 0, 0) // Inferred int
5454
mapTestCase[int8, int8](t, 0, 0)
5555
mapTestCase[int16, int16](t, 0, 0)
5656
mapTestCase[int32, int32](t, 0, 0)
5757
mapTestCase[int64, int64](t, 0, 0)
5858
mapTestCase[int64, graph.ID](t, 0, 0)
5959

6060
mapTestCase[float32, float32](t, 1.5, 1.5)
61-
mapTestCase[float64, float64](t, 1.5, 1.5)
61+
mapTestCase(t, 1.5, 1.5) // Inferred float64
6262

63-
mapTestCase[bool, bool](t, true, true)
64-
mapTestCase[string, string](t, "test", "test")
63+
mapTestCase(t, true, true)
64+
mapTestCase(t, "test", "test")
6565

66-
mapTestCase[time.Time, time.Time](t, utcNow, utcNow)
67-
mapTestCase[string, time.Time](t, utcNow.Format(time.RFC3339Nano), utcNow)
68-
mapTestCase[int64, time.Time](t, utcNow.Unix(), time.Unix(utcNow.Unix(), 0))
69-
mapTestCase[dbtype.Time, time.Time](t, dbtype.Time(utcNow), utcNow)
70-
mapTestCase[dbtype.LocalTime, time.Time](t, dbtype.LocalTime(utcNow), utcNow)
71-
mapTestCase[dbtype.Date, time.Time](t, dbtype.Date(utcNow), utcNow)
72-
mapTestCase[dbtype.LocalDateTime, time.Time](t, dbtype.LocalDateTime(utcNow), utcNow)
66+
mapTestCase(t, utcNow, utcNow)
67+
mapTestCase(t, utcNow.Format(time.RFC3339Nano), utcNow)
68+
mapTestCase(t, utcNow.Unix(), time.Unix(utcNow.Unix(), 0))
69+
mapTestCase(t, dbtype.Time(utcNow), utcNow)
70+
mapTestCase(t, dbtype.LocalTime(utcNow), utcNow)
71+
mapTestCase(t, dbtype.Date(utcNow), utcNow)
72+
mapTestCase(t, dbtype.LocalDateTime(utcNow), utcNow)
7373

74-
mapTestCase[[]any, []string](t, anyStringSlice, stringSlice)
75-
mapTestCase[[]any, []graph.Kind](t, anyStringSlice, kindSlice)
76-
mapTestCase[[]any, graph.Kinds](t, anyStringSlice, kinds)
74+
mapTestCase(t, anyStringSlice, stringSlice)
75+
mapTestCase(t, anyStringSlice, kindSlice)
76+
mapTestCase(t, anyStringSlice, kinds)
7777
}

0 commit comments

Comments
 (0)