Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make
Base.reduced_indices
more type-stable (JuliaLang#52905)
This fixes JuliaLang#35199 by rewriting `Base.reduced_indices` to be type stable (and grounded). I was also able to remove a method since that case is covered by the general case. The changes are illustrated by the following quick benchmarks: ```julia julia> VERSION v"1.10.0" julia> using BenchmarkTools julia> M = [1 2; 3 4] 2×2 Matrix{Int64}: 1 2 3 4 julia> @Btime sum($M, dims=$(2)) 194.816 ns (5 allocations: 160 bytes) 2×1 Matrix{Int64}: 3 7 julia> @Btime sum($M, dims=$((2,))) 209.385 ns (5 allocations: 224 bytes) 2×1 Matrix{Int64}: 3 7 julia> function my_reduced_indices(inds::Base.Indices{N}, region) where N rinds = inds for i in region isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers")) d = Int(i) if d < 1 throw(ArgumentError("region dimension(s) must be ≥ 1, got $d")) elseif d <= N rinds = let rinds_=rinds ntuple(j -> j == d ? Base.reduced_index(rinds_[d]) : rinds_[j], Val(N)) end end end rinds end my_reduced_indices (generic function with 1 method) julia> Base.reduced_indices(inds::Base.Indices{N}, region::Int) where N = my_reduced_indices(inds, region) julia> Base.reduced_indices(inds::Base.Indices{N}, region) where N = my_reduced_indices(inds, region) julia> @Btime sum($M, dims=$(2)) 43.582 ns (1 allocation: 80 bytes) 2×1 Matrix{Int64}: 3 7 julia> @Btime sum($M, dims=$((2,))) 43.882 ns (1 allocation: 80 bytes) 2×1 Matrix{Int64}: 3 7 ``` I also rewrote `Base.reduced_indices0` in the same fashion. I wasn't sure how to add tests for this since the improvements are to type-groundedness. Since these changes affect all reductions I hope this solution is robust.
- Loading branch information