Skip to content

Add a helper function for avoiding certain race conditions in Win32 API calls #476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Sources/SWBUtil/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ add_library(SWBUtil
VFS.swift
WaitCondition.swift
WeakRef.swift
Win32.swift
Win32Error.swift
XCBuildDataArchive.swift
Xcode.swift)
Expand Down
29 changes: 0 additions & 29 deletions Sources/SWBUtil/Library.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,3 @@ public struct LibraryHandle: @unchecked Sendable {
self.rawValue = rawValue
}
}

#if os(Windows)
@_spi(Testing) public func SWB_GetModuleFileNameW(_ hModule: HMODULE?) throws -> String {
#if DEBUG
var bufferCount = Int(1) // force looping
#else
var bufferCount = Int(MAX_PATH)
#endif
while true {
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: bufferCount, { buffer in
switch (GetModuleFileNameW(hModule, buffer.baseAddress!, DWORD(buffer.count)), GetLastError()) {
case (1..<DWORD(bufferCount), DWORD(ERROR_SUCCESS)):
guard let result = String.decodeCString(buffer.baseAddress!, as: UTF16.self)?.result else {
throw Win32Error(DWORD(ERROR_ILLEGAL_CHARACTER))
}
return result
case (DWORD(bufferCount), DWORD(ERROR_INSUFFICIENT_BUFFER)):
bufferCount += Int(MAX_PATH)
return nil
case (_, let errorCode):
throw Win32Error(errorCode)
}
}) {
return result
}
}
preconditionFailure("unreachable")
}
#endif
20 changes: 4 additions & 16 deletions Sources/SWBUtil/POSIX.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,10 @@ public enum POSIX: Sendable {
public static func getenv(_ name: String) throws -> String? {
#if os(Windows)
try name.withCString(encodedAs: CInterop.PlatformUnicodeEncoding.self) { wName in
let dwLength: DWORD = GetEnvironmentVariableW(wName, nil, 0)
if dwLength == 0 {
if GetLastError() == ERROR_ENVVAR_NOT_FOUND {
return nil
}
throw POSIXError(errno, context: "GetEnvironmentVariableW", name)
}
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
switch GetEnvironmentVariableW(wName, $0.baseAddress!, DWORD($0.count)) {
case 1..<dwLength:
return String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self)
case 0 where GetLastError() == ERROR_ENVVAR_NOT_FOUND:
return nil
default:
throw POSIXError(errno, context: "GetEnvironmentVariableW", name)
}
do {
return try SWB_GetEnvironmentVariableW(wName)
} catch let error as Win32Error where error.error == ERROR_ENVVAR_NOT_FOUND {
return nil
}
}
#else
Expand Down
66 changes: 66 additions & 0 deletions Sources/SWBUtil/Win32.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift open source project
//
// Copyright (c) 2025 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See http://swift.org/LICENSE.txt for license information
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#if os(Windows)
public import WinSDK

/// Calls a Win32 API function that fills a (potentially long path) null-terminated string buffer by continually attempting to allocate more memory up until the true max path is reached.
/// This is especially useful for protecting against race conditions like with GetCurrentDirectoryW where the measured length may no longer be valid on subsequent calls.
/// - parameter initialSize: Initial size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter maxSize: Maximum size of the buffer (including the null terminator) to allocate to hold the returned string.
/// - parameter body: Closure to call the Win32 API function to populate the provided buffer.
/// Should return the number of UTF-16 code units (not including the null terminator) copied, 0 to indicate an error.
/// If the buffer is not of sufficient size, should return a value greater than or equal to the size of the buffer.
private func FillNullTerminatedWideStringBuffer(initialSize: DWORD, maxSize: DWORD, _ body: (UnsafeMutableBufferPointer<WCHAR>) throws -> DWORD) throws -> String {
var bufferCount = max(1, min(initialSize, maxSize))
while bufferCount <= maxSize {
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(bufferCount), { buffer in
let count = try body(buffer)
switch count {
case 0:
throw Win32Error(GetLastError())
case 1..<DWORD(buffer.count):
let result = String(decodingCString: buffer.baseAddress!, as: UTF16.self)
assert(result.utf16.count == count, "Parsed UTF-16 count \(result.utf16.count) != reported UTF-16 count \(count)")
return result
default:
bufferCount *= 2
return nil
}
}) {
return result
}
}
throw Win32Error(DWORD(ERROR_INSUFFICIENT_BUFFER))
}

private let maxPathLength = DWORD(Int16.max) // https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation
private let maxEnvVarLength = DWORD(Int16.max) // https://devblogs.microsoft.com/oldnewthing/20100203-00/

@_spi(Testing) public func SWB_GetModuleFileNameW(_ hModule: HMODULE?) throws -> String {
try FillNullTerminatedWideStringBuffer(initialSize: DWORD(MAX_PATH), maxSize: maxPathLength) {
GetModuleFileNameW(hModule, $0.baseAddress!, DWORD($0.count))
}
}

public func SWB_GetEnvironmentVariableW(_ wName: LPCWSTR) throws -> String {
try FillNullTerminatedWideStringBuffer(initialSize: 1024, maxSize: maxEnvVarLength) {
GetEnvironmentVariableW(wName, $0.baseAddress!, DWORD($0.count))
}
}

public func SWB_GetWindowsDirectoryW() throws -> String {
try FillNullTerminatedWideStringBuffer(initialSize: DWORD(MAX_PATH), maxSize: maxPathLength) {
GetWindowsDirectoryW($0.baseAddress!, DWORD($0.count))
}
}
#endif
15 changes: 1 addition & 14 deletions Tests/SwiftBuildTests/ConsoleCommands/CLIConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,7 @@ fileprivate func swiftRuntimePath() throws -> Path? {

fileprivate func systemRoot() throws -> Path? {
#if os(Windows)
let dwLength: DWORD = GetWindowsDirectoryW(nil, 0)
if dwLength == 0 {
throw Win32Error(GetLastError())
}
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
switch GetWindowsDirectoryW($0.baseAddress!, DWORD($0.count)) {
case 1..<dwLength:
return Path(String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self))
case 0:
throw Win32Error(GetLastError())
default:
throw Win32Error(DWORD(ERROR_INSUFFICIENT_BUFFER))
}
}
return try Path(SWB_GetWindowsDirectoryW())
#else
return nil
#endif
Expand Down