From 80b30e987a9089d38872e4a7095db55cf450c995 Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 11 Mar 2024 10:47:25 +0100 Subject: [PATCH 1/6] Update pipe visibility Previously, some internals of the pipe module were public, while parts of the public interface where only pub(crate). This patch fixes visibility in the pipe module so that internals are private and the public API is public. It also fixes some clippy lints on the way. --- src/class.rs | 6 ++--- src/pipe.rs | 69 ++++++++++++++++++++++++++++------------------------ 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/class.rs b/src/class.rs index 720a532..75ee101 100644 --- a/src/class.rs +++ b/src/class.rs @@ -95,19 +95,19 @@ where /// Indicate in INIT response that Wink command is implemented. pub fn implements_wink(mut self) -> Self { - self.pipe.implements |= 0x01; + self.pipe.set_implements(self.pipe.implements() | 0x01); self } /// Indicate in INIT response that RawMsg command is implemented. pub fn implements_ctap1(mut self) -> Self { - self.pipe.implements &= !0x80; + self.pipe.set_implements(self.pipe.implements() & !0x80); self } /// Indicate in INIT response that Cbor command is implemented. pub fn implements_ctap2(mut self) -> Self { - self.pipe.implements |= 0x04; + self.pipe.set_implements(self.pipe.implements() | 0x04); self } diff --git a/src/pipe.rs b/src/pipe.rs index a4b6834..701777a 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -12,14 +12,12 @@ receive busy errors). No state is maintained between transactions. */ -use core::convert::TryFrom; -use core::convert::TryInto; use core::sync::atomic::Ordering; // pub type ContactInterchange = usbd_ccid::types::ApduInterchange; // pub type ContactlessInterchange = iso14443::types::ApduInterchange; use ctaphid_dispatch::command::Command; -use ctaphid_dispatch::types::Requester; +use ctaphid_dispatch::types::{Error as DispatchError, Requester}; use ctap_types::Error as AuthenticatorError; use trussed::interrupt::InterruptFlag; @@ -41,11 +39,12 @@ use crate::{ PACKET_SIZE, }, types::KeepaliveStatus, + Version, }; /// The actual payload of given length is dealt with separately #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct Request { +struct Request { channel: u32, command: Command, length: u16, @@ -54,14 +53,14 @@ pub struct Request { /// The actual payload of given length is dealt with separately #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct Response { +struct Response { channel: u32, command: Command, length: u16, } impl Response { - pub fn from_request_and_size(request: Request, size: usize) -> Self { + fn from_request_and_size(request: Request, size: usize) -> Self { Self { channel: request.channel, command: request.command, @@ -69,21 +68,21 @@ impl Response { } } - pub fn error_from_request(request: Request) -> Self { + fn error_from_request(request: Request) -> Self { Self::error_on_channel(request.channel) } - pub fn error_on_channel(channel: u32) -> Self { + fn error_on_channel(channel: u32) -> Self { Self { channel, - command: ctaphid_dispatch::command::Command::Error, + command: Command::Error, length: 1, } } } #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct MessageState { +struct MessageState { // sequence number of next continuation packet next_sequence: u8, // number of bytes of message payload transmitted so far @@ -108,8 +107,7 @@ impl MessageState { } #[derive(Clone, Debug, Eq, PartialEq)] -#[allow(unused)] -pub enum State { +enum State { Idle, // if request payload data is larger than one packet @@ -144,21 +142,21 @@ pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus> { last_channel: u32, // Indicator of implemented commands in INIT response. - pub(crate) implements: u8, + implements: u8, // timestamp that gets used for timing out CID's - pub(crate) last_milliseconds: u32, + last_milliseconds: u32, // a "read once" indicator if now we're waiting on the application processing started_processing: bool, needs_keepalive: bool, - pub(crate) version: crate::Version, + version: Version, } impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> { - pub(crate) fn new( + pub fn new( read_endpoint: EndpointOut<'alloc, Bus>, write_endpoint: EndpointIn<'alloc, Bus>, interchange: Requester<'pipe>, @@ -187,7 +185,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // &mut self.authenticator // } - pub(crate) fn with_interrupt( + pub fn with_interrupt( read_endpoint: EndpointOut<'alloc, Bus>, write_endpoint: EndpointIn<'alloc, Bus>, interchange: Requester<'pipe>, @@ -211,7 +209,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } } - pub(crate) fn set_version(&mut self, version: crate::Version) { + pub fn implements(&self) -> u8 { + self.implements + } + + pub fn set_implements(&mut self, implements: u8) { + self.implements = implements; + } + + pub fn set_version(&mut self, version: Version) { self.version = version; } @@ -224,12 +230,12 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } // used to generate the configuration descriptors - pub(crate) fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> { + pub fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> { &self.read_endpoint } // used to generate the configuration descriptors - pub(crate) fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> { + pub fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> { &self.write_endpoint } @@ -247,7 +253,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus /// a CTAP message, with which it then calls `dispatch_message`. /// /// During these calls, we can be in states: Idle, Receiving, Dispatching. - pub(crate) fn read_and_handle_packet(&mut self) { + pub fn read_and_handle_packet(&mut self) { // info_now!("got a packet!"); let mut packet = [0u8; PACKET_SIZE]; match self.read_endpoint.read(&mut packet) { @@ -293,7 +299,10 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // `solo ls` crashes here as it uses command 0x86 Err(_) => { info!("Received invalid command."); - self.start_sending_error_on_channel(channel, AuthenticatorError::InvalidCommand); + self.start_sending_error_on_channel( + channel, + AuthenticatorError::InvalidCommand, + ); return; } }; @@ -503,11 +512,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } _ => { - if request.command == Command::Cbor { - self.needs_keepalive = true; - } else { - self.needs_keepalive = false; - } + self.needs_keepalive = request.command == Command::Cbor; if self.interchange.state() == interchange::State::Responded { info!("dumping stale response"); self.interchange.take_response(); @@ -576,15 +581,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus if let State::WaitingOnAuthenticator(request) = self.state { if let Ok(response) = self.interchange.response() { match &response.0 { - Err(ctaphid_dispatch::app::Error::InvalidCommand) => { + Err(DispatchError::InvalidCommand) => { info!("Got waiting reply from authenticator??"); self.start_sending_error(request, AuthenticatorError::InvalidCommand); } - Err(ctaphid_dispatch::app::Error::InvalidLength) => { + Err(DispatchError::InvalidLength) => { info!("Error, payload needed app command."); self.start_sending_error(request, AuthenticatorError::InvalidLength); } - Err(ctaphid_dispatch::app::Error::NoResponse) => { + Err(DispatchError::NoResponse) => { info!("Got waiting noresponse from authenticator??"); } @@ -602,7 +607,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus message.len() ); let response = Response::from_request_and_size(request, message.len()); - self.buffer[..message.len()].copy_from_slice(&message); + self.buffer[..message.len()].copy_from_slice(message); self.start_sending(response); } } @@ -641,7 +646,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // called from poll, and when a packet has been sent #[inline(never)] - pub(crate) fn maybe_write_packet(&mut self) { + pub fn maybe_write_packet(&mut self) { match self.state { State::WaitingToSend(response) => { // zeros leftover bytes From da5dd49e8a3329428e8befe86a238822ab9ac0b8 Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 18 Mar 2024 12:16:48 +0100 Subject: [PATCH 2/6] pipe: Move read and write handling into helper class --- src/pipe.rs | 303 ++++++++++++++++++++++++++-------------------------- 1 file changed, 153 insertions(+), 150 deletions(-) diff --git a/src/pipe.rs b/src/pipe.rs index 701777a..85b0b7a 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -100,9 +100,11 @@ impl Default for MessageState { impl MessageState { // update state due to receiving a full new continuation packet - pub fn absorb_packet(&mut self) { + #[must_use] + pub fn absorb_packet(mut self) -> Self { self.next_sequence += 1; self.transmitted += PACKET_SIZE - 5; + self } } @@ -127,8 +129,7 @@ enum State { } pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus> { - read_endpoint: EndpointOut<'alloc, Bus>, - write_endpoint: EndpointIn<'alloc, Bus>, + endpoints: Endpoints<'alloc, Bus>, state: State, interchange: Requester<'pipe>, @@ -163,8 +164,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus initial_milliseconds: u32, ) -> Self { Self { - read_endpoint, - write_endpoint, + endpoints: Endpoints::new(read_endpoint, write_endpoint), state: State::Idle, interchange, buffer: [0u8; MESSAGE_SIZE], @@ -193,8 +193,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus initial_milliseconds: u32, ) -> Self { Self { - read_endpoint, - write_endpoint, + endpoints: Endpoints::new(read_endpoint, write_endpoint), state: State::Idle, interchange, buffer: [0u8; MESSAGE_SIZE], @@ -222,21 +221,21 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } pub fn read_address(&self) -> EndpointAddress { - self.read_endpoint.address() + self.endpoints.read.address() } pub fn write_address(&self) -> EndpointAddress { - self.write_endpoint.address() + self.endpoints.write.address() } // used to generate the configuration descriptors pub fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> { - &self.read_endpoint + &self.endpoints.read } // used to generate the configuration descriptors pub fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> { - &self.write_endpoint + &self.endpoints.write } fn cancel_ongoing_activity(&mut self) { @@ -256,27 +255,9 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus pub fn read_and_handle_packet(&mut self) { // info_now!("got a packet!"); let mut packet = [0u8; PACKET_SIZE]; - match self.read_endpoint.read(&mut packet) { - Ok(PACKET_SIZE) => {} - Ok(_size) => { - // error handling? - // from spec: "Packets are always fixed size (defined by the endpoint and - // HID report descriptors) and although all bytes may not be needed in a - // particular packet, the full size always has to be sent. - // Unused bytes SHOULD be set to zero." - // !("OK but size {}", size); - info!("error unexpected size {}", _size); - return; - } - // usb-device lists WouldBlock or BufferOverflow as possible errors. - // both should not occur here, and we can't do anything anyway. - // Err(UsbError::WouldBlock) => { return; }, - // Err(UsbError::BufferOverflow) => { return; }, - Err(_error) => { - info!("error no {}", _error as i32); - return; - } - }; + if self.endpoints.read(&mut packet).is_err() { + return; + } info!(">> "); info!("{}", hex_str!(&packet[..16])); @@ -368,7 +349,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } else { // case of continuation packet match self.state { - State::Receiving((request, mut message_state)) => { + State::Receiving((request, message_state)) => { let sequence = packet[4]; // info_now!("receiving continuation packet {}", sequence); if sequence != message_state.next_sequence { @@ -394,7 +375,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // store received part of payload self.buffer[message_state.transmitted..][..PACKET_SIZE - 5] .copy_from_slice(&packet[5..]); - message_state.absorb_packet(); + let message_state = message_state.absorb_packet(); self.state = State::Receiving((request, message_state)); // info_now!("absorbed packet, awaiting next"); } else { @@ -554,19 +535,19 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } else { info!("keepalive"); - let mut packet = [0u8; PACKET_SIZE]; - - packet[..4].copy_from_slice(&request.channel.to_be_bytes()); - packet[4] = 0x80 | 0x3B; - packet[5..7].copy_from_slice(&1u16.to_be_bytes()); - - if is_waiting_for_user_presence { - packet[7] = KeepaliveStatus::UpNeeded as u8; + let response = Response { + channel: request.channel, + command: Command::KeepAlive, + length: 1, + }; + let status = if is_waiting_for_user_presence { + KeepaliveStatus::UpNeeded } else { - packet[7] = KeepaliveStatus::Processing as u8; - } - - self.write_endpoint.write(&packet).ok(); + KeepaliveStatus::Processing + }; + self.endpoints + .write(Packet::init(response, &[status as u8])) + .ok(); true } @@ -647,116 +628,138 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // called from poll, and when a packet has been sent #[inline(never)] pub fn maybe_write_packet(&mut self) { - match self.state { - State::WaitingToSend(response) => { - // zeros leftover bytes - let mut packet = [0u8; PACKET_SIZE]; - packet[..4].copy_from_slice(&response.channel.to_be_bytes()); - // packet[4] = response.command.into() | 0x80u8; - packet[4] = response.command.into_u8() | 0x80; - packet[5..7].copy_from_slice(&response.length.to_be_bytes()); - - let fits_in_one_packet = 7 + response.length as usize <= PACKET_SIZE; - if fits_in_one_packet { - packet[7..][..response.length as usize] - .copy_from_slice(&self.buffer[..response.length as usize]); - self.state = State::Idle; - } else { - packet[7..].copy_from_slice(&self.buffer[..PACKET_SIZE - 7]); - } + let packet = match self.state { + State::WaitingToSend(response) => Packet::init(response, &self.buffer), + State::Sending((response, message_state)) => { + Packet::cont(response, message_state, &self.buffer) + } + // nothing to send + _ => { + return; + } + }; + if self.endpoints.write(packet).is_ok() { + self.state = packet.next_state(); + } + } +} - // try actually sending - // info_now!("attempting to write init packet {:?}, {:?}", - // &packet[..32], &packet[32..]); - let result = self.write_endpoint.write(&packet); +#[derive(Clone, Copy, Debug)] +struct Packet<'a> { + response: Response, + message_state: Option, + buffer: &'a [u8], +} - match result { - Err(UsbError::WouldBlock) => { - // fine, can't write try later - // this shouldn't happen probably - info!("hid usb WouldBlock"); - } - Err(_) => { - // info_now!("weird USB errrorrr"); - panic!("unexpected error writing packet!"); - } - Ok(PACKET_SIZE) => { - // goodie, this worked - if fits_in_one_packet { - self.state = State::Idle; - // info_now!("StartSent {} bytes, idle again", response.length); - // info_now!("IDLE again"); - } else { - self.state = State::Sending((response, MessageState::default())); - // info_now!( - // "StartSent {} of {} bytes, waiting to send again", - // PACKET_SIZE - 7, response.length); - // info_now!("State: {:?}", &self.state); - } - } - Ok(_) => { - // info_now!("short write"); - panic!("unexpected size writing packet!"); - } - }; - } +impl<'a> Packet<'a> { + fn init(response: Response, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: None, + buffer, + } + } - State::Sending((response, mut message_state)) => { - // info_now!("in StillSending"); - let mut packet = [0u8; PACKET_SIZE]; - packet[..4].copy_from_slice(&response.channel.to_be_bytes()); - packet[4] = message_state.next_sequence; - - let sent = message_state.transmitted; - let remaining = response.length as usize - sent; - let last_packet = 5 + remaining <= PACKET_SIZE; - if last_packet { - packet[5..][..remaining] - .copy_from_slice(&self.buffer[message_state.transmitted..][..remaining]); - } else { - packet[5..].copy_from_slice( - &self.buffer[message_state.transmitted..][..PACKET_SIZE - 5], - ); - } + fn cont(response: Response, message_state: MessageState, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: Some(message_state), + buffer, + } + } - // try actually sending - // info_now!("attempting to write cont packet {:?}, {:?}", - // &packet[..32], &packet[32..]); - let result = self.write_endpoint.write(&packet); - - match result { - Err(UsbError::WouldBlock) => { - // fine, can't write try later - // this shouldn't happen probably - // info_now!("can't send seq {}, write endpoint busy", - // message_state.next_sequence); - } - Err(_) => { - // info_now!("weird USB error"); - panic!("unexpected error writing packet!"); - } - Ok(PACKET_SIZE) => { - // goodie, this worked - if last_packet { - self.state = State::Idle; - // info_now!("in IDLE state after {:?}", &message_state); - } else { - message_state.absorb_packet(); - // DANGER! destructuring in the match arm copies out - // message state, so need to update state - // info_now!("sent one more, now {:?}", &message_state); - self.state = State::Sending((response, message_state)); - } - } - Ok(_) => { - debug!("short write"); - panic!("unexpected size writing packet!"); - } - }; + fn has_more(&self) -> bool { + if let Some(message_state) = self.message_state { + let remaining = usize::from(self.response.length) - message_state.transmitted; + remaining > PACKET_SIZE - 5 + } else { + usize::from(self.response.length) > PACKET_SIZE - 7 + } + } + + fn next_state(&self) -> State { + if self.has_more() { + let message_state = self + .message_state + .map(MessageState::absorb_packet) + .unwrap_or_default(); + State::Sending((self.response, message_state)) + } else { + State::Idle + } + } + + fn serialize(&self, buffer: &mut [u8; PACKET_SIZE]) { + // buffer must be zeroed + buffer[..4].copy_from_slice(&self.response.channel.to_be_bytes()); + if let Some(message_state) = self.message_state { + buffer[4] = message_state.next_sequence; + let remaining = usize::from(self.response.length) - message_state.transmitted; + let n = remaining.min(PACKET_SIZE - 5); + buffer[5..][..n].copy_from_slice(&self.buffer[message_state.transmitted..][..n]); + } else { + buffer[4] = self.response.command.into_u8() | 0x80; + buffer[5..7].copy_from_slice(&self.response.length.to_be_bytes()); + let n = usize::from(self.response.length).min(PACKET_SIZE - 7); + buffer[7..][..n].copy_from_slice(&self.buffer[..n]); + } + } +} + +struct Endpoints<'a, Bus: UsbBus> { + read: EndpointOut<'a, Bus>, + write: EndpointIn<'a, Bus>, +} + +impl<'a, Bus: UsbBus> Endpoints<'a, Bus> { + fn new(read: EndpointOut<'a, Bus>, write: EndpointIn<'a, Bus>) -> Self { + Self { read, write } + } + + fn read(&mut self, packet: &mut [u8; PACKET_SIZE]) -> Result<(), ()> { + match self.read.read(packet) { + Ok(PACKET_SIZE) => Ok(()), + Ok(_size) => { + // error handling? + // from spec: "Packets are always fixed size (defined by the endpoint and + // HID report descriptors) and although all bytes may not be needed in a + // particular packet, the full size always has to be sent. + // Unused bytes SHOULD be set to zero." + // !("OK but size {}", size); + info!("error unexpected size {}", _size); + Err(()) } + // usb-device lists WouldBlock or BufferOverflow as possible errors. + // both should not occur here, and we can't do anything anyway. + // Err(UsbError::WouldBlock) => { return; }, + // Err(UsbError::BufferOverflow) => { return; }, + Err(_error) => { + info!("error no {}", _error as i32); + Err(()) + } + } + } - // nothing to send - _ => {} + fn write(&mut self, packet: Packet<'_>) -> Result<(), ()> { + // zeros leftover bytes + let mut buffer = [0u8; PACKET_SIZE]; + packet.serialize(&mut buffer); + match self.write.write(&buffer) { + Ok(PACKET_SIZE) => Ok(()), + Ok(_) => { + error!("short write"); + panic!("unexpected size writing packet!"); + } + Err(UsbError::WouldBlock) => { + // fine, can't write try later + // this shouldn't happen probably + info!("hid usb WouldBlock"); + Err(()) + } + Err(_) => { + // info_now!("weird USB error"); + panic!("unexpected error writing packet!"); + } } } } From 529fcfe2a96585e904ab7e1c3cb95496d69bf602 Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 18 Mar 2024 12:21:09 +0100 Subject: [PATCH 3/6] pipe: Directly send error in send_error_now Instead of resetting the pipe state after sending the error, we can direclty write it to the endpoint as it will always fit into one init packet. --- src/pipe.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/pipe.rs b/src/pipe.rs index 85b0b7a..341b2d5 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -613,16 +613,11 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } fn send_error_now(&mut self, request: Request, error: AuthenticatorError) { - let last_state = core::mem::replace(&mut self.state, State::Idle); - let last_first_byte = self.buffer[0]; - - self.buffer[0] = error as u8; let response = Response::error_from_request(request); - self.start_sending(response); - self.maybe_write_packet(); - - self.state = last_state; - self.buffer[0] = last_first_byte; + // TODO: should we block? + self.endpoints + .write(Packet::init(response, &[error as u8])) + .ok(); } // called from poll, and when a packet has been sent From 8e2c8b02a5ae4daed0b59766cfeb531ef5363025 Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 18 Mar 2024 14:19:29 +0100 Subject: [PATCH 4/6] Use Result to collect errors --- src/class.rs | 3 +- src/pipe.rs | 144 +++++++++++++++++++++++++++++++++------------------ 2 files changed, 96 insertions(+), 51 deletions(-) diff --git a/src/class.rs b/src/class.rs index 75ee101..79445be 100644 --- a/src/class.rs +++ b/src/class.rs @@ -253,8 +253,7 @@ where #[inline(never)] fn poll(&mut self) { // debug!("state = {:?}", self.pipe().state); - self.pipe.handle_response(); - self.pipe.maybe_write_packet(); + self.pipe.handle_and_write_response(); } // called when endpoint with given address received a packet diff --git a/src/pipe.rs b/src/pipe.rs index 341b2d5..1dbd064 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -51,6 +51,24 @@ struct Request { timestamp: u32, } +impl Request { + fn error(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: false, + } + } + + fn error_now(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: true, + } + } +} + /// The actual payload of given length is dealt with separately #[derive(Copy, Clone, Debug, Eq, PartialEq)] struct Response { @@ -68,10 +86,6 @@ impl Response { } } - fn error_from_request(request: Request) -> Self { - Self::error_on_channel(request.channel) - } - fn error_on_channel(channel: u32) -> Self { Self { channel, @@ -81,6 +95,22 @@ impl Response { } } +struct PipeError { + channel: u32, + error: AuthenticatorError, + keep_state: bool, +} + +impl PipeError { + fn on_channel(channel: u32, error: AuthenticatorError) -> Self { + Self { + channel, + error, + keep_state: false, + } + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq)] struct MessageState { // sequence number of next continuation packet @@ -255,9 +285,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus pub fn read_and_handle_packet(&mut self) { // info_now!("got a packet!"); let mut packet = [0u8; PACKET_SIZE]; - if self.endpoints.read(&mut packet).is_err() { - return; + if self.endpoints.read(&mut packet).is_ok() { + match self.handle_packet(&packet) { + Ok(()) => (), + Err(error) => self.send_error(error), + } } + } + + fn handle_packet(&mut self, packet: &[u8; 64]) -> Result<(), PipeError> { info!(">> "); info!("{}", hex_str!(&packet[..16])); @@ -280,11 +316,10 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // `solo ls` crashes here as it uses command 0x86 Err(_) => { info!("Received invalid command."); - self.start_sending_error_on_channel( + return Err(PipeError::on_channel( channel, AuthenticatorError::InvalidCommand, - ); - return; + )); } }; @@ -305,34 +340,32 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus State::Receiving((request, _message_state)) => request, _ => { info_now!("Ignoring transaction as we're already transmitting."); - return; + return Ok(()); } }; if packet[4] == 0x86 { info_now!("Resyncing!"); self.cancel_ongoing_activity(); } else { - if channel == request.channel { + return if channel == request.channel { if command == Command::Cancel { info_now!("Cancelling"); self.cancel_ongoing_activity(); + Ok(()) } else { info_now!("Expected seq, {:?}", request.command); - self.start_sending_error(request, AuthenticatorError::InvalidSeq); + Err(request.error(AuthenticatorError::InvalidSeq)) } } else { info_now!("busy."); - self.send_error_now(current_request, AuthenticatorError::ChannelBusy); - } - - return; + Err(current_request.error_now(AuthenticatorError::ChannelBusy)) + }; } } if length > MESSAGE_SIZE as u16 { info!("Error message too big."); - self.send_error_now(current_request, AuthenticatorError::InvalidLength); - return; + return Err(current_request.error_now(AuthenticatorError::InvalidLength)); } if length > PACKET_SIZE as u16 - 7 { @@ -341,10 +374,11 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); self.state = State::Receiving((current_request, { MessageState::default() })); // we're done... wait for next packet + Ok(()) } else { // request fits in one packet self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); - self.dispatch_request(current_request); + self.dispatch_request(current_request) } } else { // case of continuation packet @@ -357,15 +391,14 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // info_now!("wrong sequence for continuation packet, expected {} received {}", // message_state.next_sequence, sequence); info!("Error invalid cont pkt"); - self.start_sending_error(request, AuthenticatorError::InvalidSeq); - return; + return Err(request.error(AuthenticatorError::InvalidSeq)); } if channel != request.channel { // error handling? // info_now!("wrong channel for continuation packet, expected {} received {}", // request.channel, channel); info!("Ignore invalid channel"); - return; + return Ok(()); } let payload_length = request.length as usize; @@ -378,16 +411,18 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus let message_state = message_state.absorb_packet(); self.state = State::Receiving((request, message_state)); // info_now!("absorbed packet, awaiting next"); + Ok(()) } else { let missing = request.length as usize - message_state.transmitted; self.buffer[message_state.transmitted..payload_length] .copy_from_slice(&packet[5..][..missing]); - self.dispatch_request(request); + self.dispatch_request(request) } } _ => { // unexpected continuation packet info!("Ignore unexpected cont pkt"); + Ok(()) } } } @@ -418,19 +453,18 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus request.timestamp, milliseconds, last ); let req = *request; - self.start_sending_error(req, AuthenticatorError::Timeout); + self.send_error(req.error(AuthenticatorError::Timeout)); } } } - fn dispatch_request(&mut self, request: Request) { + fn dispatch_request(&mut self, request: Request) -> Result<(), PipeError> { info!("Got request: {:?}", request.command); match request.command { Command::Init => {} _ => { if request.channel == 0xffffffff { - self.start_sending_error(request, AuthenticatorError::InvalidChannel); - return; + return Err(request.error(AuthenticatorError::InvalidChannel)); } } } @@ -442,7 +476,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus match request.channel { 0 => { // this is an error / reserved number - self.start_sending_error(request, AuthenticatorError::InvalidChannel); + Err(request.error(AuthenticatorError::InvalidChannel)) } // broadcast channel ID - request for assignment @@ -478,6 +512,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus self.buffer[16] = self.implements; self.start_sending(response); } + Ok(()) } } } @@ -485,11 +520,13 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus Command::Ping => { let response = Response::from_request_and_size(request, request.length as usize); self.start_sending(response); + Ok(()) } Command::Cancel => { info!("CTAPHID_CANCEL"); self.cancel_ongoing_activity(); + Ok(()) } _ => { @@ -505,12 +542,13 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus Ok(_) => { self.state = State::WaitingOnAuthenticator(request); self.started_processing = true; + Ok(()) } Err(_) => { // busy info_now!("STATE: {:?}", self.interchange.state()); info!("can't handle more than one authenticator request at a time."); - self.send_error_now(request, AuthenticatorError::ChannelBusy); + Err(request.error_now(AuthenticatorError::ChannelBusy)) } } } @@ -557,21 +595,29 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } } + pub fn handle_and_write_response(&mut self) { + if let Err(err) = self.handle_response() { + self.send_error(err); + } + self.maybe_write_packet(); + } + #[inline(never)] - pub fn handle_response(&mut self) { + fn handle_response(&mut self) -> Result<(), PipeError> { if let State::WaitingOnAuthenticator(request) = self.state { if let Ok(response) = self.interchange.response() { match &response.0 { Err(DispatchError::InvalidCommand) => { info!("Got waiting reply from authenticator??"); - self.start_sending_error(request, AuthenticatorError::InvalidCommand); + Err(request.error(AuthenticatorError::InvalidCommand)) } Err(DispatchError::InvalidLength) => { info!("Error, payload needed app command."); - self.start_sending_error(request, AuthenticatorError::InvalidLength); + Err(request.error(AuthenticatorError::InvalidLength)) } Err(DispatchError::NoResponse) => { info!("Got waiting noresponse from authenticator??"); + Ok(()) } Ok(message) => { @@ -581,7 +627,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus message.len(), self.buffer.len(), ); - self.start_sending_error(request, AuthenticatorError::InvalidLength); + Err(request.error(AuthenticatorError::InvalidLength)) } else { info!( "Got {} bytes response from authenticator, starting send", @@ -590,10 +636,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus let response = Response::from_request_and_size(request, message.len()); self.buffer[..message.len()].copy_from_slice(message); self.start_sending(response); + Ok(()) } } } + } else { + Ok(()) } + } else { + Ok(()) } } @@ -602,22 +653,17 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus self.maybe_write_packet(); } - fn start_sending_error(&mut self, request: Request, error: AuthenticatorError) { - self.start_sending_error_on_channel(request.channel, error); - } - - fn start_sending_error_on_channel(&mut self, channel: u32, error: AuthenticatorError) { - self.buffer[0] = error as u8; - let response = Response::error_on_channel(channel); - self.start_sending(response); - } - - fn send_error_now(&mut self, request: Request, error: AuthenticatorError) { - let response = Response::error_from_request(request); - // TODO: should we block? - self.endpoints - .write(Packet::init(response, &[error as u8])) - .ok(); + fn send_error(&mut self, error: PipeError) { + let response = Response::error_on_channel(error.channel); + if error.keep_state { + // TODO: should we block? + self.endpoints + .write(Packet::init(response, &[error.error as u8])) + .ok(); + } else { + self.buffer[0] = error.error as u8; + self.start_sending(response); + } } // called from poll, and when a packet has been sent From 16200030313b92e0d0d0aa68ae452b3c51846dff Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 18 Mar 2024 15:20:31 +0100 Subject: [PATCH 5/6] pipe: Return response instead of sending directly --- src/pipe.rs | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/pipe.rs b/src/pipe.rs index 1dbd064..a3dc019 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -287,13 +287,14 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus let mut packet = [0u8; PACKET_SIZE]; if self.endpoints.read(&mut packet).is_ok() { match self.handle_packet(&packet) { - Ok(()) => (), + Ok(Some(response)) => self.start_sending(response), + Ok(None) => (), Err(error) => self.send_error(error), } } } - fn handle_packet(&mut self, packet: &[u8; 64]) -> Result<(), PipeError> { + fn handle_packet(&mut self, packet: &[u8; 64]) -> Result, PipeError> { info!(">> "); info!("{}", hex_str!(&packet[..16])); @@ -340,7 +341,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus State::Receiving((request, _message_state)) => request, _ => { info_now!("Ignoring transaction as we're already transmitting."); - return Ok(()); + return Ok(None); } }; if packet[4] == 0x86 { @@ -351,7 +352,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus if command == Command::Cancel { info_now!("Cancelling"); self.cancel_ongoing_activity(); - Ok(()) + Ok(None) } else { info_now!("Expected seq, {:?}", request.command); Err(request.error(AuthenticatorError::InvalidSeq)) @@ -374,7 +375,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); self.state = State::Receiving((current_request, { MessageState::default() })); // we're done... wait for next packet - Ok(()) + Ok(None) } else { // request fits in one packet self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); @@ -398,7 +399,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // info_now!("wrong channel for continuation packet, expected {} received {}", // request.channel, channel); info!("Ignore invalid channel"); - return Ok(()); + return Ok(None); } let payload_length = request.length as usize; @@ -411,7 +412,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus let message_state = message_state.absorb_packet(); self.state = State::Receiving((request, message_state)); // info_now!("absorbed packet, awaiting next"); - Ok(()) + Ok(None) } else { let missing = request.length as usize - message_state.transmitted; self.buffer[message_state.transmitted..payload_length] @@ -422,7 +423,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus _ => { // unexpected continuation packet info!("Ignore unexpected cont pkt"); - Ok(()) + Ok(None) } } } @@ -458,7 +459,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } } - fn dispatch_request(&mut self, request: Request) -> Result<(), PipeError> { + fn dispatch_request(&mut self, request: Request) -> Result, PipeError> { info!("Got request: {:?}", request.command); match request.command { Command::Init => {} @@ -484,6 +485,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus if request.length != 8 { // error info!("Invalid length for init. ignore."); + Ok(None) } else { self.last_channel += 1; // info_now!( @@ -510,23 +512,21 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // 0x8: does not implement MSG // self.buffer[16] = 0x01 | 0x08; self.buffer[16] = self.implements; - self.start_sending(response); + Ok(Some(response)) } - Ok(()) } } } Command::Ping => { let response = Response::from_request_and_size(request, request.length as usize); - self.start_sending(response); - Ok(()) + Ok(Some(response)) } Command::Cancel => { info!("CTAPHID_CANCEL"); self.cancel_ongoing_activity(); - Ok(()) + Ok(None) } _ => { @@ -542,7 +542,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus Ok(_) => { self.state = State::WaitingOnAuthenticator(request); self.started_processing = true; - Ok(()) + Ok(None) } Err(_) => { // busy @@ -596,14 +596,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } pub fn handle_and_write_response(&mut self) { - if let Err(err) = self.handle_response() { - self.send_error(err); + match self.handle_response() { + Ok(Some(response)) => self.start_sending(response), + Ok(None) => (), + Err(error) => self.send_error(error), } - self.maybe_write_packet(); } #[inline(never)] - fn handle_response(&mut self) -> Result<(), PipeError> { + fn handle_response(&mut self) -> Result, PipeError> { if let State::WaitingOnAuthenticator(request) = self.state { if let Ok(response) = self.interchange.response() { match &response.0 { @@ -617,7 +618,7 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus } Err(DispatchError::NoResponse) => { info!("Got waiting noresponse from authenticator??"); - Ok(()) + Ok(None) } Ok(message) => { @@ -635,16 +636,15 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus ); let response = Response::from_request_and_size(request, message.len()); self.buffer[..message.len()].copy_from_slice(message); - self.start_sending(response); - Ok(()) + Ok(Some(response)) } } } } else { - Ok(()) + Ok(None) } } else { - Ok(()) + Ok(None) } } From 997311da4c370c347a7fc6ae3f5df2582383768e Mon Sep 17 00:00:00 2001 From: Robin Krahl Date: Mon, 18 Mar 2024 15:48:49 +0100 Subject: [PATCH 6/6] pipe: Introduce Buffer type for message buffering --- src/buffer.rs | 673 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/pipe.rs | 662 +++---------------------------------------------- 3 files changed, 703 insertions(+), 633 deletions(-) create mode 100644 src/buffer.rs diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..b5d7bba --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,673 @@ +use crate::{ + constants::{ + // 3072 + MESSAGE_SIZE, + // 64 + PACKET_SIZE, + }, + types::KeepaliveStatus, + Version, +}; +use core::sync::atomic::Ordering; +use ctap_types::Error as AuthenticatorError; +use ctaphid_dispatch::command::Command; +use ctaphid_dispatch::types::{Error as DispatchError, Requester}; +use ref_swap::OptionRefSwap; +use trussed::interrupt::InterruptFlag; + +/// The actual payload of given length is dealt with separately +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct Request { + channel: u32, + command: Command, + length: u16, + timestamp: u32, +} + +impl Request { + fn error(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: false, + } + } + + fn error_now(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: true, + } + } +} + +/// The actual payload of given length is dealt with separately +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct Response { + channel: u32, + command: Command, + length: u16, +} + +impl Response { + fn from_request_and_size(request: Request, size: usize) -> Self { + Self { + channel: request.channel, + command: request.command, + length: size as u16, + } + } + + fn error_on_channel(channel: u32) -> Self { + Self { + channel, + command: Command::Error, + length: 1, + } + } +} + +struct PipeError { + channel: u32, + error: AuthenticatorError, + keep_state: bool, +} + +impl PipeError { + fn on_channel(channel: u32, error: AuthenticatorError) -> Self { + Self { + channel, + error, + keep_state: false, + } + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct MessageState { + // sequence number of next continuation packet + next_sequence: u8, + // number of bytes of message payload transmitted so far + transmitted: usize, +} + +impl Default for MessageState { + fn default() -> Self { + Self { + next_sequence: 0, + transmitted: PACKET_SIZE - 7, + } + } +} + +impl MessageState { + // update state due to receiving a full new continuation packet + #[must_use] + fn absorb_packet(mut self) -> Self { + self.next_sequence += 1; + self.transmitted += PACKET_SIZE - 5; + self + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum State { + Idle, + + // if request payload data is larger than one packet + Receiving((Request, MessageState)), + + // Processing(Request), + + // // the request message is ready, need to dispatch to authenticator + // Dispatching((Request, Ctap2Request)), + + // waiting for response from authenticator + WaitingOnAuthenticator(Request), + + WaitingToSend(Response), + + Sending((Response, MessageState)), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[must_use] +pub enum BufferState { + Idle, + ResponseQueued, + Error(BufferError), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct BufferError { + channel: u32, + error: u8, +} + +pub struct Buffer<'pipe, 'interrupt> { + state: State, + interchange: Requester<'pipe>, + interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, + // shared between requests and responses, due to size + buffer: [u8; MESSAGE_SIZE], + // we assign channel IDs one by one, this is the one last assigned + // TODO: move into "app" + last_channel: u32, + // Indicator of implemented commands in INIT response. + implements: u8, + // timestamp that gets used for timing out CID's + last_milliseconds: u32, + // a "read once" indicator if now we're waiting on the application processing + started_processing: bool, + needs_keepalive: bool, + version: Version, +} + +impl<'pipe, 'interrupt> Buffer<'pipe, 'interrupt> { + pub fn new( + interchange: Requester<'pipe>, + initial_milliseconds: u32, + interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, + ) -> Self { + Self { + state: State::Idle, + interchange, + interrupt, + buffer: [0; MESSAGE_SIZE], + last_channel: 0, + // Default to nothing implemented. + implements: 0x80, + last_milliseconds: initial_milliseconds, + started_processing: false, + needs_keepalive: false, + version: Default::default(), + } + } + + pub fn implements(&self) -> u8 { + self.implements + } + + pub fn set_implements(&mut self, implements: u8) { + self.implements = implements; + } + + pub fn set_version(&mut self, version: Version) { + self.version = version; + } + + fn cancel_ongoing_activity(&mut self) { + if matches!(self.state, State::WaitingOnAuthenticator(_)) { + info_now!("Interrupting request"); + if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) { + info_now!("Loaded some interrupter"); + i.interrupt(); + } + } + } + + pub fn check_timeout(&mut self, milliseconds: u32) -> BufferState { + // At any point the RP application could crash or something, + // so its up to the device to timeout those transactions. + let last = core::mem::replace(&mut self.last_milliseconds, milliseconds); + if let State::Receiving((request, message_state)) = &self.state { + if (milliseconds - last) > 200 { + // If there's a lapse in `check_timeout(...)` getting called (e.g. due to logging), + // this could lead to inaccurate timestamps on requests. So we'll + // just "forgive" requests temporarily if this happens. + debug!( + "lapse in hid check.. {} {} {}", + request.timestamp, milliseconds, last + ); + let mut request = *request; + request.timestamp = milliseconds; + self.state = State::Receiving((request, *message_state)); + BufferState::Idle + } + // compare keeping in mind of possible overflow in timestamp. + else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550) + || (milliseconds < request.timestamp && milliseconds > 550) + { + debug!( + "Channel timeout. {}, {}, {}", + request.timestamp, milliseconds, last + ); + self.send_error(request.error(AuthenticatorError::Timeout)) + } else { + BufferState::Idle + } + } else { + BufferState::Idle + } + } + + #[must_use] + pub fn send_keepalive(&self, is_waiting_for_user_presence: bool) -> Option> { + if let State::WaitingOnAuthenticator(request) = &self.state { + if !self.needs_keepalive { + // let response go out normally in idle loop + info!("cmd does not need keepalive messages"); + None + } else { + info!("keepalive"); + + let response = Response { + channel: request.channel, + command: Command::KeepAlive, + length: 1, + }; + let status = if is_waiting_for_user_presence { + &(KeepaliveStatus::UpNeeded as u8) + } else { + &(KeepaliveStatus::Processing as u8) + }; + Some(Packet::init(response, core::slice::from_ref(status))) + } + } else { + info!("keepalive done"); + None + } + } + + pub fn try_send_packet) -> Result<(), ()>>(&mut self, f: F) { + if let Some(packet) = self.packet_to_send() { + if f(packet).is_ok() { + self.state = packet.next_state(); + } + } + } + + #[must_use] + fn packet_to_send(&self) -> Option> { + match self.state { + State::WaitingToSend(response) => Some(Packet::init(response, &self.buffer)), + State::Sending((response, message_state)) => { + Some(Packet::cont(response, message_state, &self.buffer)) + } + // nothing to send + _ => None, + } + } + + pub fn handle_packet(&mut self, packet: &[u8; 64]) -> BufferState { + match self.handle_packet_impl(packet) { + Ok(Some(response)) => self.send_response(response), + Ok(None) => BufferState::Idle, + Err(error) => self.send_error(error), + } + } + + fn handle_packet_impl(&mut self, packet: &[u8; 64]) -> Result, PipeError> { + info!(">> "); + info!("{}", hex_str!(&packet[..16])); + + // packet is 64 bytes, reading 4 will not panic + let channel = u32::from_be_bytes(packet[..4].try_into().unwrap()); + // info_now!("channel {}", channel); + + let is_initialization = (packet[4] >> 7) != 0; + // info_now!("is_initialization {}", is_initialization); + + if is_initialization { + // case of initialization packet + info!("init"); + + let command_number = packet[4] & !0x80; + // info_now!("command number {}", command_number); + + let command = match Command::try_from(command_number) { + Ok(command) => command, + // `solo ls` crashes here as it uses command 0x86 + Err(_) => { + info!("Received invalid command."); + return Err(PipeError::on_channel( + channel, + AuthenticatorError::InvalidCommand, + )); + } + }; + + // can't actually fail + let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap()); + + let timestamp = self.last_milliseconds; + let current_request = Request { + channel, + command, + length, + timestamp, + }; + + if !(self.state == State::Idle) { + let request = match self.state { + State::WaitingOnAuthenticator(request) => request, + State::Receiving((request, _message_state)) => request, + _ => { + info_now!("Ignoring transaction as we're already transmitting."); + return Ok(None); + } + }; + if packet[4] == 0x86 { + info_now!("Resyncing!"); + self.cancel_ongoing_activity(); + } else { + return if channel == request.channel { + if command == Command::Cancel { + info_now!("Cancelling"); + self.cancel_ongoing_activity(); + Ok(None) + } else { + info_now!("Expected seq, {:?}", request.command); + Err(request.error(AuthenticatorError::InvalidSeq)) + } + } else { + info_now!("busy."); + Err(current_request.error_now(AuthenticatorError::ChannelBusy)) + }; + } + } + + if length > MESSAGE_SIZE as u16 { + info!("Error message too big."); + return Err(current_request.error_now(AuthenticatorError::InvalidLength)); + } + + if length > PACKET_SIZE as u16 - 7 { + // store received part of payload, + // prepare for continuation packets + self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); + self.state = State::Receiving((current_request, { MessageState::default() })); + // we're done... wait for next packet + Ok(None) + } else { + // request fits in one packet + self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); + self.dispatch_request(current_request) + } + } else { + // case of continuation packet + match self.state { + State::Receiving((request, message_state)) => { + let sequence = packet[4]; + // info_now!("receiving continuation packet {}", sequence); + if sequence != message_state.next_sequence { + // error handling? + // info_now!("wrong sequence for continuation packet, expected {} received {}", + // message_state.next_sequence, sequence); + info!("Error invalid cont pkt"); + return Err(request.error(AuthenticatorError::InvalidSeq)); + } + if channel != request.channel { + // error handling? + // info_now!("wrong channel for continuation packet, expected {} received {}", + // request.channel, channel); + info!("Ignore invalid channel"); + return Ok(None); + } + + let payload_length = request.length as usize; + if message_state.transmitted + (PACKET_SIZE - 5) < payload_length { + // info_now!("transmitted {} + (PACKET_SIZE - 5) < {}", + // message_state.transmitted, payload_length); + // store received part of payload + self.buffer[message_state.transmitted..][..PACKET_SIZE - 5] + .copy_from_slice(&packet[5..]); + let message_state = message_state.absorb_packet(); + self.state = State::Receiving((request, message_state)); + // info_now!("absorbed packet, awaiting next"); + Ok(None) + } else { + let missing = request.length as usize - message_state.transmitted; + self.buffer[message_state.transmitted..payload_length] + .copy_from_slice(&packet[5..][..missing]); + self.dispatch_request(request) + } + } + _ => { + // unexpected continuation packet + info!("Ignore unexpected cont pkt"); + Ok(None) + } + } + } + } + + fn dispatch_request(&mut self, request: Request) -> Result, PipeError> { + info!("Got request: {:?}", request.command); + match request.command { + Command::Init => {} + _ => { + if request.channel == 0xffffffff { + return Err(request.error(AuthenticatorError::InvalidChannel)); + } + } + } + // dispatch request further + match request.command { + Command::Init => { + // info_now!("command INIT!"); + // info_now!("data: {:?}", &self.buffer[..request.length as usize]); + match request.channel { + 0 => { + // this is an error / reserved number + Err(request.error(AuthenticatorError::InvalidChannel)) + } + + // broadcast channel ID - request for assignment + cid => { + if request.length != 8 { + // error + info!("Invalid length for init. ignore."); + Ok(None) + } else { + self.last_channel += 1; + // info_now!( + // "assigned channel {}", self.last_channel); + let _nonce = &self.buffer[..8]; + let response = Response { + channel: cid, + command: request.command, + length: 17, + }; + + self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes()); + // CTAPHID protocol version + self.buffer[12] = 2; + // major device version number + self.buffer[13] = self.version.major; + // minor device version number + self.buffer[14] = self.version.minor; + // build device version number + self.buffer[15] = self.version.build; + // capabilities flags + // 0x1: implements WINK + // 0x4: implements CBOR + // 0x8: does not implement MSG + // self.buffer[16] = 0x01 | 0x08; + self.buffer[16] = self.implements; + Ok(Some(response)) + } + } + } + } + + Command::Ping => { + let response = Response::from_request_and_size(request, request.length as usize); + Ok(Some(response)) + } + + Command::Cancel => { + info!("CTAPHID_CANCEL"); + self.cancel_ongoing_activity(); + Ok(None) + } + + _ => { + self.needs_keepalive = request.command == Command::Cbor; + if self.interchange.state() == interchange::State::Responded { + info!("dumping stale response"); + self.interchange.take_response(); + } + match self.interchange.request(( + request.command, + heapless::Vec::from_slice(&self.buffer[..request.length as usize]).unwrap(), + )) { + Ok(_) => { + self.state = State::WaitingOnAuthenticator(request); + self.started_processing = true; + Ok(None) + } + Err(_) => { + // busy + info_now!("STATE: {:?}", self.interchange.state()); + info!("can't handle more than one authenticator request at a time."); + Err(request.error_now(AuthenticatorError::ChannelBusy)) + } + } + } + } + } + + #[inline(never)] + pub fn handle_response(&mut self) -> BufferState { + if let State::WaitingOnAuthenticator(request) = self.state { + if let Ok(response) = self.interchange.response() { + match &response.0 { + Err(DispatchError::InvalidCommand) => { + info!("Got waiting reply from authenticator??"); + self.send_error(request.error(AuthenticatorError::InvalidCommand)) + } + Err(DispatchError::InvalidLength) => { + info!("Error, payload needed app command."); + self.send_error(request.error(AuthenticatorError::InvalidLength)) + } + Err(DispatchError::NoResponse) => { + info!("Got waiting noresponse from authenticator??"); + BufferState::Idle + } + + Ok(message) => { + if message.len() > self.buffer.len() { + error!( + "Message is longer than buffer ({} > {})", + message.len(), + self.buffer.len(), + ); + self.send_error(request.error(AuthenticatorError::InvalidLength)) + } else { + info!( + "Got {} bytes response from authenticator, starting send", + message.len() + ); + let response = Response::from_request_and_size(request, message.len()); + self.buffer[..message.len()].copy_from_slice(message); + self.send_response(response) + } + } + } + } else { + BufferState::Idle + } + } else { + BufferState::Idle + } + } + + fn send_error(&mut self, error: PipeError) -> BufferState { + let response = Response::error_on_channel(error.channel); + if error.keep_state { + BufferState::Error(BufferError { + channel: error.channel, + error: error.error as u8, + }) + } else { + self.buffer[0] = error.error as u8; + self.state = State::WaitingToSend(response); + BufferState::ResponseQueued + } + } + + fn send_response(&mut self, response: Response) -> BufferState { + self.state = State::WaitingToSend(response); + BufferState::ResponseQueued + } + + pub fn did_start_processing(&mut self) -> bool { + if self.started_processing { + self.started_processing = false; + true + } else { + false + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Packet<'a> { + response: Response, + message_state: Option, + buffer: &'a [u8], +} + +impl<'a> Packet<'a> { + fn init(response: Response, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: None, + buffer, + } + } + + fn cont(response: Response, message_state: MessageState, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: Some(message_state), + buffer, + } + } + + fn has_more(&self) -> bool { + if let Some(message_state) = self.message_state { + let remaining = usize::from(self.response.length) - message_state.transmitted; + remaining > PACKET_SIZE - 5 + } else { + usize::from(self.response.length) > PACKET_SIZE - 7 + } + } + + fn next_state(&self) -> State { + if self.has_more() { + let message_state = self + .message_state + .map(MessageState::absorb_packet) + .unwrap_or_default(); + State::Sending((self.response, message_state)) + } else { + State::Idle + } + } + + pub fn serialize(&self, buffer: &mut [u8; PACKET_SIZE]) { + // buffer must be zeroed + buffer[..4].copy_from_slice(&self.response.channel.to_be_bytes()); + if let Some(message_state) = self.message_state { + buffer[4] = message_state.next_sequence; + let remaining = usize::from(self.response.length) - message_state.transmitted; + let n = remaining.min(PACKET_SIZE - 5); + buffer[5..][..n].copy_from_slice(&self.buffer[message_state.transmitted..][..n]); + } else { + buffer[4] = self.response.command.into_u8() | 0x80; + buffer[5..7].copy_from_slice(&self.response.length.to_be_bytes()); + let n = usize::from(self.response.length).min(PACKET_SIZE - 7); + buffer[7..][..n].copy_from_slice(&self.buffer[..n]); + } + } +} + +impl<'a> From<&'a BufferError> for Packet<'a> { + fn from(error: &'a BufferError) -> Self { + let response = Response::error_on_channel(error.channel); + Self::init(response, core::slice::from_ref(&error.error)) + } +} diff --git a/src/lib.rs b/src/lib.rs index f57ec34..0922edd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ generate_macros!(); // pub mod authenticator; +pub mod buffer; pub mod class; pub mod constants; pub use class::CtapHid; diff --git a/src/pipe.rs b/src/pipe.rs index a3dc019..b887268 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -12,178 +12,24 @@ receive busy errors). No state is maintained between transactions. */ -use core::sync::atomic::Ordering; -// pub type ContactInterchange = usbd_ccid::types::ApduInterchange; -// pub type ContactlessInterchange = iso14443::types::ApduInterchange; - -use ctaphid_dispatch::command::Command; -use ctaphid_dispatch::types::{Error as DispatchError, Requester}; - -use ctap_types::Error as AuthenticatorError; -use trussed::interrupt::InterruptFlag; - +use ctaphid_dispatch::types::Requester; use ref_swap::OptionRefSwap; -// use serde::Serialize; +use trussed::interrupt::InterruptFlag; use usb_device::{ bus::UsbBus, endpoint::{EndpointAddress, EndpointIn, EndpointOut}, UsbError, - // Result as UsbResult, }; use crate::{ - constants::{ - // 3072 - MESSAGE_SIZE, - // 64 - PACKET_SIZE, - }, - types::KeepaliveStatus, + buffer::{Buffer, BufferState, Packet}, + constants::PACKET_SIZE, Version, }; -/// The actual payload of given length is dealt with separately -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -struct Request { - channel: u32, - command: Command, - length: u16, - timestamp: u32, -} - -impl Request { - fn error(self, error: AuthenticatorError) -> PipeError { - PipeError { - channel: self.channel, - error, - keep_state: false, - } - } - - fn error_now(self, error: AuthenticatorError) -> PipeError { - PipeError { - channel: self.channel, - error, - keep_state: true, - } - } -} - -/// The actual payload of given length is dealt with separately -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -struct Response { - channel: u32, - command: Command, - length: u16, -} - -impl Response { - fn from_request_and_size(request: Request, size: usize) -> Self { - Self { - channel: request.channel, - command: request.command, - length: size as u16, - } - } - - fn error_on_channel(channel: u32) -> Self { - Self { - channel, - command: Command::Error, - length: 1, - } - } -} - -struct PipeError { - channel: u32, - error: AuthenticatorError, - keep_state: bool, -} - -impl PipeError { - fn on_channel(channel: u32, error: AuthenticatorError) -> Self { - Self { - channel, - error, - keep_state: false, - } - } -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -struct MessageState { - // sequence number of next continuation packet - next_sequence: u8, - // number of bytes of message payload transmitted so far - transmitted: usize, -} - -impl Default for MessageState { - fn default() -> Self { - Self { - next_sequence: 0, - transmitted: PACKET_SIZE - 7, - } - } -} - -impl MessageState { - // update state due to receiving a full new continuation packet - #[must_use] - pub fn absorb_packet(mut self) -> Self { - self.next_sequence += 1; - self.transmitted += PACKET_SIZE - 5; - self - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -enum State { - Idle, - - // if request payload data is larger than one packet - Receiving((Request, MessageState)), - - // Processing(Request), - - // // the request message is ready, need to dispatch to authenticator - // Dispatching((Request, Ctap2Request)), - - // waiting for response from authenticator - WaitingOnAuthenticator(Request), - - WaitingToSend(Response), - - Sending((Response, MessageState)), -} - pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus> { endpoints: Endpoints<'alloc, Bus>, - state: State, - - interchange: Requester<'pipe>, - interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, - - // shared between requests and responses, due to size - buffer: [u8; MESSAGE_SIZE], - - // we assign channel IDs one by one, this is the one last assigned - // TODO: move into "app" - last_channel: u32, - - // Indicator of implemented commands in INIT response. - implements: u8, - - // timestamp that gets used for timing out CID's - last_milliseconds: u32, - - // a "read once" indicator if now we're waiting on the application processing - started_processing: bool, - - needs_keepalive: bool, - - version: Version, + buffer: Buffer<'pipe, 'interrupt>, } impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> { @@ -195,25 +41,9 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus ) -> Self { Self { endpoints: Endpoints::new(read_endpoint, write_endpoint), - state: State::Idle, - interchange, - buffer: [0u8; MESSAGE_SIZE], - last_channel: 0, - interrupt: None, - // Default to nothing implemented. - implements: 0x80, - last_milliseconds: initial_milliseconds, - started_processing: false, - needs_keepalive: false, - version: Default::default(), + buffer: Buffer::new(interchange, initial_milliseconds, None), } } -} - -impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> { - // pub fn borrow_mut_authenticator(&mut self) -> &mut Authenticator { - // &mut self.authenticator - // } pub fn with_interrupt( read_endpoint: EndpointOut<'alloc, Bus>, @@ -224,30 +54,20 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus ) -> Self { Self { endpoints: Endpoints::new(read_endpoint, write_endpoint), - state: State::Idle, - interchange, - buffer: [0u8; MESSAGE_SIZE], - last_channel: 0, - interrupt, - // Default to nothing implemented. - implements: 0x80, - last_milliseconds: initial_milliseconds, - started_processing: false, - needs_keepalive: false, - version: Default::default(), + buffer: Buffer::new(interchange, initial_milliseconds, interrupt), } } pub fn implements(&self) -> u8 { - self.implements + self.buffer.implements() } pub fn set_implements(&mut self, implements: u8) { - self.implements = implements; + self.buffer.set_implements(implements); } pub fn set_version(&mut self, version: Version) { - self.version = version; + self.buffer.set_version(version); } pub fn read_address(&self) -> EndpointAddress { @@ -268,16 +88,6 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus &self.endpoints.write } - fn cancel_ongoing_activity(&mut self) { - if matches!(self.state, State::WaitingOnAuthenticator(_)) { - info_now!("Interrupting request"); - if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) { - info_now!("Loaded some interrupter"); - i.interrupt(); - } - } - } - /// This method handles CTAP packets (64 bytes), until it has assembled /// a CTAP message, with which it then calls `dispatch_message`. /// @@ -286,464 +96,50 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus // info_now!("got a packet!"); let mut packet = [0u8; PACKET_SIZE]; if self.endpoints.read(&mut packet).is_ok() { - match self.handle_packet(&packet) { - Ok(Some(response)) => self.start_sending(response), - Ok(None) => (), - Err(error) => self.send_error(error), - } - } - } - - fn handle_packet(&mut self, packet: &[u8; 64]) -> Result, PipeError> { - info!(">> "); - info!("{}", hex_str!(&packet[..16])); - - // packet is 64 bytes, reading 4 will not panic - let channel = u32::from_be_bytes(packet[..4].try_into().unwrap()); - // info_now!("channel {}", channel); - - let is_initialization = (packet[4] >> 7) != 0; - // info_now!("is_initialization {}", is_initialization); - - if is_initialization { - // case of initialization packet - info!("init"); - - let command_number = packet[4] & !0x80; - // info_now!("command number {}", command_number); - - let command = match Command::try_from(command_number) { - Ok(command) => command, - // `solo ls` crashes here as it uses command 0x86 - Err(_) => { - info!("Received invalid command."); - return Err(PipeError::on_channel( - channel, - AuthenticatorError::InvalidCommand, - )); - } - }; - - // can't actually fail - let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap()); - - let timestamp = self.last_milliseconds; - let current_request = Request { - channel, - command, - length, - timestamp, - }; - - if !(self.state == State::Idle) { - let request = match self.state { - State::WaitingOnAuthenticator(request) => request, - State::Receiving((request, _message_state)) => request, - _ => { - info_now!("Ignoring transaction as we're already transmitting."); - return Ok(None); - } - }; - if packet[4] == 0x86 { - info_now!("Resyncing!"); - self.cancel_ongoing_activity(); - } else { - return if channel == request.channel { - if command == Command::Cancel { - info_now!("Cancelling"); - self.cancel_ongoing_activity(); - Ok(None) - } else { - info_now!("Expected seq, {:?}", request.command); - Err(request.error(AuthenticatorError::InvalidSeq)) - } - } else { - info_now!("busy."); - Err(current_request.error_now(AuthenticatorError::ChannelBusy)) - }; - } - } - - if length > MESSAGE_SIZE as u16 { - info!("Error message too big."); - return Err(current_request.error_now(AuthenticatorError::InvalidLength)); - } - - if length > PACKET_SIZE as u16 - 7 { - // store received part of payload, - // prepare for continuation packets - self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); - self.state = State::Receiving((current_request, { MessageState::default() })); - // we're done... wait for next packet - Ok(None) - } else { - // request fits in one packet - self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); - self.dispatch_request(current_request) - } - } else { - // case of continuation packet - match self.state { - State::Receiving((request, message_state)) => { - let sequence = packet[4]; - // info_now!("receiving continuation packet {}", sequence); - if sequence != message_state.next_sequence { - // error handling? - // info_now!("wrong sequence for continuation packet, expected {} received {}", - // message_state.next_sequence, sequence); - info!("Error invalid cont pkt"); - return Err(request.error(AuthenticatorError::InvalidSeq)); - } - if channel != request.channel { - // error handling? - // info_now!("wrong channel for continuation packet, expected {} received {}", - // request.channel, channel); - info!("Ignore invalid channel"); - return Ok(None); - } - - let payload_length = request.length as usize; - if message_state.transmitted + (PACKET_SIZE - 5) < payload_length { - // info_now!("transmitted {} + (PACKET_SIZE - 5) < {}", - // message_state.transmitted, payload_length); - // store received part of payload - self.buffer[message_state.transmitted..][..PACKET_SIZE - 5] - .copy_from_slice(&packet[5..]); - let message_state = message_state.absorb_packet(); - self.state = State::Receiving((request, message_state)); - // info_now!("absorbed packet, awaiting next"); - Ok(None) - } else { - let missing = request.length as usize - message_state.transmitted; - self.buffer[message_state.transmitted..payload_length] - .copy_from_slice(&packet[5..][..missing]); - self.dispatch_request(request) - } - } - _ => { - // unexpected continuation packet - info!("Ignore unexpected cont pkt"); - Ok(None) - } - } + let state = self.buffer.handle_packet(&packet); + self.handle(state); } } pub fn check_timeout(&mut self, milliseconds: u32) { - // At any point the RP application could crash or something, - // so its up to the device to timeout those transactions. - let last = self.last_milliseconds; - self.last_milliseconds = milliseconds; - if let State::Receiving((request, _message_state)) = &mut self.state { - if (milliseconds - last) > 200 { - // If there's a lapse in `check_timeout(...)` getting called (e.g. due to logging), - // this could lead to inaccurate timestamps on requests. So we'll - // just "forgive" requests temporarily if this happens. - debug!( - "lapse in hid check.. {} {} {}", - request.timestamp, milliseconds, last - ); - request.timestamp = milliseconds; - } - // compare keeping in mind of possible overflow in timestamp. - else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550) - || (milliseconds < request.timestamp && milliseconds > 550) - { - debug!( - "Channel timeout. {}, {}, {}", - request.timestamp, milliseconds, last - ); - let req = *request; - self.send_error(req.error(AuthenticatorError::Timeout)); - } - } - } - - fn dispatch_request(&mut self, request: Request) -> Result, PipeError> { - info!("Got request: {:?}", request.command); - match request.command { - Command::Init => {} - _ => { - if request.channel == 0xffffffff { - return Err(request.error(AuthenticatorError::InvalidChannel)); - } - } - } - // dispatch request further - match request.command { - Command::Init => { - // info_now!("command INIT!"); - // info_now!("data: {:?}", &self.buffer[..request.length as usize]); - match request.channel { - 0 => { - // this is an error / reserved number - Err(request.error(AuthenticatorError::InvalidChannel)) - } - - // broadcast channel ID - request for assignment - cid => { - if request.length != 8 { - // error - info!("Invalid length for init. ignore."); - Ok(None) - } else { - self.last_channel += 1; - // info_now!( - // "assigned channel {}", self.last_channel); - let _nonce = &self.buffer[..8]; - let response = Response { - channel: cid, - command: request.command, - length: 17, - }; - - self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes()); - // CTAPHID protocol version - self.buffer[12] = 2; - // major device version number - self.buffer[13] = self.version.major; - // minor device version number - self.buffer[14] = self.version.minor; - // build device version number - self.buffer[15] = self.version.build; - // capabilities flags - // 0x1: implements WINK - // 0x4: implements CBOR - // 0x8: does not implement MSG - // self.buffer[16] = 0x01 | 0x08; - self.buffer[16] = self.implements; - Ok(Some(response)) - } - } - } - } - - Command::Ping => { - let response = Response::from_request_and_size(request, request.length as usize); - Ok(Some(response)) - } - - Command::Cancel => { - info!("CTAPHID_CANCEL"); - self.cancel_ongoing_activity(); - Ok(None) - } - - _ => { - self.needs_keepalive = request.command == Command::Cbor; - if self.interchange.state() == interchange::State::Responded { - info!("dumping stale response"); - self.interchange.take_response(); - } - match self.interchange.request(( - request.command, - heapless::Vec::from_slice(&self.buffer[..request.length as usize]).unwrap(), - )) { - Ok(_) => { - self.state = State::WaitingOnAuthenticator(request); - self.started_processing = true; - Ok(None) - } - Err(_) => { - // busy - info_now!("STATE: {:?}", self.interchange.state()); - info!("can't handle more than one authenticator request at a time."); - Err(request.error_now(AuthenticatorError::ChannelBusy)) - } - } - } - } + let state = self.buffer.check_timeout(milliseconds); + self.handle(state); } pub fn did_start_processing(&mut self) -> bool { - if self.started_processing { - self.started_processing = false; - true - } else { - false - } + self.buffer.did_start_processing() } pub fn send_keepalive(&mut self, is_waiting_for_user_presence: bool) -> bool { - if let State::WaitingOnAuthenticator(request) = &self.state { - if !self.needs_keepalive { - // let response go out normally in idle loop - info!("cmd does not need keepalive messages"); - false - } else { - info!("keepalive"); - - let response = Response { - channel: request.channel, - command: Command::KeepAlive, - length: 1, - }; - let status = if is_waiting_for_user_presence { - KeepaliveStatus::UpNeeded - } else { - KeepaliveStatus::Processing - }; - self.endpoints - .write(Packet::init(response, &[status as u8])) - .ok(); - - true - } + if let Some(packet) = self.buffer.send_keepalive(is_waiting_for_user_presence) { + self.endpoints.write(packet).ok(); + true } else { - info!("keepalive done"); false } } pub fn handle_and_write_response(&mut self) { - match self.handle_response() { - Ok(Some(response)) => self.start_sending(response), - Ok(None) => (), - Err(error) => self.send_error(error), - } + let state = self.buffer.handle_response(); + self.handle(state); } - #[inline(never)] - fn handle_response(&mut self) -> Result, PipeError> { - if let State::WaitingOnAuthenticator(request) = self.state { - if let Ok(response) = self.interchange.response() { - match &response.0 { - Err(DispatchError::InvalidCommand) => { - info!("Got waiting reply from authenticator??"); - Err(request.error(AuthenticatorError::InvalidCommand)) - } - Err(DispatchError::InvalidLength) => { - info!("Error, payload needed app command."); - Err(request.error(AuthenticatorError::InvalidLength)) - } - Err(DispatchError::NoResponse) => { - info!("Got waiting noresponse from authenticator??"); - Ok(None) - } - - Ok(message) => { - if message.len() > self.buffer.len() { - error!( - "Message is longer than buffer ({} > {})", - message.len(), - self.buffer.len(), - ); - Err(request.error(AuthenticatorError::InvalidLength)) - } else { - info!( - "Got {} bytes response from authenticator, starting send", - message.len() - ); - let response = Response::from_request_and_size(request, message.len()); - self.buffer[..message.len()].copy_from_slice(message); - Ok(Some(response)) - } - } - } - } else { - Ok(None) + fn handle(&mut self, state: BufferState) { + match state { + BufferState::Idle => (), + BufferState::ResponseQueued => self.maybe_write_packet(), + BufferState::Error(error) => { + // TODO: should we block? + self.endpoints.write(Packet::from(&error)).ok(); } - } else { - Ok(None) - } - } - - fn start_sending(&mut self, response: Response) { - self.state = State::WaitingToSend(response); - self.maybe_write_packet(); - } - - fn send_error(&mut self, error: PipeError) { - let response = Response::error_on_channel(error.channel); - if error.keep_state { - // TODO: should we block? - self.endpoints - .write(Packet::init(response, &[error.error as u8])) - .ok(); - } else { - self.buffer[0] = error.error as u8; - self.start_sending(response); } } // called from poll, and when a packet has been sent #[inline(never)] pub fn maybe_write_packet(&mut self) { - let packet = match self.state { - State::WaitingToSend(response) => Packet::init(response, &self.buffer), - State::Sending((response, message_state)) => { - Packet::cont(response, message_state, &self.buffer) - } - // nothing to send - _ => { - return; - } - }; - if self.endpoints.write(packet).is_ok() { - self.state = packet.next_state(); - } - } -} - -#[derive(Clone, Copy, Debug)] -struct Packet<'a> { - response: Response, - message_state: Option, - buffer: &'a [u8], -} - -impl<'a> Packet<'a> { - fn init(response: Response, buffer: &'a [u8]) -> Self { - Self { - response, - message_state: None, - buffer, - } - } - - fn cont(response: Response, message_state: MessageState, buffer: &'a [u8]) -> Self { - Self { - response, - message_state: Some(message_state), - buffer, - } - } - - fn has_more(&self) -> bool { - if let Some(message_state) = self.message_state { - let remaining = usize::from(self.response.length) - message_state.transmitted; - remaining > PACKET_SIZE - 5 - } else { - usize::from(self.response.length) > PACKET_SIZE - 7 - } - } - - fn next_state(&self) -> State { - if self.has_more() { - let message_state = self - .message_state - .map(MessageState::absorb_packet) - .unwrap_or_default(); - State::Sending((self.response, message_state)) - } else { - State::Idle - } - } - - fn serialize(&self, buffer: &mut [u8; PACKET_SIZE]) { - // buffer must be zeroed - buffer[..4].copy_from_slice(&self.response.channel.to_be_bytes()); - if let Some(message_state) = self.message_state { - buffer[4] = message_state.next_sequence; - let remaining = usize::from(self.response.length) - message_state.transmitted; - let n = remaining.min(PACKET_SIZE - 5); - buffer[5..][..n].copy_from_slice(&self.buffer[message_state.transmitted..][..n]); - } else { - buffer[4] = self.response.command.into_u8() | 0x80; - buffer[5..7].copy_from_slice(&self.response.length.to_be_bytes()); - let n = usize::from(self.response.length).min(PACKET_SIZE - 7); - buffer[7..][..n].copy_from_slice(&self.buffer[..n]); - } + self.buffer + .try_send_packet(|packet| self.endpoints.write(packet)); } }