diff --git a/go.mod b/go.mod index 5e355ae2..069a11aa 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/RobinUS2/golang-moving-average v1.0.0 github.com/gdamore/tcell/v2 v2.8.1 github.com/golang/mock v1.7.0-rc.1 - github.com/google/certificate-transparency-go v1.3.1 github.com/google/go-cmp v0.7.0 github.com/kylelemons/godebug v1.1.0 github.com/prometheus/client_golang v1.21.1 diff --git a/go.sum b/go.sum index 0faef939..73c44334 100644 --- a/go.sum +++ b/go.sum @@ -786,8 +786,6 @@ github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/certificate-transparency-go v1.3.1 h1:akbcTfQg0iZlANZLn0L9xOeWtyCIdeoYhKrqi5iH3Go= -github.com/google/certificate-transparency-go v1.3.1/go.mod h1:gg+UQlx6caKEDQ9EElFOujyxEQEfOiQzAt6782Bvi8k= github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -976,12 +974,6 @@ github.com/transparency-dev/formats v0.0.0-20250127084410-134797944be6 h1:TVUG0R github.com/transparency-dev/formats v0.0.0-20250127084410-134797944be6/go.mod h1:tSjZBSQ1ZMxgaOMppnyw48SbTDL947PD/8KYbvrx+lE= github.com/transparency-dev/merkle v0.0.2 h1:Q9nBoQcZcgPamMkGn7ghV8XiTZ/kRxn1yCG81+twTK4= github.com/transparency-dev/merkle v0.0.2/go.mod h1:pqSy+OXefQ1EDUVmAJ8MUhHB9TXGuzVAT58PqBoHz1A= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250314143707-b7c8fb6d4491 h1:HZg3ZJqnMJ2X+t+2jMP2GY1vXrES6R1YtADIURAxb0o= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250314143707-b7c8fb6d4491/go.mod h1:uvyZ7WGpaRDPY+4Lme+s1vEUOluYevTYzrDg9j05cYU= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317142235-ed0b4e51d70e h1:wFX4UhEkJEeId3obw8VH9soCyjmLSYzUeYDOsUSYKjE= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317142235-ed0b4e51d70e/go.mod h1:uvyZ7WGpaRDPY+4Lme+s1vEUOluYevTYzrDg9j05cYU= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317143907-6f25ba6ca3a6 h1:QRFg1XDm9MM1SRfMdUFonaU7NIMQgRjQpqke8UcGNmY= -github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317143907-6f25ba6ca3a6/go.mod h1:uvyZ7WGpaRDPY+4Lme+s1vEUOluYevTYzrDg9j05cYU= github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317151915-e61e2d86f685 h1:Cu1sKlj37BnhBybTQ5V1/zgqPQHlBEXIyLiAK+rVlvA= github.com/transparency-dev/trillian-tessera v0.1.1-0.20250317151915-e61e2d86f685/go.mod h1:uvyZ7WGpaRDPY+4Lme+s1vEUOluYevTYzrDg9j05cYU= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/internal/scti/handlers.go b/internal/scti/handlers.go index 7e3f58ac..3cdea775 100644 --- a/internal/scti/handlers.go +++ b/internal/scti/handlers.go @@ -29,10 +29,10 @@ import ( "sync" "time" - "github.com/google/certificate-transparency-go/tls" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/transparency-dev/static-ct/internal/types" + "github.com/transparency-dev/static-ct/internal/types/tls" "github.com/transparency-dev/static-ct/internal/x509util" "github.com/transparency-dev/static-ct/modules/dedup" tessera "github.com/transparency-dev/trillian-tessera" diff --git a/internal/scti/signatures.go b/internal/scti/signatures.go index 736c12c5..c4151309 100644 --- a/internal/scti/signatures.go +++ b/internal/scti/signatures.go @@ -23,9 +23,9 @@ import ( "fmt" "time" - "github.com/google/certificate-transparency-go/tls" tfl "github.com/transparency-dev/formats/log" "github.com/transparency-dev/static-ct/internal/types" + "github.com/transparency-dev/static-ct/internal/types/tls" "golang.org/x/mod/sumdb/note" ) diff --git a/internal/scti/signatures_test.go b/internal/scti/signatures_test.go index f4feb8e8..1a620cb4 100644 --- a/internal/scti/signatures_test.go +++ b/internal/scti/signatures_test.go @@ -24,10 +24,10 @@ import ( "testing" "time" - "github.com/google/certificate-transparency-go/tls" "github.com/kylelemons/godebug/pretty" "github.com/transparency-dev/static-ct/internal/testdata" "github.com/transparency-dev/static-ct/internal/types" + "github.com/transparency-dev/static-ct/internal/types/tls" "github.com/transparency-dev/static-ct/internal/x509util" ) diff --git a/internal/types/rfc6962.go b/internal/types/rfc6962.go index 1b2f21d9..2c50ec9a 100644 --- a/internal/types/rfc6962.go +++ b/internal/types/rfc6962.go @@ -7,7 +7,7 @@ import ( "encoding/base64" "fmt" - "github.com/google/certificate-transparency-go/tls" + "github.com/transparency-dev/static-ct/internal/types/tls" ) /////////////////////////////////////////////////////////////////////////////// diff --git a/internal/types/tls/tls.go b/internal/types/tls/tls.go new file mode 100644 index 00000000..15c50817 --- /dev/null +++ b/internal/types/tls/tls.go @@ -0,0 +1,724 @@ +// Copyright 2016 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tls implements functionality for dealing with TLS-encoded data, +// as defined in RFC 5246. This includes parsing and generation of TLS-encoded +// data, together with utility functions for dealing with the DigitallySigned +// TLS type. +package tls + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "reflect" + "strconv" + "strings" +) + +// This file holds utility functions for TLS encoding/decoding data +// as per RFC 5246 section 4. + +// A structuralError suggests that the TLS data is valid, but the Go type +// which is receiving it doesn't match. +type structuralError struct { + field string + msg string +} + +func (e structuralError) Error() string { + var prefix string + if e.field != "" { + prefix = e.field + ": " + } + return "tls: structure error: " + prefix + e.msg +} + +// A syntaxError suggests that the TLS data is invalid. +type syntaxError struct { + field string + msg string +} + +func (e syntaxError) Error() string { + var prefix string + if e.field != "" { + prefix = e.field + ": " + } + return "tls: syntax error: " + prefix + e.msg +} + +// Uint24 is an unsigned 3-byte integer. +type Uint24 uint32 + +// Enum is an unsigned integer. +type Enum uint64 + +var ( + uint8Type = reflect.TypeOf(uint8(0)) + uint16Type = reflect.TypeOf(uint16(0)) + uint24Type = reflect.TypeOf(Uint24(0)) + uint32Type = reflect.TypeOf(uint32(0)) + uint64Type = reflect.TypeOf(uint64(0)) + enumType = reflect.TypeOf(Enum(0)) +) + +// Unmarshal parses the TLS-encoded data in b and uses the reflect package to +// fill in an arbitrary value pointed at by val. Because Unmarshal uses the +// reflect package, the structs being written to must use exported fields +// (upper case names). +// +// The mappings between TLS types and Go types is as follows; some fields +// must have tags (to indicate their encoded size). +// +// TLS Go Required Tags +// opaque byte / uint8 +// uint8 byte / uint8 +// uint16 uint16 +// uint24 tls.Uint24 +// uint32 uint32 +// uint64 uint64 +// enum tls.Enum size:S or maxval:N +// Type []Type minlen:N,maxlen:M +// opaque[N] [N]byte / [N]uint8 +// uint8[N] [N]byte / [N]uint8 +// struct { } struct { } +// select(T) { +// case e1: Type *T selector:Field,val:e1 +// } +// +// TLS variants (RFC 5246 s4.6.1) are only supported when the value of the +// associated enumeration type is available earlier in the same enclosing +// struct, and each possible variant is marked with a selector tag (to +// indicate which field selects the variants) and a val tag (to indicate +// what value of the selector picks this particular field). +// +// For example, a TLS structure: +// +// enum { e1(1), e2(2) } EnumType; +// struct { +// EnumType sel; +// select(sel) { +// case e1: uint16 +// case e2: uint32 +// } data; +// } VariantItem; +// +// would have a corresponding Go type: +// +// type VariantItem struct { +// Sel tls.Enum `tls:"maxval:2"` +// Data16 *uint16 `tls:"selector:Sel,val:1"` +// Data32 *uint32 `tls:"selector:Sel,val:2"` +// } +// +// TLS fixed-length vectors of types other than opaque or uint8 are not supported. +// +// For TLS variable-length vectors that are themselves used in other vectors, +// create a single-field structure to represent the inner type. For example, for: +// +// opaque InnerType<1..65535>; +// struct { +// InnerType inners<1,65535>; +// } Something; +// +// convert to: +// +// type InnerType struct { +// Val []byte `tls:"minlen:1,maxlen:65535"` +// } +// type Something struct { +// Inners []InnerType `tls:"minlen:1,maxlen:65535"` +// } +// +// If the encoded value does not fit in the Go type, Unmarshal returns a parse error. +func Unmarshal(b []byte, val interface{}) ([]byte, error) { + return UnmarshalWithParams(b, val, "") +} + +// UnmarshalWithParams allows field parameters to be specified for the +// top-level element. The form of the params is the same as the field tags. +func UnmarshalWithParams(b []byte, val interface{}, params string) ([]byte, error) { + info, err := fieldTagToFieldInfo(params, "") + if err != nil { + return nil, err + } + // The passed in interface{} is a pointer (to allow the value to be written + // to); extract the pointed-to object as a reflect.Value, so parseField + // can do various introspection things. + v := reflect.ValueOf(val).Elem() + offset, err := parseField(v, b, 0, info) + if err != nil { + return nil, err + } + return b[offset:], nil +} + +// Return the number of bytes needed to encode values up to (and including) x. +func byteCount(x uint64) uint { + switch { + case x < 0x100: + return 1 + case x < 0x10000: + return 2 + case x < 0x1000000: + return 3 + case x < 0x100000000: + return 4 + case x < 0x10000000000: + return 5 + case x < 0x1000000000000: + return 6 + case x < 0x100000000000000: + return 7 + default: + return 8 + } +} + +type fieldInfo struct { + count uint // Number of bytes + countSet bool + minlen uint64 // Only relevant for slices + maxlen uint64 // Only relevant for slices + selector string // Only relevant for select sub-values + val uint64 // Only relevant for select sub-values + name string // Used for better error messages +} + +func (i *fieldInfo) fieldName() string { + if i == nil { + return "" + } + return i.name +} + +// Given a tag string, return a fieldInfo describing the field. +func fieldTagToFieldInfo(str string, name string) (*fieldInfo, error) { + var info *fieldInfo + // Iterate over clauses in the tag, ignoring any that don't parse properly. + for _, part := range strings.Split(str, ",") { + switch { + case strings.HasPrefix(part, "maxval:"): + if v, err := strconv.ParseUint(part[7:], 10, 64); err == nil { + info = &fieldInfo{count: byteCount(v), countSet: true} + } + case strings.HasPrefix(part, "size:"): + if sz, err := strconv.ParseUint(part[5:], 10, 32); err == nil { + info = &fieldInfo{count: uint(sz), countSet: true} + } + case strings.HasPrefix(part, "maxlen:"): + v, err := strconv.ParseUint(part[7:], 10, 64) + if err != nil { + continue + } + if info == nil { + info = &fieldInfo{} + } + info.count = byteCount(v) + info.countSet = true + info.maxlen = v + case strings.HasPrefix(part, "minlen:"): + v, err := strconv.ParseUint(part[7:], 10, 64) + if err != nil { + continue + } + if info == nil { + info = &fieldInfo{} + } + info.minlen = v + case strings.HasPrefix(part, "selector:"): + if info == nil { + info = &fieldInfo{} + } + info.selector = part[9:] + case strings.HasPrefix(part, "val:"): + v, err := strconv.ParseUint(part[4:], 10, 64) + if err != nil { + continue + } + if info == nil { + info = &fieldInfo{} + } + info.val = v + } + } + if info != nil { + info.name = name + if info.selector == "" { + if info.count < 1 { + return nil, structuralError{name, "field of unknown size in " + str} + } else if info.count > 8 { + return nil, structuralError{name, "specified size too large in " + str} + } else if info.minlen > info.maxlen { + return nil, structuralError{name, "specified length range inverted in " + str} + } else if info.val > 0 { + return nil, structuralError{name, "specified selector value but not field in " + str} + } + } + } else if name != "" { + info = &fieldInfo{name: name} + } + return info, nil +} + +// Check that a value fits into a field described by a fieldInfo structure. +func (i fieldInfo) check(val uint64, fldName string) error { + if val >= (1 << (8 * i.count)) { + return structuralError{fldName, fmt.Sprintf("value %d too large for size", val)} + } + if i.maxlen != 0 { + if val < i.minlen { + return structuralError{fldName, fmt.Sprintf("value %d too small for minimum %d", val, i.minlen)} + } + if val > i.maxlen { + return structuralError{fldName, fmt.Sprintf("value %d too large for maximum %d", val, i.maxlen)} + } + } + return nil +} + +// readVarUint reads an big-endian unsigned integer of the given size in +// bytes. +func readVarUint(data []byte, info *fieldInfo) (uint64, error) { + if info == nil || !info.countSet { + return 0, structuralError{info.fieldName(), "no field size information available"} + } + if info.count > math.MaxInt { + return 0, syntaxError{info.fieldName(), "count > math.MaxInt"} + } + if len(data) < int(info.count) { + return 0, syntaxError{info.fieldName(), "truncated variable-length integer"} + } + var result uint64 + for i := uint(0); i < info.count; i++ { + result = (result << 8) | uint64(data[i]) + } + if err := info.check(result, info.name); err != nil { + return 0, err + } + return result, nil +} + +// parseField is the main parsing function. Given a byte slice and an offset +// (in bytes) into the data, it will try to parse a suitable ASN.1 value out +// and store it in the given Value. +func parseField(v reflect.Value, data []byte, initOffset int, info *fieldInfo) (int, error) { + offset := initOffset + rest := data[offset:] + + fieldType := v.Type() + // First look for known fixed types. + switch fieldType { + case uint8Type: + if len(rest) < 1 { + return offset, syntaxError{info.fieldName(), "truncated uint8"} + } + v.SetUint(uint64(rest[0])) + offset++ + return offset, nil + case uint16Type: + if len(rest) < 2 { + return offset, syntaxError{info.fieldName(), "truncated uint16"} + } + v.SetUint(uint64(binary.BigEndian.Uint16(rest))) + offset += 2 + return offset, nil + case uint24Type: + if len(rest) < 3 { + return offset, syntaxError{info.fieldName(), "truncated uint24"} + } + v.SetUint(uint64(data[0])<<16 | uint64(data[1])<<8 | uint64(data[2])) + offset += 3 + return offset, nil + case uint32Type: + if len(rest) < 4 { + return offset, syntaxError{info.fieldName(), "truncated uint32"} + } + v.SetUint(uint64(binary.BigEndian.Uint32(rest))) + offset += 4 + return offset, nil + case uint64Type: + if len(rest) < 8 { + return offset, syntaxError{info.fieldName(), "truncated uint64"} + } + v.SetUint(uint64(binary.BigEndian.Uint64(rest))) + offset += 8 + return offset, nil + } + + // Now deal with user-defined types. + switch v.Kind() { + case enumType.Kind(): + // Assume that anything of the same kind as Enum is an Enum, so that + // users can alias types of their own to Enum. + val, err := readVarUint(rest, info) + if err != nil { + return offset, err + } + v.SetUint(val) + if info.count > math.MaxInt { + return offset, syntaxError{info.fieldName(), "count > math.MaxInt"} + } + offset += int(info.count) + return offset, nil + case reflect.Struct: + structType := fieldType + // TLS includes a select(Enum) {..} construct, where the value of an enum + // indicates which variant field is present (like a C union). We require + // that the enum value be an earlier field in the same structure (the selector), + // and that each of the possible variant destination fields be pointers. + // So the Go mapping looks like: + // type variantType struct { + // Which tls.Enum `tls:"size:1"` // this is the selector + // Val1 *type1 `tls:"selector:Which,val:1"` // this is a destination + // Val2 *type2 `tls:"selector:Which,val:1"` // this is a destination + // } + + // To deal with this, we track any enum-like fields and their values... + enums := make(map[string]uint64) + // .. and we track which selector names we've seen (in the destination field tags), + // and whether a destination for that selector has been chosen. + selectorSeen := make(map[string]bool) + for i := 0; i < structType.NumField(); i++ { + // Find information about this field. + tag := structType.Field(i).Tag.Get("tls") + fieldInfo, err := fieldTagToFieldInfo(tag, structType.Field(i).Name) + if err != nil { + return offset, err + } + + destination := v.Field(i) + if fieldInfo.selector != "" { + // This is a possible select(Enum) destination, so first check that the referenced + // selector field has already been seen earlier in the struct. + choice, ok := enums[fieldInfo.selector] + if !ok { + return offset, structuralError{fieldInfo.name, "selector not seen: " + fieldInfo.selector} + } + if structType.Field(i).Type.Kind() != reflect.Ptr { + return offset, structuralError{fieldInfo.name, "choice field not a pointer type"} + } + // Is this the first mention of the selector field name? If so, remember it. + seen, ok := selectorSeen[fieldInfo.selector] + if !ok { + selectorSeen[fieldInfo.selector] = false + } + if choice != fieldInfo.val { + // This destination field was not the chosen one, so make it nil (we checked + // it was a pointer above). + v.Field(i).Set(reflect.Zero(structType.Field(i).Type)) + continue + } + if seen { + // We already saw a different destination field receive the value for this + // selector value, which indicates a badly annotated structure. + return offset, structuralError{fieldInfo.name, "duplicate selector value for " + fieldInfo.selector} + } + selectorSeen[fieldInfo.selector] = true + // Make an object of the pointed-to type and parse into that. + v.Field(i).Set(reflect.New(structType.Field(i).Type.Elem())) + destination = v.Field(i).Elem() + } + offset, err = parseField(destination, data, offset, fieldInfo) + if err != nil { + return offset, err + } + + // Remember any possible tls.Enum values encountered in case they are selectors. + if structType.Field(i).Type.Kind() == enumType.Kind() { + enums[structType.Field(i).Name] = v.Field(i).Uint() + } + + } + + // Now we have seen all fields in the structure, check that all select(Enum) {..} selector + // fields found a destination to put their data in. + for selector, seen := range selectorSeen { + if !seen { + return offset, syntaxError{info.fieldName(), selector + ": unhandled value for selector"} + } + } + return offset, nil + case reflect.Array: + datalen := v.Len() + + if datalen > len(rest) { + return offset, syntaxError{info.fieldName(), "truncated array"} + } + inner := rest[:datalen] + offset += datalen + if fieldType.Elem().Kind() != reflect.Uint8 { + // Only byte/uint8 arrays are supported + return offset, structuralError{info.fieldName(), "unsupported array type: " + v.Type().String()} + } + reflect.Copy(v, reflect.ValueOf(inner)) + return offset, nil + + case reflect.Slice: + sliceType := fieldType + // Slices represent variable-length vectors, which are prefixed by a length field. + // The fieldInfo indicates the size of that length field. + varlen, err := readVarUint(rest, info) + if err != nil { + return offset, err + } + if varlen > math.MaxInt { + return offset, syntaxError{info.fieldName(), "varlen > math.MaxInt"} + } + datalen := int(varlen) + if info.count > math.MaxInt { + return offset, syntaxError{info.fieldName(), "count > math.MaxInt"} + } + offset += int(info.count) + rest = rest[info.count:] + + if datalen > len(rest) { + return offset, syntaxError{info.fieldName(), "truncated slice"} + } + inner := rest[:datalen] + offset += datalen + if fieldType.Elem().Kind() == reflect.Uint8 { + // Fast version for []byte + v.Set(reflect.MakeSlice(sliceType, datalen, datalen)) + reflect.Copy(v, reflect.ValueOf(inner)) + return offset, nil + } + + v.Set(reflect.MakeSlice(sliceType, 0, datalen)) + single := reflect.New(sliceType.Elem()) + for innerOffset := 0; innerOffset < len(inner); { + var err error + innerOffset, err = parseField(single.Elem(), inner, innerOffset, nil) + if err != nil { + return offset, err + } + v.Set(reflect.Append(v, single.Elem())) + } + return offset, nil + + default: + return offset, structuralError{info.fieldName(), fmt.Sprintf("unsupported type: %s of kind %s", fieldType, v.Kind())} + } +} + +// Marshal returns the TLS encoding of val. +func Marshal(val interface{}) ([]byte, error) { + return MarshalWithParams(val, "") +} + +// MarshalWithParams returns the TLS encoding of val, and allows field +// parameters to be specified for the top-level element. The form +// of the params is the same as the field tags. +func MarshalWithParams(val interface{}, params string) ([]byte, error) { + info, err := fieldTagToFieldInfo(params, "") + if err != nil { + return nil, err + } + var out bytes.Buffer + v := reflect.ValueOf(val) + if err := marshalField(&out, v, info); err != nil { + return nil, err + } + return out.Bytes(), err +} + +func marshalField(out *bytes.Buffer, v reflect.Value, info *fieldInfo) error { + var prefix string + if info != nil && len(info.name) > 0 { + prefix = info.name + ": " + } + fieldType := v.Type() + // First look for known fixed types. + switch fieldType { + case uint8Type: + out.WriteByte(byte(v.Uint())) + return nil + case uint16Type: + scratch := make([]byte, 2) + binary.BigEndian.PutUint16(scratch, uint16(v.Uint())) + out.Write(scratch) + return nil + case uint24Type: + i := v.Uint() + if i > 0xffffff { + return structuralError{info.fieldName(), fmt.Sprintf("uint24 overflow %d", i)} + } + scratch := make([]byte, 4) + binary.BigEndian.PutUint32(scratch, uint32(i)) + out.Write(scratch[1:]) + return nil + case uint32Type: + scratch := make([]byte, 4) + binary.BigEndian.PutUint32(scratch, uint32(v.Uint())) + out.Write(scratch) + return nil + case uint64Type: + scratch := make([]byte, 8) + binary.BigEndian.PutUint64(scratch, uint64(v.Uint())) + out.Write(scratch) + return nil + } + + // Now deal with user-defined types. + switch v.Kind() { + case enumType.Kind(): + i := v.Uint() + if info == nil { + return structuralError{info.fieldName(), "enum field tag missing"} + } + if err := info.check(i, prefix); err != nil { + return err + } + scratch := make([]byte, 8) + binary.BigEndian.PutUint64(scratch, uint64(i)) + out.Write(scratch[(8 - info.count):]) + return nil + case reflect.Struct: + structType := fieldType + enums := make(map[string]uint64) // Values of any Enum fields + // The comment parseField() describes the mapping of the TLS select(Enum) {..} construct; + // here we have selector and source (rather than destination) fields. + + // Track which selector names we've seen (in the source field tags), and whether a source + // value for that selector has been processed. + selectorSeen := make(map[string]bool) + for i := 0; i < structType.NumField(); i++ { + // Find information about this field. + tag := structType.Field(i).Tag.Get("tls") + fieldInfo, err := fieldTagToFieldInfo(tag, structType.Field(i).Name) + if err != nil { + return err + } + + source := v.Field(i) + if fieldInfo.selector != "" { + // This field is a possible source for a select(Enum) {..}. First check + // the selector field name has been seen. + choice, ok := enums[fieldInfo.selector] + if !ok { + return structuralError{fieldInfo.name, "selector not seen: " + fieldInfo.selector} + } + if structType.Field(i).Type.Kind() != reflect.Ptr { + return structuralError{fieldInfo.name, "choice field not a pointer type"} + } + // Is this the first mention of the selector field name? If so, remember it. + seen, ok := selectorSeen[fieldInfo.selector] + if !ok { + selectorSeen[fieldInfo.selector] = false + } + if choice != fieldInfo.val { + // This source was not chosen; police that it should be nil. + if v.Field(i).Pointer() != uintptr(0) { + return structuralError{fieldInfo.name, "unchosen field is non-nil"} + } + continue + } + if seen { + // We already saw a different source field generate the value for this + // selector value, which indicates a badly annotated structure. + return structuralError{fieldInfo.name, "duplicate selector value for " + fieldInfo.selector} + } + selectorSeen[fieldInfo.selector] = true + if v.Field(i).Pointer() == uintptr(0) { + return structuralError{fieldInfo.name, "chosen field is nil"} + } + // Marshal from the pointed-to source object. + source = v.Field(i).Elem() + } + + var fieldData bytes.Buffer + if err := marshalField(&fieldData, source, fieldInfo); err != nil { + return err + } + out.Write(fieldData.Bytes()) + + // Remember any tls.Enum values encountered in case they are selectors. + if structType.Field(i).Type.Kind() == enumType.Kind() { + enums[structType.Field(i).Name] = v.Field(i).Uint() + } + } + // Now we have seen all fields in the structure, check that all select(Enum) {..} selector + // fields found a source field to get their data from. + for selector, seen := range selectorSeen { + if !seen { + return syntaxError{info.fieldName(), selector + ": unhandled value for selector"} + } + } + return nil + + case reflect.Array: + datalen := v.Len() + arrayType := fieldType + if arrayType.Elem().Kind() != reflect.Uint8 { + // Only byte/uint8 arrays are supported + return structuralError{info.fieldName(), "unsupported array type"} + } + bytes := make([]byte, datalen) + for i := 0; i < datalen; i++ { + bytes[i] = uint8(v.Index(i).Uint()) + } + _, err := out.Write(bytes) + return err + + case reflect.Slice: + if info == nil { + return structuralError{info.fieldName(), "slice field tag missing"} + } + + sliceType := fieldType + if sliceType.Elem().Kind() == reflect.Uint8 { + // Fast version for []byte: first write the length as info.count bytes. + datalen := v.Len() + scratch := make([]byte, 8) + binary.BigEndian.PutUint64(scratch, uint64(datalen)) + out.Write(scratch[(8 - info.count):]) + + if err := info.check(uint64(datalen), prefix); err != nil { + return err + } + // Then just write the data. + bytes := make([]byte, datalen) + for i := 0; i < datalen; i++ { + bytes[i] = uint8(v.Index(i).Uint()) + } + _, err := out.Write(bytes) + return err + } + // General version: use a separate Buffer to write the slice entries into. + var innerBuf bytes.Buffer + for i := 0; i < v.Len(); i++ { + if err := marshalField(&innerBuf, v.Index(i), nil); err != nil { + return err + } + } + + // Now insert (and check) the size. + size := uint64(innerBuf.Len()) + if err := info.check(size, prefix); err != nil { + return err + } + scratch := make([]byte, 8) + binary.BigEndian.PutUint64(scratch, size) + out.Write(scratch[(8 - info.count):]) + + // Then copy the data. + _, err := out.Write(innerBuf.Bytes()) + return err + + default: + return structuralError{info.fieldName(), fmt.Sprintf("unsupported type: %s of kind %s", fieldType, v.Kind())} + } +} diff --git a/internal/types/tls/tls_test.go b/internal/types/tls/tls_test.go new file mode 100644 index 00000000..fd208744 --- /dev/null +++ b/internal/types/tls/tls_test.go @@ -0,0 +1,355 @@ +// Copyright 2016 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tls + +import ( + "bytes" + "encoding/hex" + "reflect" + "strings" + "testing" +) + +type testStruct struct { + Data []byte `tls:"minlen:2,maxlen:4"` + IntVal uint16 + Other [4]byte + Enum Enum `tls:"size:2"` +} + +type testVariant struct { + Which Enum `tls:"size:1"` + Val16 *uint16 `tls:"selector:Which,val:0"` + Val32 *uint32 `tls:"selector:Which,val:1"` +} + +type testTwoVariants struct { + Which Enum `tls:"size:1"` + Val16 *uint16 `tls:"selector:Which,val:0"` + Val32 *uint32 `tls:"selector:Which,val:1"` + Second Enum `tls:"size:1"` + Second16 *uint16 `tls:"selector:Second,val:0"` + Second32 *uint32 `tls:"selector:Second,val:1"` +} + +// Check that library users can define their own Enum types. +type aliasEnum Enum +type testAliasEnum struct { + Val aliasEnum `tls:"size:1"` + Val16 *uint16 `tls:"selector:Val,val:1"` + Val32 *uint32 `tls:"selector:Val,val:2"` +} + +type testNonByteSlice struct { + Vals []uint16 `tls:"minlen:2,maxlen:6"` +} + +type testSliceOfStructs struct { + Vals []testVariant `tls:"minlen:0,maxlen:100"` +} + +type testInnerType struct { + Val []byte `tls:"minlen:0,maxlen:65535"` +} + +type testSliceOfSlices struct { + Inners []testInnerType `tls:"minlen:0,maxlen:65535"` +} + +func TestMarshalUnmarshalRoundTrip(t *testing.T) { + thing := testStruct{Data: []byte{0x01, 0x02, 0x03}, IntVal: 42, Other: [4]byte{1, 2, 3, 4}, Enum: 17} + data, err := Marshal(thing) + if err != nil { + t.Fatalf("Failed to Marshal(%+v): %s", thing, err.Error()) + } + var other testStruct + rest, err := Unmarshal(data, &other) + if err != nil { + t.Fatalf("Failed to Unmarshal(%s)", hex.EncodeToString(data)) + } + if len(rest) > 0 { + t.Errorf("Data left over after Unmarshal(%s): %s", hex.EncodeToString(data), hex.EncodeToString(rest)) + } +} + +func TestFieldTagToFieldInfo(t *testing.T) { + var tests = []struct { + tag string + want *fieldInfo + errstr string + }{ + {"", nil, ""}, + {"bogus", nil, ""}, + {"also,bogus", nil, ""}, + {"also,bogus:99", nil, ""}, + {"maxval:1xyz", nil, ""}, + {"maxval:1", &fieldInfo{count: 1, countSet: true}, ""}, + {"maxval:255", &fieldInfo{count: 1, countSet: true}, ""}, + {"maxval:256", &fieldInfo{count: 2, countSet: true}, ""}, + {"maxval:65535", &fieldInfo{count: 2, countSet: true}, ""}, + {"maxval:65536", &fieldInfo{count: 3, countSet: true}, ""}, + {"maxval:16777215", &fieldInfo{count: 3, countSet: true}, ""}, + {"maxval:16777216", &fieldInfo{count: 4, countSet: true}, ""}, + {"maxval:16777216", &fieldInfo{count: 4, countSet: true}, ""}, + {"maxval:4294967295", &fieldInfo{count: 4, countSet: true}, ""}, + {"maxval:4294967296", &fieldInfo{count: 5, countSet: true}, ""}, + {"maxval:1099511627775", &fieldInfo{count: 5, countSet: true}, ""}, + {"maxval:1099511627776", &fieldInfo{count: 6, countSet: true}, ""}, + {"maxval:281474976710655", &fieldInfo{count: 6, countSet: true}, ""}, + {"maxval:281474976710656", &fieldInfo{count: 7, countSet: true}, ""}, + {"maxval:72057594037927935", &fieldInfo{count: 7, countSet: true}, ""}, + {"maxval:72057594037927936", &fieldInfo{count: 8, countSet: true}, ""}, + {"minlen:1x", nil, ""}, + {"maxlen:1x", nil, ""}, + {"maxlen:1", &fieldInfo{count: 1, countSet: true, maxlen: 1}, ""}, + {"maxlen:255", &fieldInfo{count: 1, countSet: true, maxlen: 255}, ""}, + {"maxlen:65535", &fieldInfo{count: 2, countSet: true, maxlen: 65535}, ""}, + {"minlen:65530,maxlen:65535", &fieldInfo{count: 2, countSet: true, minlen: 65530, maxlen: 65535}, ""}, + {"maxlen:65535,minlen:65530", &fieldInfo{count: 2, countSet: true, minlen: 65530, maxlen: 65535}, ""}, + {"minlen:65536,maxlen:65535", nil, "inverted"}, + {"maxlen:16777215", &fieldInfo{count: 3, countSet: true, maxlen: 16777215}, ""}, + {"maxlen:281474976710655", &fieldInfo{count: 6, countSet: true, maxlen: 281474976710655}, ""}, + {"maxlen:72057594037927936", &fieldInfo{count: 8, countSet: true, maxlen: 72057594037927936}, ""}, + {"size:0", nil, "unknown size"}, + {"size:1", &fieldInfo{count: 1, countSet: true}, ""}, + {"size:2", &fieldInfo{count: 2, countSet: true}, ""}, + {"size:3", &fieldInfo{count: 3, countSet: true}, ""}, + {"size:4", &fieldInfo{count: 4, countSet: true}, ""}, + {"size:5", &fieldInfo{count: 5, countSet: true}, ""}, + {"size:6", &fieldInfo{count: 6, countSet: true}, ""}, + {"size:7", &fieldInfo{count: 7, countSet: true}, ""}, + {"size:8", &fieldInfo{count: 8, countSet: true}, ""}, + {"size:9", nil, "too large"}, + {"size:1x", nil, ""}, + {"size:1,val:9", nil, "selector value"}, + {"selector:Bob,val:x9", &fieldInfo{selector: "Bob"}, ""}, + {"selector:Fred,val:1", &fieldInfo{selector: "Fred", val: 1}, ""}, + {"val:9,selector:Fred,val:1", &fieldInfo{selector: "Fred", val: 1}, ""}, + } + for _, test := range tests { + got, err := fieldTagToFieldInfo(test.tag, "") + if test.errstr != "" { + if err == nil { + t.Errorf("fieldTagToFieldInfo('%v')=%+v,nil; want error %q", test.tag, got, test.errstr) + } else if !strings.Contains(err.Error(), test.errstr) { + t.Errorf("fieldTagToFieldInfo('%v')=nil,%q; want error %q", test.tag, err.Error(), test.errstr) + } + continue + } + if err != nil { + t.Errorf("fieldTagToFieldInfo('%v')=nil,%q; want %+v", test.tag, err.Error(), test.want) + } else if !reflect.DeepEqual(got, test.want) { + t.Errorf("fieldTagToFieldInfo('%v')=%+v,nil; want %+v", test.tag, got, test.want) + } + } +} + +// Can't take the address of a numeric constant so use helper functions +func newByte(n byte) *byte { return &n } +func newUint8(n uint8) *uint8 { return &n } +func newUint16(n uint16) *uint16 { return &n } +func newUint24(n Uint24) *Uint24 { return &n } +func newUint32(n uint32) *uint32 { return &n } +func newUint64(n uint64) *uint64 { return &n } +func newInt16(n int16) *int16 { return &n } +func newEnum(n Enum) *Enum { return &n } + +func TestUnmarshalMarshalWithParamsRoundTrip(t *testing.T) { + var tests = []struct { + data string // hex encoded + params string + item interface{} + }{ + {"00", "", newUint8(0)}, + {"03", "", newByte(3)}, + {"0101", "", newUint16(0x0101)}, + {"010203", "", newUint24(0x010203)}, + {"000000", "", newUint24(0x00)}, + {"00000009", "", newUint32(0x09)}, + {"0000000901020304", "", newUint64(0x0901020304)}, + {"030405", "", &[3]byte{3, 4, 5}}, + {"03", "", &[1]byte{3}}, + {"0001", "size:2", newEnum(1)}, + {"0100000001", "size:5", newEnum(0x100000001)}, + {"12", "maxval:18", newEnum(18)}, + // Note that maxval is just used to give enum size; it's not policed + {"20", "maxval:18", newEnum(32)}, + {"020a0b", "minlen:1,maxlen:5", &[]byte{0xa, 0xb}}, + {"020a0b0101010203040011", "", &testStruct{Data: []byte{0xa, 0xb}, IntVal: 0x101, Other: [4]byte{1, 2, 3, 4}, Enum: 17}}, + {"000102", "", &testVariant{Which: 0, Val16: newUint16(0x0102)}}, + {"0101020304", "", &testVariant{Which: 1, Val32: newUint32(0x01020304)}}, + {"0001020104030201", "", &testTwoVariants{Which: 0, Val16: newUint16(0x0102), Second: 1, Second32: newUint32(0x04030201)}}, + {"06010102020303", "", &testNonByteSlice{Vals: []uint16{0x101, 0x202, 0x303}}}, + {"00", "", &testSliceOfStructs{Vals: []testVariant{}}}, + {"080001020101020304", "", + &testSliceOfStructs{ + Vals: []testVariant{ + {Which: 0, Val16: newUint16(0x0102)}, + {Which: 1, Val32: newUint32(0x01020304)}, + }, + }, + }, + {"000a00030102030003040506", "", + &testSliceOfSlices{ + Inners: []testInnerType{ + {Val: []byte{1, 2, 3}}, + {Val: []byte{4, 5, 6}}, + }, + }, + }, + {"011011", "", &testAliasEnum{Val: 1, Val16: newUint16(0x1011)}}, + {"0403", "", &SignatureAndHashAlgorithm{Hash: SHA256, Signature: ECDSA}}, + {"04030003010203", "", + &DigitallySigned{ + Algorithm: SignatureAndHashAlgorithm{Hash: SHA256, Signature: ECDSA}, + Signature: []byte{1, 2, 3}, + }, + }, + } + for _, test := range tests { + inVal := reflect.ValueOf(test.item).Elem() + pv := reflect.New(reflect.TypeOf(test.item).Elem()) + val := pv.Interface() + inData, _ := hex.DecodeString(test.data) + if _, err := UnmarshalWithParams(inData, val, test.params); err != nil { + t.Errorf("Unmarshal(%s)=nil,%q; want %+v", test.data, err.Error(), inVal) + } else if !reflect.DeepEqual(val, test.item) { + t.Errorf("Unmarshal(%s)=%+v,nil; want %+v", test.data, reflect.ValueOf(val).Elem(), inVal) + } + + if data, err := MarshalWithParams(inVal.Interface(), test.params); err != nil { + t.Errorf("Marshal(%+v)=nil,%q; want %s", inVal, err.Error(), test.data) + } else if !bytes.Equal(data, inData) { + t.Errorf("Marshal(%+v)=%s,nil; want %s", inVal, hex.EncodeToString(data), test.data) + } + } +} + +type testInvalidFieldTag struct { + Data []byte `tls:"minlen:3,maxlen:2"` +} + +type testDuplicateSelectorVal struct { + Which Enum `tls:"size:1"` + Val *uint16 `tls:"selector:Which,val:0"` + DupVal *uint32 `tls:"selector:Which"` // implicit val:0 +} + +type testMissingSelector struct { + Val *uint16 `tls:"selector:Missing,val:0"` +} + +type testChoiceNotPointer struct { + Which Enum `tls:"size:1"` + Val uint16 `tls:"selector:Which,val:0"` +} + +type nonEnumAlias uint16 + +func newNonEnumAlias(n nonEnumAlias) *nonEnumAlias { return &n } + +func TestUnmarshalWithParamsFailures(t *testing.T) { + var tests = []struct { + data string // hex encoded + params string + item interface{} + errstr string + }{ + {"", "", newUint8(0), "truncated"}, + {"0x01", "", newUint16(0x0101), "truncated"}, + {"0103", "", newUint24(0x010203), "truncated"}, + {"00", "", newUint24(0x00), "truncated"}, + {"000009", "", newUint32(0x09), "truncated"}, + {"00000901020304", "", newUint64(0x0901020304), "truncated"}, + {"0102", "", newInt16(0x0102), "unsupported type"}, // TLS encoding only supports unsigned integers + {"0607", "", &[3]byte{6, 7, 8}, "truncated array"}, + {"01010202", "", &[3]uint16{0x101, 0x202}, "unsupported array"}, + {"01", "", newEnum(1), "no field size"}, + {"00", "size:2", newEnum(0), "truncated"}, + {"00", "size:9", newEnum(0), "too large"}, + {"020a0b", "minlen:4,maxlen:8", &[]byte{0x0a, 0x0b}, "too small"}, + {"040a0b0c0d", "minlen:1,maxlen:3", &[]byte{0x0a, 0x0b, 0x0c, 0x0d}, "too large"}, + {"020a0b", "minlen:8,maxlen:6", &[]byte{0x0a, 0x0b}, "inverted"}, + {"020a", "minlen:0,maxlen:6", &[]byte{0x0a, 0x0b}, "truncated"}, + {"02", "minlen:0,maxlen:6", &[]byte{0x0a, 0x0b}, "truncated"}, + {"0001", "minlen:0,maxlen:256", &[]byte{0x0a, 0x0b}, "truncated"}, + {"020a", "minlen:0", &[]byte{0x0a, 0x0b}, "unknown size"}, + {"020a", "", &[]byte{0x0a, 0x0b}, "no field size information"}, + {"020a0b", "", &testInvalidFieldTag{}, "range inverted"}, + {"020a0b01010102030400", "", + &testStruct{Data: []byte{0xa, 0xb}, IntVal: 0x101, Other: [4]byte{1, 2, 3, 4}, Enum: 17}, "truncated"}, + {"010102", "", &testVariant{Which: 1, Val32: newUint32(0x01020304)}, "truncated"}, + {"092122", "", &testVariant{Which: 0, Val16: newUint16(0x2122)}, "unhandled value for selector"}, + {"0001020304", "", &testDuplicateSelectorVal{Which: 0, Val: newUint16(0x0102)}, "duplicate selector value"}, + {"0102", "", &testMissingSelector{Val: newUint16(1)}, "selector not seen"}, + {"000007", "", &testChoiceNotPointer{Which: 0, Val: 7}, "choice field not a pointer type"}, + {"05010102020303", "", &testNonByteSlice{Vals: []uint16{0x101, 0x202, 0x303}}, "truncated"}, + {"0101", "size:2", newNonEnumAlias(0x0102), "unsupported type"}, + {"0403010203", "", + &DigitallySigned{ + Algorithm: SignatureAndHashAlgorithm{Hash: SHA256, Signature: ECDSA}, + Signature: []byte{1, 2, 3}}, "truncated"}, + } + for _, test := range tests { + pv := reflect.New(reflect.TypeOf(test.item).Elem()) + val := pv.Interface() + in, _ := hex.DecodeString(test.data) + if _, err := UnmarshalWithParams(in, val, test.params); err == nil { + t.Errorf("Unmarshal(%s)=%+v,nil; want error %q", test.data, reflect.ValueOf(val).Elem(), test.errstr) + } else if !strings.Contains(err.Error(), test.errstr) { + t.Errorf("Unmarshal(%s)=nil,%q; want error %q", test.data, err.Error(), test.errstr) + } + } +} + +func TestMarshalWithParamsFailures(t *testing.T) { + var tests = []struct { + item interface{} + params string + errstr string + }{ + {Uint24(0x1000000), "", "overflow"}, + {int16(0x0102), "", "unsupported type"}, // All TLS ints are unsigned + {Enum(1), "", "field tag missing"}, + {Enum(256), "size:1", "too large"}, + {Enum(256), "maxval:255", "too large"}, + {Enum(2), "", "field tag missing"}, + {Enum(256), "size:9", "too large"}, + {[]byte{0xa, 0xb, 0xc, 0xd}, "minlen:1,maxlen:3", "too large"}, + {[]byte{0xa, 0xb, 0xc, 0xd}, "minlen:6,maxlen:13", "too small"}, + {[]byte{0xa, 0xb, 0xc, 0xd}, "minlen:6,maxlen:3", "inverted"}, + {[]byte{0xa, 0xb, 0xc, 0xd}, "minlen:6", "unknown size"}, + {[]byte{0xa, 0xb, 0xc, 0xd}, "", "field tag missing"}, + {[3]uint16{0x101, 0x202}, "", "unsupported array"}, + {testInvalidFieldTag{}, "", "inverted"}, + {testStruct{Data: []byte{0xa}, IntVal: 0x101, Other: [4]byte{1, 2, 3, 4}, Enum: 17}, "", "too small"}, + {testVariant{Which: 0, Val32: newUint32(0x01020304)}, "", "chosen field is nil"}, + {testVariant{Which: 0, Val16: newUint16(11), Val32: newUint32(0x01020304)}, "", "unchosen field is non-nil"}, + {testVariant{Which: 3}, "", "unhandled value for selector"}, + {testMissingSelector{Val: newUint16(1)}, "", "selector not seen"}, + {testChoiceNotPointer{Which: 0, Val: 7}, "", "choice field not a pointer"}, + {testDuplicateSelectorVal{Which: 0, Val: newUint16(1)}, "", "duplicate selector value"}, + {testNonByteSlice{Vals: []uint16{1, 2, 3, 4}}, "", "too large"}, + {testSliceOfStructs{[]testVariant{{Which: 3}}}, "", "unhandled value for selector"}, + {nonEnumAlias(0x0102), "", "unsupported type"}, + } + for _, test := range tests { + if data, err := MarshalWithParams(test.item, test.params); err == nil { + t.Errorf("Marshal(%+v)=%x,nil; want error %q", test.item, data, test.errstr) + } else if !strings.Contains(err.Error(), test.errstr) { + t.Errorf("Marshal(%+v)=nil,%q; want error %q", test.item, err.Error(), test.errstr) + } + } +} diff --git a/internal/types/tls/types.go b/internal/types/tls/types.go new file mode 100644 index 00000000..0652a053 --- /dev/null +++ b/internal/types/tls/types.go @@ -0,0 +1,93 @@ +// Copyright 2016 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "fmt" +) + +// DigitallySigned gives information about a signature, including the algorithm used +// and the signature value. Defined in RFC 5246 s4.7. +type DigitallySigned struct { + Algorithm SignatureAndHashAlgorithm + Signature []byte `tls:"minlen:0,maxlen:65535"` +} + +func (d DigitallySigned) String() string { + return fmt.Sprintf("Signature: HashAlgo=%v SignAlgo=%v Value=%x", d.Algorithm.Hash, d.Algorithm.Signature, d.Signature) +} + +// SignatureAndHashAlgorithm gives information about the algorithms used for a +// signature. Defined in RFC 5246 s7.4.1.4.1. +type SignatureAndHashAlgorithm struct { + Hash HashAlgorithm `tls:"maxval:255"` + Signature SignatureAlgorithm `tls:"maxval:255"` +} + +// HashAlgorithm enum from RFC 5246 s7.4.1.4.1. +type HashAlgorithm Enum + +// HashAlgorithm constants from RFC 5246 s7.4.1.4.1. +const ( + SHA256 HashAlgorithm = 4 + SHA384 HashAlgorithm = 5 + SHA512 HashAlgorithm = 6 +) + +func (h HashAlgorithm) String() string { + switch h { + case SHA256: + return "SHA256" + case SHA384: + return "SHA384" + case SHA512: + return "SHA512" + default: + return fmt.Sprintf("UNKNOWN(%d)", h) + } +} + +// SignatureAlgorithm enum from RFC 5246 s7.4.1.4.1. +type SignatureAlgorithm Enum + +// SignatureAlgorithm constants from RFC 5246 s7.4.1.4.1. +const ( + Anonymous SignatureAlgorithm = 0 + ECDSA SignatureAlgorithm = 3 +) + +func (s SignatureAlgorithm) String() string { + switch s { + case Anonymous: + return "Anonymous" + case ECDSA: + return "ECDSA" + default: + return fmt.Sprintf("UNKNOWN(%d)", s) + } +} + +// SignatureAlgorithmFromPubKey returns the algorithm used for this public key. +// ECDSA, RSA, and DSA keys are supported. Other key types will return Anonymous. +func SignatureAlgorithmFromPubKey(k crypto.PublicKey) SignatureAlgorithm { + switch k.(type) { + case *ecdsa.PublicKey: + return ECDSA + default: + return Anonymous + } +} diff --git a/internal/types/tls/types_test.go b/internal/types/tls/types_test.go new file mode 100644 index 00000000..af7231b0 --- /dev/null +++ b/internal/types/tls/types_test.go @@ -0,0 +1,90 @@ +// Copyright 2016 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "testing" +) + +func TestHashAlgorithmString(t *testing.T) { + var tests = []struct { + algo HashAlgorithm + want string + }{ + {SHA256, "SHA256"}, + {SHA384, "SHA384"}, + {SHA512, "SHA512"}, + {99, "UNKNOWN(99)"}, + } + for _, test := range tests { + if got := test.algo.String(); got != test.want { + t.Errorf("%v.String()=%q; want %q", test.algo, got, test.want) + } + } +} + +func TestSignatureAlgorithmString(t *testing.T) { + var tests = []struct { + algo SignatureAlgorithm + want string + }{ + {Anonymous, "Anonymous"}, + {ECDSA, "ECDSA"}, + {99, "UNKNOWN(99)"}, + } + for _, test := range tests { + if got := test.algo.String(); got != test.want { + t.Errorf("%v.String()=%q; want %q", test.algo, got, test.want) + } + } +} + +func TestDigitallySignedString(t *testing.T) { + var tests = []struct { + ds DigitallySigned + want string + }{ + { + ds: DigitallySigned{Algorithm: SignatureAndHashAlgorithm{Hash: SHA256, Signature: ECDSA}, Signature: []byte{0x01, 0x02}}, + want: "Signature: HashAlgo=SHA256 SignAlgo=ECDSA Value=0102", + }, + { + ds: DigitallySigned{Algorithm: SignatureAndHashAlgorithm{Hash: 99, Signature: 99}, Signature: []byte{0x03, 0x04}}, + want: "Signature: HashAlgo=UNKNOWN(99) SignAlgo=UNKNOWN(99) Value=0304", + }, + } + for _, test := range tests { + if got := test.ds.String(); got != test.want { + t.Errorf("%v.String()=%q; want %q", test.ds, got, test.want) + } + } +} + +func TestSignatureAlgorithm(t *testing.T) { + for _, test := range []struct { + name string + key crypto.PublicKey + want SignatureAlgorithm + }{ + {name: "ECDSA", key: new(ecdsa.PublicKey), want: ECDSA}, + {name: "Other", key: "foo", want: Anonymous}, + } { + if got := SignatureAlgorithmFromPubKey(test.key); got != test.want { + t.Errorf("%v: SignatureAlgorithm() = %v, want %v", test.name, got, test.want) + } + } +}