-
Notifications
You must be signed in to change notification settings - Fork 710
/
Copy pathprotovalidate.go
92 lines (82 loc) · 2.62 KB
/
protovalidate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.
package protovalidate
import (
"context"
"errors"
"github.com/bufbuild/protovalidate-go"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func UnaryServerInterceptor(validator protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOpts(opts)
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
if err := validateMsg(req, validator, o); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func StreamServerInterceptor(validator protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOpts(opts)
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
return handler(srv, &wrappedServerStream{
ServerStream: stream,
validator: validator,
options: o,
})
}
}
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type wrappedServerStream struct {
grpc.ServerStream
validator protovalidate.Validator
options *options
}
func (w *wrappedServerStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
return err
}
return validateMsg(m, w.validator, w.options)
}
func validateMsg(m interface{}, validator protovalidate.Validator, opts *options) error {
msg, ok := m.(proto.Message)
if !ok {
return status.Errorf(codes.Internal, "unsupported message type: %T", m)
}
if opts.shouldIgnoreMessage(msg.ProtoReflect().Descriptor().FullName()) {
return nil
}
err := validator.Validate(msg)
if err == nil {
return nil
}
var valErr *protovalidate.ValidationError
if errors.As(err, &valErr) {
// Message is invalid.
st := status.New(codes.InvalidArgument, err.Error())
ds, detErr := st.WithDetails(valErr.ToProto())
if detErr != nil {
return st.Err()
}
return ds.Err()
}
// CEL expression doesn't compile or type-check.
return status.Error(codes.Internal, err.Error())
}