Skip to content

Commit 599298a

Browse files
committed
Add JLD2 extension
1 parent 602976f commit 599298a

File tree

5 files changed

+46
-0
lines changed

5 files changed

+46
-0
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@ ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1515
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1616
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717

18+
[weakdeps]
19+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
20+
21+
[extensions]
22+
JLD2Ext = "JLD2"
23+
1824
[compat]
1925
Adapt = "4.0"
2026
GPUArraysCore = "= 0.2.0"
27+
JLD2 = "0.4, 0.5"
2128
KernelAbstractions = "0.9.28"
2229
LLVM = "3.9, 4, 5, 6, 7, 8, 9"
2330
LinearAlgebra = "1"

ext/JLD2Ext.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module JLD2Ext
2+
3+
using GPUArrays: AbstractGPUArray
4+
using JLD2: JLD2
5+
6+
JLD2.writeas(::Type{<:AbstractGPUArray{T, N}}) where {T, N} = Array{T, N}
7+
8+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
44
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
5+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
56
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
7+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
68
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

test/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ include("testsuite/random.jl")
9494
include("testsuite/uniformscaling.jl")
9595
include("testsuite/statistics.jl")
9696
include("testsuite/alloc_cache.jl")
97+
include("testsuite/jld2ext.jl")
9798

9899
"""
99100
Runs the entire GPUArrays test suite on array type `AT`

test/testsuite/jld2ext.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using JLD2
2+
using Test
3+
4+
@testsuite "ext/jld2" (AT, eltypes) -> begin
5+
for ET in eltypes
6+
@testset "$ET" begin
7+
# Test with different array sizes and dimensions
8+
for dims in ((2,), (2, 2), (2, 2, 2))
9+
# Create a random array
10+
x = AT(rand(ET, dims...))
11+
12+
# Save to a temporary file
13+
mktempdir() do dir
14+
file = joinpath(dir, "test.jld2")
15+
16+
# Save and load
17+
JLD2.save_object(file, x)
18+
y = JLD2.load_object(file)
19+
20+
# Verify the loaded array matches the original
21+
@test y isa Array{ET, length(dims)}
22+
@test size(y) == size(x)
23+
@test Array(x) y
24+
end
25+
end
26+
end
27+
end
28+
end

0 commit comments

Comments
 (0)