Skip to content

BED-5467 Allow DB injection into DAWGS #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 8, 2025
6 changes: 3 additions & 3 deletions cmd/api/src/api/auth_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,15 @@ func TestValidateRequestSignature(t *testing.T) {

tmpFiles, err := os.ReadDir(authenticator.cfg.TempDirectory())
assert.NoError(t, err)
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
return strings.HasPrefix(file.Name(), "bh-request-")
}), 1)

// Closing the body should remove the tmp file
req.Body.Close()
tmpFiles, err = os.ReadDir(os.TempDir())
assert.NoError(t, err)
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
return strings.HasPrefix(file.Name(), "bh-request-")
}), 0)
})
Expand Down Expand Up @@ -376,7 +376,7 @@ func TestValidateRequestSignature(t *testing.T) {
// "small" payloads should not create a tmp file
tmpFiles, err := os.ReadDir(os.TempDir())
assert.NoError(t, err)
assert.Len(t, slicesext.Filter[fs.DirEntry](tmpFiles, func(file fs.DirEntry) bool {
assert.Len(t, slicesext.Filter(tmpFiles, func(file fs.DirEntry) bool {
return strings.HasPrefix(file.Name(), "bh-request-")
}), 0)
})
Expand Down
22 changes: 16 additions & 6 deletions cmd/api/src/api/tools/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ func migrateTypes(ctx context.Context, neoDB, pgDB graph.Database) error {
return err
}

_, err := pgDB.(*pg.Driver).KindMapper().AssertKinds(ctx, append(neoNodeKinds, neoEdgeKinds...))
driver, ok := pgDB.(*pg.Driver)
if !ok {
return fmt.Errorf("current graph database is not a pg driver")
}

_, err := driver.KindMapper().AssertKinds(ctx, append(neoNodeKinds, neoEdgeKinds...))
return err
}

Expand Down Expand Up @@ -443,15 +448,20 @@ func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http
}

func (s *PGMigrator) OpenPostgresGraphConnection() (graph.Database, error) {
return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
DriverCfg: s.Cfg.Database.PostgreSQLConnectionString(),
})
if pool, err := pg.NewPool(s.Cfg.Database.PostgreSQLConnectionString()); err != nil {
return nil, err
} else {
return dawgs.Open(s.ServerCtx, pg.DriverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
ConnectionString: s.Cfg.Database.PostgreSQLConnectionString(),
Pool: pool,
})
}
}

func (s *PGMigrator) OpenNeo4jGraphConnection() (graph.Database, error) {
return dawgs.Open(s.ServerCtx, neo4j.DriverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
DriverCfg: s.Cfg.Neo4J.Neo4jConnectionString(),
ConnectionString: s.Cfg.Neo4J.Neo4jConnectionString(),
})
}
56 changes: 34 additions & 22 deletions cmd/api/src/bootstrap/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"log/slog"
"os"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/bloodhound/bhlog"
"github.com/specterops/bloodhound/bhlog/level"
"github.com/specterops/bloodhound/dawgs"
Expand Down Expand Up @@ -75,34 +76,45 @@ func DefaultConfigFilePath() string {
}

func ConnectGraph(ctx context.Context, cfg config.Configuration) (*graph.DatabaseSwitch, error) {
var connectionString string

if driverName, err := tools.LookupGraphDriver(ctx, cfg); err != nil {
var (
connectionString string
pool *pgxpool.Pool
err error
)

driverName, err := tools.LookupGraphDriver(ctx, cfg)
if err != nil {
return nil, err
} else {
switch driverName {
case neo4j.DriverName:
slog.InfoContext(ctx, "Connecting to graph using Neo4j")
connectionString = cfg.Neo4J.Neo4jConnectionString()
}

case pg.DriverName:
slog.InfoContext(ctx, "Connecting to graph using PostgreSQL")
connectionString = cfg.Database.PostgreSQLConnectionString()
switch driverName {
case neo4j.DriverName:
slog.InfoContext(ctx, "Connecting to graph using Neo4j")
connectionString = cfg.Neo4J.Neo4jConnectionString()

default:
return nil, fmt.Errorf("unknown graphdb driver name: %s", driverName)
}
case pg.DriverName:
slog.InfoContext(ctx, "Connecting to graph using PostgreSQL")
connectionString = cfg.Database.PostgreSQLConnectionString()

if connectionString == "" {
return nil, fmt.Errorf("graph connection requires a connection url to be set")
} else if graphDatabase, err := dawgs.Open(ctx, driverName, dawgs.Config{
GraphQueryMemoryLimit: size.Size(cfg.GraphQueryMemoryLimit) * size.Gibibyte,
DriverCfg: connectionString,
}); err != nil {
pool, err = pg.NewPool(connectionString)
if err != nil {
return nil, err
} else {
return graph.NewDatabaseSwitch(ctx, graphDatabase), nil
}

default:
return nil, fmt.Errorf("unknown graphdb driver name: %s", driverName)
}

if connectionString == "" {
return nil, fmt.Errorf("graph connection requires a connection url to be set")
} else if graphDatabase, err := dawgs.Open(ctx, driverName, dawgs.Config{
GraphQueryMemoryLimit: size.Size(cfg.GraphQueryMemoryLimit) * size.Gibibyte,
ConnectionString: connectionString,
Pool: pool,
}); err != nil {
return nil, err
} else {
return graph.NewDatabaseSwitch(ctx, graphDatabase), nil
}
}

Expand Down
18 changes: 16 additions & 2 deletions cmd/api/src/cmd/dawgs-harness/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"syscall"
"time"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/specterops/bloodhound/bhlog"
"github.com/specterops/bloodhound/dawgs"
Expand All @@ -44,9 +45,22 @@ func fatalf(format string, args ...any) {
}

func RunTestSuite(ctx context.Context, connectionStr, driverName string) tests.TestSuite {
var (
pool *pgxpool.Pool
err error
)

if driverName == pg.DriverName {
pool, err = pg.NewPool(connectionStr)
if err != nil {
fatalf("Failed creating a new pgxpool: %s", err)
}
}

if connection, err := dawgs.Open(context.TODO(), driverName, dawgs.Config{
GraphQueryMemoryLimit: size.Gibibyte,
DriverCfg: connectionStr,
ConnectionString: connectionStr,
Pool: pool,
}); err != nil {
fatalf("Failed opening %s database: %v", driverName, err)
} else {
Expand All @@ -65,7 +79,7 @@ func RunTestSuite(ctx context.Context, connectionStr, driverName string) tests.T
}
}

panic(nil)
panic("unexpected error")
}

func newContext() context.Context {
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/src/database/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (s *GormLogAdapter) Trace(ctx context.Context, begin time.Time, fc func() (
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
sql, _ := fc()

slog.ErrorContext(ctx, "Database error", "query", sql, "error", err)
slog.ErrorContext(ctx, "Database error", "query", sql, "err", err)
} else {
elapsed := time.Since(begin)

Expand Down
10 changes: 7 additions & 3 deletions cmd/api/src/test/integration/dawgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,24 @@ func OpenGraphDB(testCtrl test.Controller, schema graph.Schema) graph.Database {

switch cfg.GraphDriver {
case pg.DriverName:
pool, err := pg.NewPool(cfg.Database.PostgreSQLConnectionString())
test.RequireNilErrf(testCtrl, err, "Failed to create new pgx pool: %v", err)
graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{
DriverCfg: cfg.Database.PostgreSQLConnectionString(),
ConnectionString: cfg.Database.PostgreSQLConnectionString(),
Pool: pool,
})
test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)

case neo4j.DriverName:
graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{
DriverCfg: cfg.Neo4J.Neo4jConnectionString(),
ConnectionString: cfg.Neo4J.Neo4jConnectionString(),
})
test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)

default:
testCtrl.Fatalf("unsupported graph driver name %s", cfg.GraphDriver)
}

test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err)
test.RequireNilErr(testCtrl, graphDatabase.AssertSchema(context.Background(), schema))

return graphDatabase
Expand Down
4 changes: 3 additions & 1 deletion packages/go/dawgs/dawgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"

"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/bloodhound/dawgs/graph"
"github.com/specterops/bloodhound/dawgs/util/size"
)
Expand All @@ -38,7 +39,8 @@ func Register(driverName string, constructor DriverConstructor) {

type Config struct {
GraphQueryMemoryLimit size.Size
DriverCfg any
ConnectionString string
Pool *pgxpool.Pool
}

func Open(ctx context.Context, driverName string, config Config) (graph.Database, error) {
Expand Down
4 changes: 1 addition & 3 deletions packages/go/dawgs/drivers/neo4j/neo4j.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ const (
)

func newNeo4jDB(_ context.Context, cfg dawgs.Config) (graph.Database, error) {
if connectionURLStr, typeOK := cfg.DriverCfg.(string); !typeOK {
return nil, fmt.Errorf("expected string for configuration type but got %T", cfg.DriverCfg)
} else if connectionURL, err := url.Parse(connectionURLStr); err != nil {
if connectionURL, err := url.Parse(cfg.ConnectionString); err != nil {
return nil, err
} else if connectionURL.Scheme != DriverName {
return nil, fmt.Errorf("expected connection URL scheme %s for Neo4J but got %s", DriverName, connectionURL.Scheme)
Expand Down
28 changes: 14 additions & 14 deletions packages/go/dawgs/drivers/neo4j/result_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,28 @@ func Test_mapValue(t *testing.T) {
mapTestCase[uint32, uint32](t, 0, 0)
mapTestCase[uint64, uint64](t, 0, 0)

mapTestCase[int, int](t, 0, 0)
mapTestCase(t, 0, 0) // Inferred int
mapTestCase[int8, int8](t, 0, 0)
mapTestCase[int16, int16](t, 0, 0)
mapTestCase[int32, int32](t, 0, 0)
mapTestCase[int64, int64](t, 0, 0)
mapTestCase[int64, graph.ID](t, 0, 0)

mapTestCase[float32, float32](t, 1.5, 1.5)
mapTestCase[float64, float64](t, 1.5, 1.5)
mapTestCase(t, 1.5, 1.5) // Inferred float64

mapTestCase[bool, bool](t, true, true)
mapTestCase[string, string](t, "test", "test")
mapTestCase(t, true, true)
mapTestCase(t, "test", "test")

mapTestCase[time.Time, time.Time](t, utcNow, utcNow)
mapTestCase[string, time.Time](t, utcNow.Format(time.RFC3339Nano), utcNow)
mapTestCase[int64, time.Time](t, utcNow.Unix(), time.Unix(utcNow.Unix(), 0))
mapTestCase[dbtype.Time, time.Time](t, dbtype.Time(utcNow), utcNow)
mapTestCase[dbtype.LocalTime, time.Time](t, dbtype.LocalTime(utcNow), utcNow)
mapTestCase[dbtype.Date, time.Time](t, dbtype.Date(utcNow), utcNow)
mapTestCase[dbtype.LocalDateTime, time.Time](t, dbtype.LocalDateTime(utcNow), utcNow)
mapTestCase(t, utcNow, utcNow)
mapTestCase(t, utcNow.Format(time.RFC3339Nano), utcNow)
mapTestCase(t, utcNow.Unix(), time.Unix(utcNow.Unix(), 0))
mapTestCase(t, dbtype.Time(utcNow), utcNow)
mapTestCase(t, dbtype.LocalTime(utcNow), utcNow)
mapTestCase(t, dbtype.Date(utcNow), utcNow)
mapTestCase(t, dbtype.LocalDateTime(utcNow), utcNow)

mapTestCase[[]any, []string](t, anyStringSlice, stringSlice)
mapTestCase[[]any, []graph.Kind](t, anyStringSlice, kindSlice)
mapTestCase[[]any, graph.Kinds](t, anyStringSlice, kinds)
mapTestCase(t, anyStringSlice, stringSlice)
mapTestCase(t, anyStringSlice, kindSlice)
mapTestCase(t, anyStringSlice, kinds)
}
Loading
Loading