Skip to content

Commit

Permalink
Add view
Browse files Browse the repository at this point in the history
  • Loading branch information
giopaglia committed Jul 31, 2024
1 parent f7a4778 commit 3bb9dbe
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/LabeledMultiDataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ end

function SoleBase.instances(
lmd::LabeledMultiDataset,
inds::AbstractVector{<:Integer},
inds::AbstractVector,
return_view::Union{Val{true},Val{false}} = Val(false)
)
LabeledMultiDataset(
Expand Down
12 changes: 11 additions & 1 deletion src/MultiDataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,23 @@ end

function SoleBase.instances(
md::MultiDataset,
inds::AbstractVector{<:Integer},
inds::AbstractVector,
return_view::Union{Val{true},Val{false}} = Val(false),
)
@assert return_view == Val(false)
@assert all(i->i<=ninstances(md), inds) "Cannot slice MultiDataset of $(ninstances(md)) instances with indices $(inds)."
MultiDataset(data(md)[inds,:], grouped_variables(md))
end

import Base: view
Base.@propagate_inbounds function view(md::MultiDataset, inds...)
MultiDataset(view(data(md), inds...), grouped_variables(md))
end
Base.@propagate_inbounds function view(md::MultiDataset, inds::Integer, ::Colon)
MultiDataset(view(data(md), [inds], :), grouped_variables(md))
end


function vcat(mds::MultiDataset...)
MultiDataset(vcat((data.(mds)...)), grouped_variables(first(mds)))
end
Expand Down
2 changes: 1 addition & 1 deletion src/dimensional-data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function checknvariables(d::AbstractDimensionalDataset{T,D}) where {T<:Number,D}
end
nvariables(d::AbstractDimensionalDataset{T,D}) where {T<:Number,D} = size(first(eachinstance(d)), D)

function instances(d::AbstractDimensionalDataset, inds::AbstractVector{<:Integer}, return_view::Union{Val{true},Val{false}} = Val(false))
function instances(d::AbstractDimensionalDataset, inds::AbstractVector, return_view::Union{Val{true},Val{false}} = Val(false))
if return_view == Val(true) @views d[inds] else d[inds] end
end

Expand Down

0 comments on commit 3bb9dbe

Please sign in to comment.