Skip to content

Commit

Permalink
Merge pull request #80 from numericalEFT/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
iintSjds authored Apr 24, 2023
2 parents c34a666 + 9c8813e commit 5ece9c5
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 48 deletions.
6 changes: 4 additions & 2 deletions src/BaseMesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ export fractional_coordinates, cartesian_coordinates

function Base.getindex(mesh::AbstractUniformMesh{T,DIM}, inds...) where {T,DIM}
n = SVector{DIM,Int}(inds)
return mesh.origin + lattice_vector(mesh) * ((n .- 1 .+ mesh.shift) ./ mesh.size)
mshift = SVector{DIM,T}(mesh.shift)
return mesh.origin + lattice_vector(mesh) * ((n .- 1 .+ mshift) ./ mesh.size)
end

function Base.getindex(mesh::AbstractUniformMesh, I::Int)
Expand All @@ -50,7 +51,8 @@ end

function Base.getindex(mesh::AbstractUniformMesh{T,DIM}, ::Type{<:FracCoords}, I::Int) where {T,DIM}
n = SVector{DIM,Int}(AbstractMeshes._ind2inds(mesh.size, I))
return inv_lattice_vector(mesh) * mesh.origin + (n .- 1 .+ mesh.shift) ./ mesh.size
mshift = SVector{DIM,T}(mesh.shift)
return inv_lattice_vector(mesh) * mesh.origin + (n .- 1 .+ mshift) ./ mesh.size
end

"""
Expand Down
38 changes: 14 additions & 24 deletions test/AbstractMeshes.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
@testset "AbstractMeshes" begin
using BrillouinZoneMeshes.AbstractMeshes

function test_func_not_implemented(func, obj)
# if a func required is not implemented for obj
# an error occur
try
func(obj)
catch e
@test e isa ErrorException
end
end

# create a random concrete mesh
DIM = 2
N1, N2 = 3, 5
Expand Down Expand Up @@ -38,24 +28,24 @@
# test error thrown from funcs not implemented
struct NotAMesh{T,DIM} <: AbstractMesh{T,DIM} end
notamesh = NotAMesh{Float64,3}()
test_func_not_implemented(println, notamesh)
@test_throws ErrorException println(notamesh)

test_func_not_implemented(x -> getindex(x, 1), notamesh)
test_func_not_implemented(x -> getindex(x, 1, 2, 3), notamesh)
test_func_not_implemented(x -> getindex(x, FracCoords, 1), notamesh)
test_func_not_implemented(x -> getindex(x, FracCoords, 1, 2, 3), notamesh)
@test_throws ErrorException notamesh[1]
@test_throws ErrorException notamesh[1, 2, 3]
@test_throws ErrorException notamesh[FracCoords, 1]
@test_throws ErrorException notamesh[FracCoords, 1, 2, 3]

test_func_not_implemented(x -> locate(x, 1), notamesh)
test_func_not_implemented(x -> volume(x, 1), notamesh)
test_func_not_implemented(volume, notamesh)
@test_throws ErrorException locate(notamesh, 1)
@test_throws ErrorException volume(notamesh, 1)
@test_throws ErrorException volume(notamesh)

test_func_not_implemented(lattice_vector, notamesh)
test_func_not_implemented(inv_lattice_vector, notamesh)
test_func_not_implemented(cell_volume, notamesh)
@test_throws ErrorException lattice_vector(notamesh)
@test_throws ErrorException inv_lattice_vector(notamesh)
@test_throws ErrorException cell_volume(notamesh)

test_func_not_implemented(x -> integrate([1,], x), notamesh)
test_func_not_implemented(x -> interp([1,], x, 1), notamesh)
@test_throws ErrorException integrate([1,], notamesh)
@test_throws ErrorException interp([1,], notamesh, 1)

test_func_not_implemented(x -> interval(x, 1), notamesh)
@test_throws ErrorException interval(notamesh, 1)

end
4 changes: 4 additions & 0 deletions test/BZMeshes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
println(bzmesh)
display(bzmesh)

@inferred bzmesh[1]
@inferred bzmesh[1, 1]
@inferred bzmesh[AbstractMeshes.FracCoords, 1]

for (pi, p) in enumerate(bzmesh)
@test bzmesh[pi] p # linear index
inds = AbstractMeshes._ind2inds(size(bzmesh), pi)
Expand Down
23 changes: 12 additions & 11 deletions test/BaseMesh.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
@testset "Base Mesh" begin
rng = MersenneTwister(1234)

function test_func_not_implemented(func, obj)
# if a func required is not implemented for obj
# an error occur
try
func(obj)
catch e
@test e isa ErrorException
end
end

@testset "UMesh" begin
DIM = 2
N1, N2 = 3, 5
Expand All @@ -19,6 +9,10 @@
cell = BZMeshes.Cell(lattice=lattice)
mesh = BaseMesh.UMesh(br=cell, origin=ones(DIM) ./ 2, size=(N1, N2), shift=zeros(DIM))

@inferred mesh[1]
@inferred mesh[1, 1]
@inferred mesh[AbstractMeshes.FracCoords, 1]

@test length(mesh) == N1 * N2
@test size(mesh) == (N1, N2)

Expand All @@ -45,7 +39,7 @@

# basics
struct NotAPM{T,DIM} <: BaseMesh.AbstractProdMesh{T,DIM} end
test_func_not_implemented(x -> BaseMesh._getgrid(x, 1), NotAPM{Int,3}())
@test_throws ErrorException BaseMesh._getgrid(NotAPM{Int,3}(), 1)

@testset "DirectProdMesh" begin
N, M = 3, 2
Expand All @@ -62,6 +56,8 @@
dpm = DirectProdMesh(r, theta, phi)
println(size(dpm))

@inferred dpm[1]

vol = 0.0
for (pi, p) in enumerate(dpm)
i, j, k = AbstractMeshes._ind2inds(size(dpm), pi)
Expand Down Expand Up @@ -91,6 +87,9 @@
cm = ProdMesh(grids, theta)
println([cm.grids[i].panel[2] for i in 1:length(theta)])
println(size(cm))

@inferred cm[1]

for j in 1:length(cm.mesh)
for i in 1:length(cm.grids[j])
p = cm[i, j]
Expand All @@ -117,6 +116,8 @@

cm2 = ChebMesh(origin, latvec, cm)

@inferred cm[1]

vol = 0.0
for (i, x) in enumerate(cm2)
@test cm2[i] x
Expand Down
12 changes: 1 addition & 11 deletions test/MeshMap.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
@testset "MeshMaps" begin

function test_func_not_implemented(func, obj)
# if a func required is not implemented for obj
# an error occur
try
func(obj)
catch e
@test e isa ErrorException
end
end

locate, volume = MeshMaps.locate, MeshMaps.volume

@testset "MeshMap" begin

struct NotAMesh{T,DIM} <: AbstractMesh{T,DIM} end
notamesh = NotAMesh{Float64,3}()
test_func_not_implemented(MeshMap, notamesh)
@test_throws ErrorException MeshMap(notamesh)

# test MeshMap constructor
map = [1, 2, 2, 1, 2, 6, 6, 2, 2, 6, 6, 2, 1, 2, 2, 1]
Expand Down
14 changes: 14 additions & 0 deletions test/test_typestable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

using BrillouinZoneMeshes
using Test

@testset "Type Stable?" begin
DIM = 2
N1, N2 = 3, 5
lattice = Matrix([1/N1/2 0; 0 1.0/N2/2]') .* 2π
# so that bzmesh[i,j] = (2i-1,2j-1)
cell = BZMeshes.Cell(lattice=lattice)
mesh = BaseMesh.UMesh(br=cell, origin=ones(DIM) ./ 2, size=(N1, N2), shift=zeros(DIM))

@inferred mesh[1]
end

0 comments on commit 5ece9c5

Please sign in to comment.