Skip to content

Commit

Permalink
fix: minor test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 24, 2025
1 parent f4e89a4 commit 1a7ae8a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.0"
version = "1.1.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ForwardDiff: ForwardDiff
using GPUArraysCore: AnyGPUArray
using Statistics: mean

using MLDataDevices: get_device_type, get_device, CPUDevice, CUDADevice
using MLDataDevices: MLDataDevices, get_device_type, get_device, CPUDevice, CUDADevice

is_extension_loaded(::Val) = false

Expand Down Expand Up @@ -88,4 +88,6 @@ struct DataTransferBarrier{V}
val::V
end

MLDataDevices.isleaf(::DataTransferBarrier) = true

end
4 changes: 4 additions & 0 deletions test/vision_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ end

for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152]
@testset for pretrained in [false, true]
pretrained && pkgversion(Metalhead) > v"0.9.4" && continue

model = Vision.ResNet(depth; pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
st = Lux.testmode(st)
Expand All @@ -130,6 +132,7 @@ end
(50, 32, 4), (101, 32, 8), (101, 64, 4), (152, 64, 4)]
@testset for pretrained in [false, true]
depth == 152 && pretrained && continue
pretrained && pkgversion(Metalhead) > v"0.9.4" && continue

model = Vision.ResNeXt(depth; pretrained, cardinality, base_width)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
Expand All @@ -155,6 +158,7 @@ end
for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152]
@testset for pretrained in [false, true]
depth == 152 && pretrained && continue
pretrained && pkgversion(Metalhead) > v"0.9.4" && continue

model = Vision.WideResNet(depth; pretrained)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
Expand Down

0 comments on commit 1a7ae8a

Please sign in to comment.