Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atls: add handshake timeout to aTLS servers and clients #1255

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions coordinator/internal/authority/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package authority

import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/edgelesssys/contrast/internal/attestation/certcache"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/constants"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/memstore"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -112,16 +114,19 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
return nil, nil, err
}

conn, info, err := credentials.NewTLS(serverCfg).ServerHandshake(rawConn)
if err != nil {
log.Error("ServerHandshake failed", "error", err)
return nil, nil, err
ctx, cancel := context.WithTimeout(context.Background(), constants.ATLSServerTimeout)
defer cancel()

conn := tls.Server(rawConn, serverCfg)
if err := conn.HandshakeContext(ctx); err != nil {
return nil, nil, fmt.Errorf("handshake error: %w", err)
}
tlsInfo, ok := info.(credentials.TLSInfo)
if ok {
authInfo.TLSInfo = tlsInfo
} else {
log.Error("credentials.NewTLS returned unexpected AuthInfo", "obj", info)

authInfo.TLSInfo = credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}

return conn, authInfo, nil
Expand Down
2 changes: 1 addition & 1 deletion coordinator/internal/authority/userapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func getPeerPublicKey(ctx context.Context) ([]byte, error) {
}
tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo)
if !ok {
return nil, errors.New("peer auth info is not of type TLSInfo")
return nil, fmt.Errorf("peer auth info is not of type TLSInfo: got %T", peer.AuthInfo)
}
if len(tlsInfo.State.PeerCertificates) == 0 || tlsInfo.State.PeerCertificates[0] == nil {
return nil, errors.New("no peer certificates found")
Expand Down
17 changes: 16 additions & 1 deletion internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

package constants

import "github.com/google/go-sev-guest/abi"
import (
"time"

"github.com/google/go-sev-guest/abi"
)

// Version value is injected at build time.
var (
Expand All @@ -26,3 +30,14 @@ var SNPPolicy = abi.SnpPolicy{
SMT: true,
Debug: false,
}

const (
// ATLSClientTimeout is the maximal amount of time spent by Coordinator clients for issuing
// and validation of attestation docs.
ATLSClientTimeout = 30 * time.Second

// ATLSServerTimeout is the maximal amount of time that the Coordinator can spend for issuing
// attestation docs. It's deliberately smaller than ATLSClientTimeout to allow proper error
// propagation.
ATLSServerTimeout = ATLSClientTimeout - 5*time.Second
)
20 changes: 19 additions & 1 deletion internal/grpc/atlscredentials/atlscredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ package atlscredentials
import (
"context"
"crypto"
"crypto/tls"
"errors"
"fmt"
"net"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/constants"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc/credentials"
)
Expand Down Expand Up @@ -58,7 +61,22 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
return nil, nil, err
}

return credentials.NewTLS(serverCfg).ServerHandshake(rawConn)
ctx, cancel := context.WithTimeout(context.Background(), constants.ATLSServerTimeout)
defer cancel()

conn := tls.Server(rawConn, serverCfg)
if err := conn.HandshakeContext(ctx); err != nil {
return nil, nil, fmt.Errorf("handshake error: %w", err)
}

info := credentials.TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: credentials.CommonAuthInfo{
SecurityLevel: credentials.PrivacyAndIntegrity,
},
}

return conn, info, nil
}

// Info provides information about the protocol.
Expand Down
4 changes: 2 additions & 2 deletions internal/grpc/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
"context"
"crypto"
"net"
"time"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/constants"
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
Expand Down Expand Up @@ -57,7 +57,7 @@ func (d *Dialer) Dial(_ context.Context, target string) (*grpc.ClientConn, error
grpc.WithConnectParams(grpc.ConnectParams{
// We need a high initial timeout, because otherwise the client will get stuck in a reconnect loop
// where the timeout is too low to get a full handshake done.
MinConnectTimeout: 30 * time.Second,
MinConnectTimeout: constants.ATLSClientTimeout,
}),
)
}
Expand Down