Skip to content

Commit

Permalink
Using existing TaskQueue, refactoring it to support non-throwing clos…
Browse files Browse the repository at this point in the history
…ures.
  • Loading branch information
sebaland committed Jan 9, 2024
1 parent 49e5692 commit 5cfcce1
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 112 deletions.
46 changes: 38 additions & 8 deletions Amplify/Core/Support/TaskQueue.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,64 @@
import Foundation

/// A helper for executing asynchronous work serially.
public actor TaskQueue<Success> {
private var previousTask: Task<Success, Error>?

public actor TaskQueue<Success, Failure> where Failure: Error {
private var previousTask: Task<Success, Failure>?

Check warning on line 13 in Amplify/Core/Support/TaskQueue.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
public init() {}
}

public extension TaskQueue where Failure == any Error {
/// Serializes asynchronous requests made from an async context
///
/// Given an invocation like
/// ```swift
/// let tq = TaskQueue<Int>()
/// let tq = TaskQueue<Int, Error>()
/// let v1 = try await tq.sync { try await doAsync1() }
/// let v2 = try await tq.sync { try await doAsync2() }
/// let v3 = try await tq.sync { try await doAsync3() }
/// ```
/// TaskQueue serializes this work so that `doAsync1` is performed before `doAsync2`,
/// which is performed before `doAsync3`.
public func sync(block: @Sendable @escaping () async throws -> Success) async throws -> Success {
let currentTask: Task<Success, Error> = Task { [previousTask] in
func sync(block: @Sendable @escaping () async throws -> Success) async throws -> Success {
let currentTask: Task<Success, Failure> = Task { [previousTask] in
_ = await previousTask?.result
return try await block()
}
previousTask = currentTask
return try await currentTask.value
}

public nonisolated func async(block: @Sendable @escaping () async throws -> Success) rethrows {

Check warning on line 37 in Amplify/Core/Support/TaskQueue.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
nonisolated func async(block: @Sendable @escaping () async throws -> Success) rethrows {
Task {
try await sync(block: block)
}
}
}

public extension TaskQueue where Failure == Never {
/// Serializes asynchronous requests made from an async context
///
/// Given an invocation like
/// ```swift
/// let tq = TaskQueue<Int, Never>()
/// let v1 = await tq.sync { await doAsync1() }
/// let v2 = await tq.sync { await doAsync2() }
/// let v3 = await tq.sync { await doAsync3() }
/// ```
/// TaskQueue serializes this work so that `doAsync1` is performed before `doAsync2`,
/// which is performed before `doAsync3`.
func sync(block: @Sendable @escaping () async -> Success) async -> Success {
let currentTask: Task<Success, Failure> = Task { [previousTask] in
_ = await previousTask?.result
return await block()
}
previousTask = currentTask
return await currentTask.value
}

Check warning on line 65 in Amplify/Core/Support/TaskQueue.swift

View workflow job for this annotation

GitHub Actions / run-swiftlint

Lines should not have trailing whitespace (trailing_whitespace)
nonisolated func async(block: @Sendable @escaping () async -> Success) {
Task {
await sync(block: block)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public final class AWSCognitoAuthPlugin: AWSCognitoAuthPluginBehavior {

var analyticsHandler: UserPoolAnalyticsBehavior!

var taskQueue: TaskQueue<Any>!
var taskQueue: TaskQueue<Any, Error>!

var httpClientEngineProxy: HttpClientEngineProxy?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CredentialStoreOperationClient: CredentialStoreStateBehavior {

// Task queue is being used to manage CRUD operations to the credential store synchronously
// This will help us keeping the CRUD methods atomic
private let taskQueue = TaskQueue<CredentialStoreData?>()
private let taskQueue = TaskQueue<CredentialStoreData?, Error>()

init(credentialStoreStateMachine: CredentialStoreStateMachine) {
self.credentialStoreStateMachine = credentialStoreStateMachine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class RemoteSyncEngine: RemoteSyncEngineBehavior {
}

/// Synchronizes startup operations
let taskQueue = TaskQueue<Void>()
let taskQueue = TaskQueue<Void, Error>()

// Assigned at `setUpCloudSubscriptions`
var reconciliationQueue: IncomingEventReconciliationQueue?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,56 @@ import Foundation
public struct PinpointSession: Codable {
private enum State: Codable {
case active
case paused
case stopped
case paused(date: Date)
case stopped(date: Date)
}
typealias SessionId = String

let sessionId: SessionId
let startTime: Date
private(set) var stopTime: Date?
var stopTime: Date? {
switch state {
case .active:
return nil
case .paused(let stopTime),
.stopped(let stopTime):
return stopTime
}
}

private var state: State = .active

init(appId: String,
uniqueId: String) {
sessionId = Self.generateSessionId(appId: appId,
uniqueId: uniqueId)
startTime = Date()
stopTime = nil
}

init(sessionId: SessionId,
startTime: Date,
stopTime: Date?) {
self.sessionId = sessionId
self.startTime = startTime
self.stopTime = stopTime
if stopTime != nil {
state = .stopped
if let stopTime {
self.state = .stopped(date: stopTime)
}
}

var isPaused: Bool {
return stopTime != nil && state == .paused
if case .paused = state {
return true
}

return false
}

var isStopped: Bool {
return stopTime != nil && state == .stopped
if case .stopped = state {
return true
}

return false
}

var duration: Date.Millisecond? {
Expand All @@ -56,18 +71,15 @@ public struct PinpointSession: Codable {

mutating func stop() {
guard !isStopped else { return }
stopTime = stopTime ?? Date()
state = .stopped
state = .stopped(date: stopTime ?? Date())
}

mutating func pause() {
guard !isPaused else { return }
stopTime = Date()
state = .paused
state = .paused(date: Date())
}

mutating func resume() {
stopTime = nil
state = .active
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Amplify
import AWSPluginsCore
import Foundation

@_spi(InternalAWSPinpoint)
Expand All @@ -14,7 +15,6 @@ public protocol SessionClientBehaviour: AnyObject {
var analyticsClient: AnalyticsClientBehaviour? { get set }

func startPinpointSession()
func validateOrRetrieveSession(_ session: PinpointSession?) -> PinpointSession
func startTrackingSessions(backgroundTimeout: TimeInterval)
}

Expand All @@ -34,7 +34,7 @@ class SessionClient: SessionClientBehaviour {
private let configuration: SessionClientConfiguration
private let sessionClientQueue = DispatchQueue(label: Constants.queue,
attributes: .concurrent)
private let analyticsTaskQueue = TaskQueue()
private let analyticsTaskQueue = TaskQueue<Void, Never>()
private let userDefaults: UserDefaultsBehaviour
private var sessionBackgroundTimeout: TimeInterval = .zero

Expand Down Expand Up @@ -82,28 +82,14 @@ class SessionClient: SessionClientBehaviour {
sessionBackgroundTimeout = backgroundTimeout
activityTracker.backgroundTrackingTimeout = backgroundTimeout
activityTracker.beginActivityTracking { [weak self] newState in
guard let self = self else { return }
guard let self else { return }
self.log.verbose("New state received: \(newState)")
self.sessionClientQueue.sync(flags: .barrier) {
self.respond(to: newState)
}
}
}

func validateOrRetrieveSession(_ session: PinpointSession?) -> PinpointSession {
if let session = session, !session.sessionId.isEmpty {
return session
}

if let storedSession = Self.retrieveStoredSession(from: userDefaults, using: archiver) {
return storedSession
}

return PinpointSession(sessionId: PinpointSession.Constants.defaultSessionId,
startTime: Date(),
stopTime: Date())
}

private static func retrieveStoredSession(from userDefaults: UserDefaultsBehaviour,
using archiver: AmplifyArchiverBehaviour) -> PinpointSession? {
guard let sessionData = userDefaults.data(forKey: Constants.sessionKey),
Expand All @@ -122,8 +108,8 @@ class SessionClient: SessionClientBehaviour {
log.info("Session Started.")

// Update Endpoint and record Session Start event
analyticsTaskQueue.task { [weak self] in
guard let self = self else { return }
analyticsTaskQueue.async { [weak self] in
guard let self else { return }
try? await self.endpointClient.updateEndpointProfile()
self.log.verbose("Firing Session Event: Start")
await self.record(eventType: Constants.Events.start)
Expand All @@ -144,8 +130,8 @@ class SessionClient: SessionClientBehaviour {
session.pause()
saveSession()
log.info("Session Paused.")
analyticsTaskQueue.task { [weak self] in
guard let self = self else { return }
analyticsTaskQueue.async { [weak self] in
guard let self else { return }
self.log.verbose("Firing Session Event: Pause")
await self.record(eventType: Constants.Events.pause)
}
Expand Down Expand Up @@ -174,8 +160,8 @@ class SessionClient: SessionClientBehaviour {
session.resume()
saveSession()
log.info("Session Resumed.")
analyticsTaskQueue.task { [weak self] in
guard let self = self else { return }
analyticsTaskQueue.async { [weak self] in
guard let self else { return }
self.log.verbose("Firing Session Event: Resume")
await self.record(eventType: Constants.Events.resume)
}
Expand All @@ -189,7 +175,7 @@ class SessionClient: SessionClientBehaviour {
}
session.stop()
log.info("Session Stopped.")
analyticsTaskQueue.task { [weak self, session] in
analyticsTaskQueue.async { [weak self, session] in
guard let self = self,
let analyticsClient = self.analyticsClient else {
return
Expand Down Expand Up @@ -236,7 +222,7 @@ class SessionClient: SessionClientBehaviour {
case .runningInBackground(let isStale):
if isStale {
endSession()
analyticsTaskQueue.task { [weak self] in
analyticsTaskQueue.async { [weak self] in
_ = try? await self?.analyticsClient?.submitEvents()
}
} else {
Expand Down Expand Up @@ -278,23 +264,3 @@ extension SessionClient {
extension PinpointSession {
static var none = PinpointSession(sessionId: "InvalidId", startTime: Date(), stopTime: nil)
}

/// This actor allows to queue async operations to only run one at a time.
private actor TaskQueue {
private var currentTask: Task<Void, Never>?

nonisolated func task(_ closure: @escaping () async -> ()) {
Task {
await addToQueue(closure)
}
}

private func addToQueue(_ closure: @escaping () async -> ()) async {
let newTask = Task { [currentTask] in
await currentTask?.value
await closure()
}
currentTask = newTask
await newTask.value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,47 +114,6 @@ class SessionClientTests: XCTestCase {
XCTAssertEqual(userDefaults.saveCount, 0)
}

func testValidateSession_withValidSession_andStoredSession_shouldReturnValidSession() async {
storeSession()
await resetCounters()
let session = PinpointSession(sessionId: "valid", startTime: Date(), stopTime: nil)
let retrievedSession = client.validateOrRetrieveSession(session)

XCTAssertEqual(userDefaults.dataForKeyCount, 0)
XCTAssertEqual(archiver.decodeCount, 0)
XCTAssertEqual(retrievedSession.sessionId, "valid")
}

func testValidateSession_withInvalidSession_andStoredSession_shouldReturnStoredSession() async {
storeSession()
await resetCounters()
let session = PinpointSession(sessionId: "", startTime: Date(), stopTime: nil)
let retrievedSession = client.validateOrRetrieveSession(session)

XCTAssertEqual(userDefaults.dataForKeyCount, 1)
XCTAssertEqual(archiver.decodeCount, 1)
XCTAssertEqual(retrievedSession.sessionId, "stored")
}

func testValidateSession_withInvalidSession_andWithoutStoredSession_shouldCreateDefaultSession() async {
await resetCounters()
let session = PinpointSession(sessionId: "", startTime: Date(), stopTime: nil)
let retrievedSession = client.validateOrRetrieveSession(session)

XCTAssertEqual(userDefaults.dataForKeyCount, 1)
XCTAssertEqual(archiver.decodeCount, 0)
XCTAssertEqual(retrievedSession.sessionId, PinpointSession.Constants.defaultSessionId)
}

func testValidateSession_withNilSession_andWithoutStoredSession_shouldCreateDefaultSession() async {
await resetCounters()
let retrievedSession = client.validateOrRetrieveSession(nil)

XCTAssertEqual(userDefaults.dataForKeyCount, 1)
XCTAssertEqual(archiver.decodeCount, 0)
XCTAssertEqual(retrievedSession.sessionId, PinpointSession.Constants.defaultSessionId)
}

func testStartPinpointSession_shouldRecordStartEvent() async {
await resetCounters()
let expectationStartSession = expectation(description: "Start event for new session")
Expand Down
2 changes: 1 addition & 1 deletion AmplifyTests/CoreTests/AmplifyTaskQueueTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class AmplifyTaskQueueTests: XCTestCase {
let expectation2 = expectation(description: "expectation2")
let expectation3 = expectation(description: "expectation3")

let taskQueue = TaskQueue<Void>()
let taskQueue = TaskQueue<Void, Error>()
try await taskQueue.sync {
try await Task.sleep(nanoseconds: 1)
expectation1.fulfill()
Expand Down

0 comments on commit 5cfcce1

Please sign in to comment.