Skip to content

refactor node registrar #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions node-registrar/cmds/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
"fmt"
"net"
"os"
"os/signal"
"strings"
"syscall"

"github.com/pkg/errors"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -57,7 +55,7 @@ func Run() error {
flag.UintVar(&f.serverPort, "server-port", 8080, "server port")
flag.StringVar(&f.domain, "domain", "", "domain on which the server will be served")
flag.StringVar(&f.network, "network", "dev", "the registrar network")
flag.Uint64Var(&f.adminTwinID, "admin-twin-id", 0, "admin twin ID")
flag.Uint64Var(&f.adminTwinID, "admin-twin-id", 1, "admin twin ID")

flag.Parse()
f.SqlLogLevel = logger.LogLevel(sqlLogLevel)
Expand All @@ -71,11 +69,11 @@ func Run() error {
return err
}

log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
zerolog.SetGlobalLevel(zerolog.InfoLevel)
logLevel := zerolog.InfoLevel
if f.debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
logLevel = zerolog.DebugLevel
}
log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).Level(logLevel).With().Timestamp().Logger()

db, err := db.NewDB(f.Config)
if err != nil {
Expand All @@ -89,17 +87,11 @@ func Run() error {
}
}()

s, err := server.NewServer(db, f.network, f.adminTwinID)
if err != nil {
return errors.Wrap(err, "failed to start gin server")
}

quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
s := server.NewServer(db, f.network, f.adminTwinID)

log.Info().Msgf("server is running on port :%d", f.serverPort)

err = s.Run(quit, fmt.Sprintf("%s:%d", f.domain, f.serverPort))
err = s.Run(fmt.Sprintf("%s:%d", f.domain, f.serverPort))
if err != nil {
return errors.Wrap(err, "failed to run gin server")
}
Expand All @@ -115,6 +107,14 @@ func (f flags) validate() error {
if strings.TrimSpace(f.domain) == "" {
return errors.New("invalid domain name, domain name should not be empty")
}

if f.SqlLogLevel < 1 || f.SqlLogLevel > 4 {
return errors.Errorf("invalid sql log level %d, sql log level should be in the range 1-4", f.SqlLogLevel)
}
if f.adminTwinID == 0 {
return errors.Errorf("invalid admin twin id %d, admin twin id should not be 0", f.adminTwinID)
}

if _, err := net.LookupHost(f.domain); err != nil {
return errors.Wrapf(err, "invalid domain %s", f.domain)
}
Expand Down
10 changes: 5 additions & 5 deletions node-registrar/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,24 @@ func (c Config) Validate() error {
}

if net.ParseIP(c.PostgresHost) == nil {
if _, err := net.LookupHost(c.PostgresHost); err == nil {
if _, err := net.LookupHost(c.PostgresHost); err != nil {
return errors.Wrapf(err, "invalid postgres host %s, failed to parse or lookup host", c.PostgresHost)
}
}

if c.PostgresPort < 1 && c.PostgresPort > 65535 {
if c.PostgresPort < 1 || c.PostgresPort > 65535 {
return errors.Errorf("invalid postgres port %d, postgres port should be in the valid port range 1–65535", c.PostgresPort)
}

if strings.TrimSpace(c.DBName) == "" {
if len(strings.TrimSpace(c.DBName)) == 0 {
return errors.New("invalid database name, database name should not be empty")
}

if strings.TrimSpace(c.PostgresUser) == "" {
if len(strings.TrimSpace(c.PostgresUser)) == 0 {
return errors.New("invalid postgres user, postgres user should not be empty")
}

if strings.TrimSpace(c.PostgresPassword) == "" {
if len(strings.TrimSpace(c.PostgresPassword)) == 0 {
return errors.New("invalid postgres password, postgres password should not be empty")
}

Expand Down
2 changes: 1 addition & 1 deletion node-registrar/pkg/db/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Account struct {
type Farm struct {
FarmID uint64 `gorm:"primaryKey;autoIncrement" json:"farm_id"`
FarmName string `gorm:"size:40;not null;unique;check:farm_name <> ''" json:"farm_name" binding:"alphanum,required"`
TwinID uint64 `json:"twin_id" gorm:"not null;check:twin_id > 0"` // Farmer account reference
TwinID uint64 `json:"twin_id" binding:"required" gorm:"not null;check:twin_id > 0"` // Farmer account reference
StellarAddress string `json:"stellar_address" binding:"required,startswith=G,len=56,alphanum,uppercase"`
Dedicated bool `json:"dedicated"`
CreatedAt time.Time `json:"created_at"`
Expand Down
37 changes: 15 additions & 22 deletions node-registrar/pkg/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ import (

"github.com/gin-gonic/gin"
"github.com/lib/pq"
"github.com/rs/zerolog/log"
"github.com/threefoldtech/tfgrid4-sdk-go/node-registrar/pkg/db"
)

const (
PubKeySize = 32
MaxTimestampDelta = 2 * time.Second
)

Expand Down Expand Up @@ -78,13 +76,14 @@ func (s Server) getFarmHandler(c *gin.Context) {

farm, err := s.db.GetFarm(id)
if err != nil {
status := http.StatusBadRequest
status := http.StatusInternalServerError

if errors.Is(err, db.ErrRecordNotFound) {
status = http.StatusNotFound
}

c.JSON(status, gin.H{"error": err.Error()})
return
}

c.JSON(http.StatusOK, farm)
Expand Down Expand Up @@ -117,7 +116,7 @@ func (s Server) createFarmHandler(c *gin.Context) {

farmID, err := s.db.CreateFarm(farm)
if err != nil {
status := http.StatusBadRequest
status := http.StatusInternalServerError

if errors.Is(err, db.ErrRecordAlreadyExists) {
status = http.StatusConflict
Expand Down Expand Up @@ -148,7 +147,7 @@ type UpdateFarmRequest struct {
// @Failure 401 {object} map[string]any "Unauthorized"
// @Failure 404 {object} map[string]any "Farm not found"
// @Router /farms/{farm_id} [patch]
func (s Server) updateFarmsHandler(c *gin.Context) {
func (s Server) updateFarmHandler(c *gin.Context) {
var req UpdateFarmRequest
farmID := c.Param("farm_id")

Expand Down Expand Up @@ -186,7 +185,7 @@ func (s Server) updateFarmsHandler(c *gin.Context) {
(len(req.StellarAddress) != 0 && existingFarm.StellarAddress != req.StellarAddress) {
err = s.db.UpdateFarm(id, req.FarmName, req.StellarAddress)
if err != nil {
status := http.StatusBadRequest
status := http.StatusInternalServerError

if errors.Is(err, db.ErrRecordNotFound) {
status = http.StatusNotFound
Expand Down Expand Up @@ -229,7 +228,7 @@ func (s Server) listNodesHandler(c *gin.Context) {

nodes, err := s.db.ListNodes(filter, limit)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

Expand Down Expand Up @@ -262,7 +261,7 @@ func (s Server) getNodeHandler(c *gin.Context) {
return
}

c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

Expand Down Expand Up @@ -319,7 +318,7 @@ func (s Server) registerNodeHandler(c *gin.Context) {

nodeID, err := s.db.RegisterNode(node)
if err != nil {
status := http.StatusBadRequest
status := http.StatusInternalServerError

if errors.Is(err, db.ErrRecordAlreadyExists) {
status = http.StatusConflict
Expand Down Expand Up @@ -382,7 +381,6 @@ func (s *Server) updateNodeHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
log.Debug().Any("req", req).Send()

updatedNode := db.Node{
FarmID: req.FarmID,
Expand Down Expand Up @@ -540,11 +538,13 @@ func (s *Server) createAccountHandler(c *gin.Context) {
publicKeyBytes, err := base64.StdEncoding.DecodeString(req.PublicKey)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid public key format"})
return
}
// Decode signature from base64
signatureBytes, err := base64.StdEncoding.DecodeString(req.Signature)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid signature format: %v", err)})
return
}
// Verify signature of the challenge
err = verifySignature(publicKeyBytes, challenge, signatureBytes)
Expand Down Expand Up @@ -572,13 +572,6 @@ func (s *Server) createAccountHandler(c *gin.Context) {
c.JSON(http.StatusCreated, account)
}

/* // verifySignature verifies an ED25519 signature
func verifySignature(publicKey, chalange, signature []byte) (bool, error) {

// Verify the signature
return ed25519.Verify(publicKey, chalange, signature), nil
} */

type UpdateAccountRequest struct {
Relays pq.StringArray `json:"relays"`
RMBEncKey string `json:"rmb_enc_key"`
Expand Down Expand Up @@ -669,7 +662,7 @@ func (s *Server) getAccountHandler(c *gin.Context) {

account, err := s.db.GetAccount(twinID)
if err != nil {
if err == db.ErrRecordNotFound {
if errors.Is(err, db.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "account not found"})
return
}
Expand All @@ -684,7 +677,7 @@ func (s *Server) getAccountHandler(c *gin.Context) {
if publicKeyParam != "" {
account, err := s.db.GetAccountByPublicKey(publicKeyParam)
if err != nil {
if err == db.ErrRecordNotFound {
if errors.Is(err, db.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "account not found"})
return
}
Expand Down Expand Up @@ -759,7 +752,7 @@ func (s *Server) getZOSVersionHandler(c *gin.Context) {
c.JSON(http.StatusOK, version)
}

// Helper function to validate public key format
// Helper function to validate public key length
func isValidPublicKey(publicKeyBase64 string) bool {
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
Expand All @@ -771,8 +764,8 @@ func isValidPublicKey(publicKeyBase64 string) bool {
// Helper function to ensure the request is from the owner
func ensureOwner(c *gin.Context, twinID uint64) {
// Retrieve twinID set by the authMiddleware
authTwinID, exists := c.Get("twinID")
if !exists {
authTwinID := c.Request.Context().Value(twinIDKey{})
if authTwinID == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "not authorized"})
return
}
Expand Down
17 changes: 14 additions & 3 deletions node-registrar/pkg/server/middlewares.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"encoding/base64"
"errors"
"fmt"
Expand All @@ -10,12 +11,16 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/threefoldtech/tfgrid4-sdk-go/node-registrar/pkg/db"
)

// twinKeyID is where the twin key is stored
type twinIDKey struct{}

const (
AuthHeader = "X-Auth"
ChallengeValidity = 5 * time.Minute
ChallengeValidity = 1 * time.Minute
)

// AuthMiddleware is a middleware function that authenticates incoming requests based on the X-Auth header.
Expand All @@ -26,7 +31,6 @@ const (
// header format `Challenge:Signature`
// - chalange format: base64(message) where the message is `timestampStr:twinIDStr`
// - signature format: base64(ed25519_or_sr22519_signature)
// TODO: do we need to support both? Maybe if only ed25519 needed we can rely on crypto pkg instead of using go-subkey
func (s *Server) AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract and validate headers
Expand All @@ -47,6 +51,7 @@ func (s *Server) AuthMiddleware() gin.HandlerFunc {
// Decode and validate challenge
challenge, err := base64.StdEncoding.DecodeString(challengeB64)
if err != nil {
log.Debug().Err(err).Msg("failed to deconde challenge")
abortWithError(c, http.StatusBadRequest, "Invalid challenge encoding")
return
}
Expand All @@ -62,6 +67,7 @@ func (s *Server) AuthMiddleware() gin.HandlerFunc {
// Validate timestamp
timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
log.Debug().Err(err).Msg("invalid timestamp")
abortWithError(c, http.StatusBadRequest, "Invalid timestamp")
return
}
Expand All @@ -73,36 +79,41 @@ func (s *Server) AuthMiddleware() gin.HandlerFunc {

twinID, err := strconv.ParseUint(twinIDStr, 10, 64)
if err != nil {
log.Debug().Err(err).Msg("invalid twin id format")
abortWithError(c, http.StatusBadRequest, "Invalid twin ID format")
return
}

account, err := s.db.GetAccount(twinID)
if err != nil {
log.Debug().Err(err).Uint64("twinID", twinID).Msg("failed to get account")
handleDatabaseError(c, err)
return
}

storedPK, err := base64.StdEncoding.DecodeString(account.PublicKey)
if err != nil {
log.Debug().Err(err).Msg("failed to get invalid stored public key")
abortWithError(c, http.StatusBadRequest, fmt.Sprintf("invalid stored public key: %v", err))
return
}

sig, err := base64.StdEncoding.DecodeString(signatureB64)
if err != nil {
log.Debug().Err(err).Msg("invalid signature encoding")
abortWithError(c, http.StatusBadRequest, "Invalid signature encoding")
return
}

// Verify signature (supports both ED25519 and SR25519)
if err := verifySignature(storedPK, challenge, sig); err != nil {
log.Debug().Err(err).Msg("signature verification failed")
abortWithError(c, http.StatusUnauthorized, fmt.Sprintf("Signature verification failed: %v", err))
return
}

// Store verified twin ID in context, must be checked form the handlers to ensure altred resources belongs to same user
c.Set("twinID", twinID)
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), twinIDKey{}, twinID))
c.Next()
}
}
Expand Down
Loading
Loading