diff --git a/piv/pcsc_test.go b/piv/pcsc_test.go index b361753..cfb6bfe 100644 --- a/piv/pcsc_test.go +++ b/piv/pcsc_test.go @@ -21,7 +21,7 @@ import ( ) func runContextTest(t *testing.T, f func(t *testing.T, c *scContext)) { - ctx, err := newSCContext() + ctx, err := newSCContext(nil) if err != nil { t.Fatalf("creating context: %v", err) } diff --git a/piv/pcsc_trace.go b/piv/pcsc_trace.go new file mode 100644 index 0000000..b40f4fe --- /dev/null +++ b/piv/pcsc_trace.go @@ -0,0 +1,108 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package piv + +import ( + "context" + "reflect" +) + +// ClientTrace is a set of hooks to run at various stages pcsc calls. +// Any particular hook may be nil. Functions may be +// called concurrently from different goroutines and some may be called +// after the request has completed or failed. +// +// ClientTrace is adapted from httptrace.ClientTrace. +// ClientTrace currently traces a single pcsc call, providing the apdus +// that were sent. +type ClientTrace struct { + // Transmit is called before an APDU is transmitted to the card. + // The byte array is the complete contents of the request being sent to + // SCardTransmit. + // Transmit is called from scTx. + Transmit func(req []byte) + + // TransmitResult is called afterr an APDU is transmitted to the card. + // req is the contents of the request. + // resp is the contents of the response. + // respN is the number of bytes returned in the response. + // s1,sw2 are the last 2 bytes of the response. + // sw1,sw2 are the contents of last 2 bytes of the response. + // an apduErr contains sw1,sw2. + // if sw1==0x61, there is more data. + // TransmitResult is called from scTx. + TransmitResult func(req, resp []byte, respN int, sw1, sw2 byte) +} + +// unique type to prevent assignment. +type clientEventContextKey struct{} + +// ContextClientTrace returns the [ClientTrace] associated with the +// provided context. If none, it returns nil. +func ContextClientTrace(ctx context.Context) *ClientTrace { + trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace) + return trace +} + +// compose modifies t such that it respects the previously-registered hooks in old, +// subject to the composition policy requested in t.Compose. +func (t *ClientTrace) compose(old *ClientTrace) { + if old == nil { + return + } + tv := reflect.ValueOf(t).Elem() + ov := reflect.ValueOf(old).Elem() + structType := tv.Type() + for i := 0; i < structType.NumField(); i++ { + tf := tv.Field(i) + hookType := tf.Type() + if hookType.Kind() != reflect.Func { + continue + } + of := ov.Field(i) + if of.IsNil() { + continue + } + if tf.IsNil() { + tf.Set(of) + continue + } + + // Make a copy of tf for tf to call. (Otherwise it + // creates a recursive call cycle and stack overflows) + tfCopy := reflect.ValueOf(tf.Interface()) + + // We need to call both tf and of in some order. + newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value { + tfCopy.Call(args) + return of.Call(args) + }) + tv.Field(i).Set(newFunc) + } +} + +// WithClientTrace returns a new context based on the provided parent +// ctx. HTTP client requests made with the returned context will use +// the provided trace hooks, in addition to any previous hooks +// registered with ctx. Any hooks defined in the provided trace will +// be called first. +func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { + if trace == nil { + panic("nil trace") + } + old := ContextClientTrace(ctx) + trace.compose(old) + + ctx = context.WithValue(ctx, clientEventContextKey{}, trace) + + return ctx +} diff --git a/piv/pcsc_unix.go b/piv/pcsc_unix.go index a43d259..9c432e5 100644 --- a/piv/pcsc_unix.go +++ b/piv/pcsc_unix.go @@ -42,16 +42,17 @@ import ( const rcSuccess = C.SCARD_S_SUCCESS type scContext struct { - ctx C.SCARDCONTEXT + ctx C.SCARDCONTEXT + trace *ClientTrace } -func newSCContext() (*scContext, error) { +func newSCContext(trace *ClientTrace) (*scContext, error) { var ctx C.SCARDCONTEXT rc := C.SCardEstablishContext(C.SCARD_SCOPE_SYSTEM, nil, nil, &ctx) if err := scCheck(rc); err != nil { return nil, err } - return &scContext{ctx: ctx}, nil + return &scContext{ctx: ctx, trace: trace}, nil } func (c *scContext) Close() error { @@ -88,8 +89,14 @@ func (c *scContext) ListReaders() ([]string, error) { return readers, nil } +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (c *scContext) WithClientTrace(clientTrace *ClientTrace) { + c.trace = clientTrace +} + type scHandle struct { - h C.SCARDHANDLE + h C.SCARDHANDLE + trace *ClientTrace } func (c *scContext) Connect(reader string) (*scHandle, error) { @@ -103,7 +110,12 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { if err := scCheck(rc); err != nil { return nil, err } - return &scHandle{handle}, nil + return &scHandle{h: handle, trace: c.trace}, nil +} + +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (h *scHandle) WithClientTrace(clientTrace *ClientTrace) { + h.trace = clientTrace } func (h *scHandle) Close() error { @@ -112,23 +124,35 @@ func (h *scHandle) Close() error { type scTx struct { h C.SCARDHANDLE + // If trace is not nil, then trace.Transmit and trace.TransmitResult will be called. + trace *ClientTrace } func (h *scHandle) Begin() (*scTx, error) { if err := scCheck(C.SCardBeginTransaction(h.h)); err != nil { return nil, err } - return &scTx{h.h}, nil + return &scTx{h.h, nil}, nil } func (t *scTx) Close() error { return scCheck(C.SCardEndTransaction(t.h, C.SCARD_LEAVE_CARD)) } +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (t *scTx) WithClientTrace(clientTrace *ClientTrace) { + t.trace = clientTrace +} + func (t *scTx) transmit(req []byte) (more bool, b []byte, err error) { var resp [C.MAX_BUFFER_SIZE_EXTENDED]byte reqN := C.DWORD(len(req)) respN := C.DWORD(len(resp)) + + if t.trace != nil && t.trace.Transmit != nil { + t.trace.Transmit(req[:]) + } + rc := C.SCardTransmit( t.h, C.SCARD_PCI_T1, @@ -142,6 +166,11 @@ func (t *scTx) transmit(req []byte) (more bool, b []byte, err error) { } sw1 := resp[respN-2] sw2 := resp[respN-1] + + if t.trace != nil && t.trace.TransmitResult != nil { + t.trace.TransmitResult(req[:], resp[:respN], int(respN), sw1, sw2) + } + if sw1 == 0x90 && sw2 == 0x00 { return false, resp[:respN-2], nil } diff --git a/piv/pcsc_windows.go b/piv/pcsc_windows.go index 845194f..00d984b 100644 --- a/piv/pcsc_windows.go +++ b/piv/pcsc_windows.go @@ -54,10 +54,11 @@ func isRCNoReaders(rc uintptr) bool { } type scContext struct { - ctx syscall.Handle + ctx syscall.Handle + trace *ClientTrace } -func newSCContext() (*scContext, error) { +func newSCContext(trace *ClientTrace) (*scContext, error) { var ctx syscall.Handle r0, _, _ := procSCardEstablishContext.Call( @@ -69,7 +70,7 @@ func newSCContext() (*scContext, error) { if err := scCheck(r0); err != nil { return nil, err } - return &scContext{ctx: ctx}, nil + return &scContext{ctx: ctx, trace: trace}, nil } func (c *scContext) Close() error { @@ -142,11 +143,17 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { if err := scCheck(r0); err != nil { return nil, err } - return &scHandle{handle}, nil + return &scHandle{handle: handle, trace: c.trace}, nil +} + +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (c *scContext) WithClientTrace(clientTrace *ClientTrace) { + c.trace = clientTrace } type scHandle struct { handle syscall.Handle + trace *ClientTrace } func (h *scHandle) Close() error { @@ -159,7 +166,7 @@ func (h *scHandle) Begin() (*scTx, error) { if err := scCheck(r0); err != nil { return nil, err } - return &scTx{h.handle}, nil + return &scTx{h.handle, nil}, nil } func (t *scTx) Close() error { @@ -167,14 +174,31 @@ func (t *scTx) Close() error { return scCheck(r0) } +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (h *scHandle) WithClientTrace(clientTrace *ClientTrace) { + h.trace = clientTrace +} + type scTx struct { handle syscall.Handle + // If trace is not nil, then trace.Transmit and trace.TransmitResult will be called. + trace *ClientTrace +} + +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (t *scTx) WithClientTrace(clientTrace *ClientTrace) { + t.trace = clientTrace } func (t *scTx) transmit(req []byte) (more bool, b []byte, err error) { var resp [maxBufferSizeExtended]byte reqN := len(req) respN := len(resp) + + if t.trace != nil && t.trace.Transmit != nil { + t.trace.Transmit(req[:]) + } + r0, _, _ := procSCardTransmit.Call( uintptr(t.handle), uintptr(scardPCIT1), @@ -193,6 +217,11 @@ func (t *scTx) transmit(req []byte) (more bool, b []byte, err error) { } sw1 := resp[respN-2] sw2 := resp[respN-1] + + if t.trace != nil && t.trace.TransmitResult != nil { + t.trace.TransmitResult(req[:], resp[:respN], int(respN), sw1, sw2) + } + if sw1 == 0x90 && sw2 == 0x00 { return false, resp[:respN-2], nil } diff --git a/piv/piv.go b/piv/piv.go index a351665..4bd6dee 100644 --- a/piv/piv.go +++ b/piv/piv.go @@ -113,6 +113,8 @@ type YubiKey struct { // YubiKey's version or PIV version? A NEO reports v1.0.4. Figure this out // before exposing an API. version *version + + trace *ClientTrace } // Close releases the connection to the smart card. @@ -125,7 +127,28 @@ func (yk *YubiKey) Close() error { return err1 } -// Open connects to a YubiKey smart card. +// WithClientTrace can be passed an instance of ClientTrace to trace the apdu's sent. +func (yk *YubiKey) WithClientTrace(clientTrace *ClientTrace) { + yk.trace = clientTrace + yk.ctx.WithClientTrace(clientTrace) + yk.h.WithClientTrace(clientTrace) + yk.tx.WithClientTrace(clientTrace) +} + +// Client allows a yubikey to be opened with tracing. +type Client struct { + Trace *ClientTrace +} + +func (p *Client) Open(card string) (*YubiKey, error) { + c := client{ + Rand: nil, + trace: p.Trace, + } + + return c.Open(card) +} + func Open(card string) (*YubiKey, error) { var c client return c.Open(card) @@ -137,11 +160,12 @@ type client struct { // Rand is a cryptographic source of randomness used for card challenges. // // If nil, defaults to crypto.Rand. - Rand io.Reader + Rand io.Reader + trace *ClientTrace } func (c *client) Cards() ([]string, error) { - ctx, err := newSCContext() + ctx, err := newSCContext(c.trace) if err != nil { return nil, fmt.Errorf("connecting to pcsc: %w", err) } @@ -150,7 +174,7 @@ func (c *client) Cards() ([]string, error) { } func (c *client) Open(card string) (*YubiKey, error) { - ctx, err := newSCContext() + ctx, err := newSCContext(c.trace) if err != nil { return nil, fmt.Errorf("connecting to smart card daemon: %w", err) } @@ -164,6 +188,7 @@ func (c *client) Open(card string) (*YubiKey, error) { if err != nil { return nil, fmt.Errorf("beginning smart card transaction: %w", err) } + if err := ykSelectApplication(tx, aidPIV[:]); err != nil { tx.Close() return nil, fmt.Errorf("selecting piv applet: %w", err)