diff --git a/src/BaseMesh.jl b/src/BaseMesh.jl index 087a1c4..7840af2 100644 --- a/src/BaseMesh.jl +++ b/src/BaseMesh.jl @@ -41,7 +41,8 @@ export fractional_coordinates, cartesian_coordinates function Base.getindex(mesh::AbstractUniformMesh{T,DIM}, inds...) where {T,DIM} n = SVector{DIM,Int}(inds) - return mesh.origin + lattice_vector(mesh) * ((n .- 1 .+ mesh.shift) ./ mesh.size) + mshift = SVector{DIM,T}(mesh.shift) + return mesh.origin + lattice_vector(mesh) * ((n .- 1 .+ mshift) ./ mesh.size) end function Base.getindex(mesh::AbstractUniformMesh, I::Int) @@ -50,7 +51,8 @@ end function Base.getindex(mesh::AbstractUniformMesh{T,DIM}, ::Type{<:FracCoords}, I::Int) where {T,DIM} n = SVector{DIM,Int}(AbstractMeshes._ind2inds(mesh.size, I)) - return inv_lattice_vector(mesh) * mesh.origin + (n .- 1 .+ mesh.shift) ./ mesh.size + mshift = SVector{DIM,T}(mesh.shift) + return inv_lattice_vector(mesh) * mesh.origin + (n .- 1 .+ mshift) ./ mesh.size end """ diff --git a/test/AbstractMeshes.jl b/test/AbstractMeshes.jl index 68a1814..10104b0 100644 --- a/test/AbstractMeshes.jl +++ b/test/AbstractMeshes.jl @@ -1,16 +1,6 @@ @testset "AbstractMeshes" begin using BrillouinZoneMeshes.AbstractMeshes - function test_func_not_implemented(func, obj) - # if a func required is not implemented for obj - # an error occur - try - func(obj) - catch e - @test e isa ErrorException - end - end - # create a random concrete mesh DIM = 2 N1, N2 = 3, 5 @@ -38,24 +28,24 @@ # test error thrown from funcs not implemented struct NotAMesh{T,DIM} <: AbstractMesh{T,DIM} end notamesh = NotAMesh{Float64,3}() - test_func_not_implemented(println, notamesh) + @test_throws ErrorException println(notamesh) - test_func_not_implemented(x -> getindex(x, 1), notamesh) - test_func_not_implemented(x -> getindex(x, 1, 2, 3), notamesh) - test_func_not_implemented(x -> getindex(x, FracCoords, 1), notamesh) - test_func_not_implemented(x -> getindex(x, FracCoords, 1, 2, 3), notamesh) + @test_throws ErrorException notamesh[1] + @test_throws ErrorException notamesh[1, 2, 3] + @test_throws ErrorException notamesh[FracCoords, 1] + @test_throws ErrorException notamesh[FracCoords, 1, 2, 3] - test_func_not_implemented(x -> locate(x, 1), notamesh) - test_func_not_implemented(x -> volume(x, 1), notamesh) - test_func_not_implemented(volume, notamesh) + @test_throws ErrorException locate(notamesh, 1) + @test_throws ErrorException volume(notamesh, 1) + @test_throws ErrorException volume(notamesh) - test_func_not_implemented(lattice_vector, notamesh) - test_func_not_implemented(inv_lattice_vector, notamesh) - test_func_not_implemented(cell_volume, notamesh) + @test_throws ErrorException lattice_vector(notamesh) + @test_throws ErrorException inv_lattice_vector(notamesh) + @test_throws ErrorException cell_volume(notamesh) - test_func_not_implemented(x -> integrate([1,], x), notamesh) - test_func_not_implemented(x -> interp([1,], x, 1), notamesh) + @test_throws ErrorException integrate([1,], notamesh) + @test_throws ErrorException interp([1,], notamesh, 1) - test_func_not_implemented(x -> interval(x, 1), notamesh) + @test_throws ErrorException interval(notamesh, 1) end \ No newline at end of file diff --git a/test/BZMeshes.jl b/test/BZMeshes.jl index 6f472bb..bac77d4 100644 --- a/test/BZMeshes.jl +++ b/test/BZMeshes.jl @@ -28,6 +28,10 @@ println(bzmesh) display(bzmesh) + @inferred bzmesh[1] + @inferred bzmesh[1, 1] + @inferred bzmesh[AbstractMeshes.FracCoords, 1] + for (pi, p) in enumerate(bzmesh) @test bzmesh[pi] ≈ p # linear index inds = AbstractMeshes._ind2inds(size(bzmesh), pi) diff --git a/test/BaseMesh.jl b/test/BaseMesh.jl index 4a63f31..1d399fc 100644 --- a/test/BaseMesh.jl +++ b/test/BaseMesh.jl @@ -1,16 +1,6 @@ @testset "Base Mesh" begin rng = MersenneTwister(1234) - function test_func_not_implemented(func, obj) - # if a func required is not implemented for obj - # an error occur - try - func(obj) - catch e - @test e isa ErrorException - end - end - @testset "UMesh" begin DIM = 2 N1, N2 = 3, 5 @@ -19,6 +9,10 @@ cell = BZMeshes.Cell(lattice=lattice) mesh = BaseMesh.UMesh(br=cell, origin=ones(DIM) ./ 2, size=(N1, N2), shift=zeros(DIM)) + @inferred mesh[1] + @inferred mesh[1, 1] + @inferred mesh[AbstractMeshes.FracCoords, 1] + @test length(mesh) == N1 * N2 @test size(mesh) == (N1, N2) @@ -45,7 +39,7 @@ # basics struct NotAPM{T,DIM} <: BaseMesh.AbstractProdMesh{T,DIM} end - test_func_not_implemented(x -> BaseMesh._getgrid(x, 1), NotAPM{Int,3}()) + @test_throws ErrorException BaseMesh._getgrid(NotAPM{Int,3}(), 1) @testset "DirectProdMesh" begin N, M = 3, 2 @@ -62,6 +56,8 @@ dpm = DirectProdMesh(r, theta, phi) println(size(dpm)) + @inferred dpm[1] + vol = 0.0 for (pi, p) in enumerate(dpm) i, j, k = AbstractMeshes._ind2inds(size(dpm), pi) @@ -91,6 +87,9 @@ cm = ProdMesh(grids, theta) println([cm.grids[i].panel[2] for i in 1:length(theta)]) println(size(cm)) + + @inferred cm[1] + for j in 1:length(cm.mesh) for i in 1:length(cm.grids[j]) p = cm[i, j] @@ -117,6 +116,8 @@ cm2 = ChebMesh(origin, latvec, cm) + @inferred cm[1] + vol = 0.0 for (i, x) in enumerate(cm2) @test cm2[i] ≈ x diff --git a/test/MeshMap.jl b/test/MeshMap.jl index 40d2046..288c54a 100644 --- a/test/MeshMap.jl +++ b/test/MeshMap.jl @@ -1,22 +1,12 @@ @testset "MeshMaps" begin - function test_func_not_implemented(func, obj) - # if a func required is not implemented for obj - # an error occur - try - func(obj) - catch e - @test e isa ErrorException - end - end - locate, volume = MeshMaps.locate, MeshMaps.volume @testset "MeshMap" begin struct NotAMesh{T,DIM} <: AbstractMesh{T,DIM} end notamesh = NotAMesh{Float64,3}() - test_func_not_implemented(MeshMap, notamesh) + @test_throws ErrorException MeshMap(notamesh) # test MeshMap constructor map = [1, 2, 2, 1, 2, 6, 6, 2, 2, 6, 6, 2, 1, 2, 2, 1] diff --git a/test/test_typestable.jl b/test/test_typestable.jl new file mode 100644 index 0000000..3254241 --- /dev/null +++ b/test/test_typestable.jl @@ -0,0 +1,14 @@ + +using BrillouinZoneMeshes +using Test + +@testset "Type Stable?" begin + DIM = 2 + N1, N2 = 3, 5 + lattice = Matrix([1/N1/2 0; 0 1.0/N2/2]') .* 2π + # so that bzmesh[i,j] = (2i-1,2j-1) + cell = BZMeshes.Cell(lattice=lattice) + mesh = BaseMesh.UMesh(br=cell, origin=ones(DIM) ./ 2, size=(N1, N2), shift=zeros(DIM)) + + @inferred mesh[1] +end \ No newline at end of file