diff --git a/src/discretize.jl b/src/discretize.jl index 43653ba7c..ffd73770c 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -321,7 +321,6 @@ function get_numeric_integral(pinnrep::PINNRepresentation) return integration_arr end end - """ prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) diff --git a/src/eltype_matching.jl b/src/eltype_matching.jl index ca5aeecd7..d178488db 100644 --- a/src/eltype_matching.jl +++ b/src/eltype_matching.jl @@ -1,8 +1,17 @@ struct EltypeAdaptor{T} end -(l::EltypeAdaptor)(x) = fmap(Adapt.adapt(l), x) +function ensure_same_device(x, device) + if (typeof(x) != device) && !(x isa Number) + error("Device mismatch detected. Ensure all data is on the same device.") + end + return x +end + + +(l::EltypeAdaptor)(x) = fmap(y -> ensure_same_device(y, l), x) + function (l::EltypeAdaptor)(x::AbstractArray{T}) where {T} - return (isbitstype(T) || T <: Number) ? Adapt.adapt(l, x) : map(l, x) + return (isbitstype(T) || T <: Number) ? x : map(y -> ensure_same_device(y, l), x) end function Adapt.adapt_storage(::EltypeAdaptor{T}, x::AbstractArray) where {T}