Skip to content

Commit c6ad3fc

Browse files
jakepetroulesLuke Daley
authored and
Luke Daley
committed
Add a helper function for avoiding certain race conditions in Win32 API calls
For many Win32 APIs returning strings, the expected usage is to first call the API with a nil buffer and 0 capacity to receive the buffer count, then call it again with a buffer of the right size. However, some APIs refer to external state which can change between calls (such as GetEnvironmentVariableW and GetCurrentDirectoryW). This can lead to race conditions where the buffer doesn't end up being of sufficient size to hold the result, and the call fails. To protect against this, add a helper function with a reusable algorithm that continually doubles the buffer size until it's large enough to hold the result, up to a specified maximum to prevent denial of service attacks.
1 parent e430fac commit c6ad3fc

File tree

5 files changed

+72
-59
lines changed

5 files changed

+72
-59
lines changed

Sources/SWBUtil/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ add_library(SWBUtil
101101
VFS.swift
102102
WaitCondition.swift
103103
WeakRef.swift
104+
Win32.swift
104105
Win32Error.swift
105106
XCBuildDataArchive.swift
106107
Xcode.swift)

Sources/SWBUtil/Library.swift

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,32 +117,3 @@ public struct LibraryHandle: @unchecked Sendable {
117117
self.rawValue = rawValue
118118
}
119119
}
120-
121-
#if os(Windows)
122-
@_spi(Testing) public func SWB_GetModuleFileNameW(_ hModule: HMODULE?) throws -> String {
123-
#if DEBUG
124-
var bufferCount = Int(1) // force looping
125-
#else
126-
var bufferCount = Int(MAX_PATH)
127-
#endif
128-
while true {
129-
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: bufferCount, { buffer in
130-
switch (GetModuleFileNameW(hModule, buffer.baseAddress!, DWORD(buffer.count)), GetLastError()) {
131-
case (1..<DWORD(bufferCount), DWORD(ERROR_SUCCESS)):
132-
guard let result = String.decodeCString(buffer.baseAddress!, as: UTF16.self)?.result else {
133-
throw Win32Error(DWORD(ERROR_ILLEGAL_CHARACTER))
134-
}
135-
return result
136-
case (DWORD(bufferCount), DWORD(ERROR_INSUFFICIENT_BUFFER)):
137-
bufferCount += Int(MAX_PATH)
138-
return nil
139-
case (_, let errorCode):
140-
throw Win32Error(errorCode)
141-
}
142-
}) {
143-
return result
144-
}
145-
}
146-
preconditionFailure("unreachable")
147-
}
148-
#endif

Sources/SWBUtil/POSIX.swift

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,10 @@ public enum POSIX: Sendable {
2626
public static func getenv(_ name: String) throws -> String? {
2727
#if os(Windows)
2828
try name.withCString(encodedAs: CInterop.PlatformUnicodeEncoding.self) { wName in
29-
let dwLength: DWORD = GetEnvironmentVariableW(wName, nil, 0)
30-
if dwLength == 0 {
31-
if GetLastError() == ERROR_ENVVAR_NOT_FOUND {
32-
return nil
33-
}
34-
throw POSIXError(errno, context: "GetEnvironmentVariableW", name)
35-
}
36-
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
37-
switch GetEnvironmentVariableW(wName, $0.baseAddress!, DWORD($0.count)) {
38-
case 1..<dwLength:
39-
return String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self)
40-
case 0 where GetLastError() == ERROR_ENVVAR_NOT_FOUND:
41-
return nil
42-
default:
43-
throw POSIXError(errno, context: "GetEnvironmentVariableW", name)
44-
}
29+
do {
30+
return try SWB_GetEnvironmentVariableW(wName)
31+
} catch let error as Win32Error where error.error == ERROR_ENVVAR_NOT_FOUND {
32+
return nil
4533
}
4634
}
4735
#else

Sources/SWBUtil/Win32.swift

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift open source project
4+
//
5+
// Copyright (c) 2025 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See http://swift.org/LICENSE.txt for license information
9+
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#if os(Windows)
14+
public import WinSDK
15+
16+
/// Calls a Win32 API function that fills a (potentially long path) null-terminated string buffer by continually attempting to allocate more memory up until the true max path is reached.
17+
/// This is especially useful for protecting against race conditions like with GetCurrentDirectoryW where the measured length may no longer be valid on subsequent calls.
18+
/// - parameter initialSize: Initial size of the buffer (including the null terminator) to allocate to hold the returned string.
19+
/// - parameter maxSize: Maximum size of the buffer (including the null terminator) to allocate to hold the returned string.
20+
/// - parameter body: Closure to call the Win32 API function to populate the provided buffer.
21+
/// Should return the number of UTF-16 code units (not including the null terminator) copied, 0 to indicate an error.
22+
/// If the buffer is not of sufficient size, should return a value greater than or equal to the size of the buffer.
23+
private func FillNullTerminatedWideStringBuffer(initialSize: DWORD, maxSize: DWORD, _ body: (UnsafeMutableBufferPointer<WCHAR>) throws -> DWORD) throws -> String {
24+
var bufferCount = max(1, min(initialSize, maxSize))
25+
while bufferCount <= maxSize {
26+
if let result = try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(bufferCount), { buffer in
27+
let count = try body(buffer)
28+
switch count {
29+
case 0:
30+
throw Win32Error(GetLastError())
31+
case 1..<DWORD(buffer.count):
32+
let result = String(decodingCString: buffer.baseAddress!, as: UTF16.self)
33+
assert(result.utf16.count == count, "Parsed UTF-16 count \(result.utf16.count) != reported UTF-16 count \(count)")
34+
return result
35+
default:
36+
bufferCount *= 2
37+
return nil
38+
}
39+
}) {
40+
return result
41+
}
42+
}
43+
throw Win32Error(DWORD(ERROR_INSUFFICIENT_BUFFER))
44+
}
45+
46+
private let maxPathLength = DWORD(Int16.max) // https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation
47+
private let maxEnvVarLength = DWORD(Int16.max) // https://devblogs.microsoft.com/oldnewthing/20100203-00/
48+
49+
@_spi(Testing) public func SWB_GetModuleFileNameW(_ hModule: HMODULE?) throws -> String {
50+
try FillNullTerminatedWideStringBuffer(initialSize: DWORD(MAX_PATH), maxSize: maxPathLength) {
51+
GetModuleFileNameW(hModule, $0.baseAddress!, DWORD($0.count))
52+
}
53+
}
54+
55+
public func SWB_GetEnvironmentVariableW(_ wName: LPCWSTR) throws -> String {
56+
try FillNullTerminatedWideStringBuffer(initialSize: 1024, maxSize: maxEnvVarLength) {
57+
GetEnvironmentVariableW(wName, $0.baseAddress!, DWORD($0.count))
58+
}
59+
}
60+
61+
public func SWB_GetWindowsDirectoryW() throws -> String {
62+
try FillNullTerminatedWideStringBuffer(initialSize: DWORD(MAX_PATH), maxSize: maxPathLength) {
63+
GetWindowsDirectoryW($0.baseAddress!, DWORD($0.count))
64+
}
65+
}
66+
#endif

Tests/SwiftBuildTests/ConsoleCommands/CLIConnection.swift

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,7 @@ fileprivate func swiftRuntimePath() throws -> Path? {
328328

329329
fileprivate func systemRoot() throws -> Path? {
330330
#if os(Windows)
331-
let dwLength: DWORD = GetWindowsDirectoryW(nil, 0)
332-
if dwLength == 0 {
333-
throw Win32Error(GetLastError())
334-
}
335-
return try withUnsafeTemporaryAllocation(of: WCHAR.self, capacity: Int(dwLength)) {
336-
switch GetWindowsDirectoryW($0.baseAddress!, DWORD($0.count)) {
337-
case 1..<dwLength:
338-
return Path(String(decodingCString: $0.baseAddress!, as: CInterop.PlatformUnicodeEncoding.self))
339-
case 0:
340-
throw Win32Error(GetLastError())
341-
default:
342-
throw Win32Error(DWORD(ERROR_INSUFFICIENT_BUFFER))
343-
}
344-
}
331+
return try Path(SWB_GetWindowsDirectoryW())
345332
#else
346333
return nil
347334
#endif

0 commit comments

Comments
 (0)