diff --git a/Project.toml b/Project.toml index 83d676d..97141e8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCIntegration" uuid = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" authors = ["Kun Chen", "Xiansheng Cai", "Pengcheng Hou"] -version = "0.3.4" +version = "0.3.5" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/src/main.jl b/src/main.jl index 95423e6..733f590 100644 --- a/src/main.jl +++ b/src/main.jl @@ -37,7 +37,7 @@ - `block`: Number of blocks. Each block will be evaluated by about neval/block times. Each block is assumed to be statistically independent, and will be used to estimate the error. In MPI mode, the blocks are distributed among the workers. If the numebr of workers N is larger than block, then block will be set to be N. - `print`: -2 to not print anything; -1 to print minimal information; 0 to print the iteration history in the end; >0 to print MC configuration for every `print` seconds and print the iteration history in the end. -- `gamma`: Learning rate of the reweight factor after each iteraction. Note that alpha <=1, where alpha = 0 means no reweighting. +- `gamma`: Learning rate of the reweight factor after each iteraction. Note that gamma <=1, where gamma = 0 means no reweighting. - `adapt`: Whether to adapt the grid and the reweight factor. - `debug`: Whether to print debug information (type instability, float overflow etc.) - `reweight_goal`: The expected distribution of visited times for each integrand after reweighting . If not set, then all factors will be initialized with one. Only useful for the :mcmc solver. @@ -158,11 +158,8 @@ function integrate(integrand::Function; # collect all statistics to summedConfig of the root worker MPIreduceConfig!(summedConfig[1]) - - if MCUtility.mpi_master() # only the master process will output results, no matter parallel = :mpi or :thread or :serial - ################### self-learning ########################################## - (solver == :mcmc || solver == :vegasmc) && doReweight!(summedConfig[1], gamma, reweight_goal) - end + ######################## self-learning ######################################### + (solver == :mcmc || solver == :vegasmc) && doReweightMPI!(summedConfig[1], gamma, reweight_goal, comm) ######################## syncronize between works ############################## @@ -304,8 +301,9 @@ function doReweight!(config, gamma, reweight_goal) end # println(config.visited) # println(config.reweight) - if isnothing(reweight_goal) == false - config.reweight .*= reweight_goal + if !isnothing(reweight_goal) # Apply reweight_goal if provided + # config.reweight .*= reweight_goal + config.reweight .*= reweight_goal ./ sum(reweight_goal) end # renoormalize all reweight to be (0.0, 1.0) config.reweight ./= sum(config.reweight) @@ -315,4 +313,14 @@ function doReweight!(config, gamma, reweight_goal) # Check Eq. (19) of https://arxiv.org/pdf/2009.05112.pdf for more detail # config.reweight = @. ((1 - config.reweight) / log(1 / config.reweight))^beta # config.reweight ./= sum(config.reweight) +end + +function doReweightMPI!(config::Configuration, gamma, reweight_goal::Union{Vector{Float64},Nothing}, comm::MPI.Comm) + if MCUtility.mpi_master() + # only the master process will output results, no matter parallel = :mpi or :thread or :serial + doReweight!(config, gamma, reweight_goal) + end + reweight_array = Vector{Float64}(config.reweight) + MPI.Bcast!(reweight_array, 0, comm) + config.reweight .= reweight_array end \ No newline at end of file diff --git a/test/mpi_test.jl b/test/mpi_test.jl index da2bbbf..5a7fb8c 100644 --- a/test/mpi_test.jl +++ b/test/mpi_test.jl @@ -10,7 +10,7 @@ const MCUtility = MCIntegration.MCUtility rank = MPI.Comm_rank(comm) # rank of current MPI worker root = 0 # rank of the root worker - a = [1, 2, 3] + a = [1, 2, 3] aa = MCUtility.MPIreduce(a) if rank == root @test aa == [Nworker, 2Nworker, 3Nworker] @@ -29,7 +29,7 @@ const MCUtility = MCIntegration.MCUtility end # inplace - a = [1, 2, 3] + a = [1, 2, 3] MCUtility.MPIreduce!(a) if rank == root @test a == [Nworker, 2Nworker, 3Nworker] @@ -43,7 +43,7 @@ end rank = MPI.Comm_rank(comm) # rank of current MPI worker root = 0 # rank of the root worker - a = [1, 2, 3] .* rank + a = [1, 2, 3] .* rank aa = MCUtility.MPIbcast(a) if rank != root @test aa == [0, 0, 0] @@ -62,7 +62,7 @@ end end # inplace - a = [1, 2, 3] .* rank + a = [1, 2, 3] .* rank MCUtility.MPIbcast!(a) if rank != root @test a == [0, 0, 0] @@ -85,7 +85,7 @@ end Z.histogram[1] = 1.3 cvar = CompositeVar(Y, Z) obs = [1.0,] - config = Configuration(var = (X, cvar), dof=[[1, 1], ], obs=obs) + config = Configuration(var=(X, cvar), dof=[[1, 1],], obs=obs) config.neval = 101 config.normalization = 1.1 config.visited[1] = 1.2 @@ -95,16 +95,16 @@ end MCIntegration.MPIreduceConfig!(config) if rank == root @test config.observable[1] == Nworker - @test config.neval == Nworker*101 - @test config.normalization ≈ Nworker*1.1 - @test config.visited[1] ≈ Nworker*1.2 - @test config.propose[1, 1, 1] ≈ Nworker*1.3 - @test config.accept[1, 1, 1] ≈ Nworker*1.4 + @test config.neval == Nworker * 101 + @test config.normalization ≈ Nworker * 1.1 + @test config.visited[1] ≈ Nworker * 1.2 + @test config.propose[1, 1, 1] ≈ Nworker * 1.3 + @test config.accept[1, 1, 1] ≈ Nworker * 1.4 - @test config.var[1].histogram[1] ≈ Nworker*1.1 # X + @test config.var[1].histogram[1] ≈ Nworker * 1.1 # X cvar = config.var[2] #compositevar - @test cvar[1].histogram[1] ≈ Nworker *1.2 #Y - @test cvar[2].histogram[1] ≈ Nworker*1.3 #Z + @test cvar[1].histogram[1] ≈ Nworker * 1.2 #Y + @test cvar[2].histogram[1] ≈ Nworker * 1.3 #Z end end @@ -130,7 +130,7 @@ end Z.histogram[1] = rank end - config = Configuration(var = (X, cvar), dof=[[1, 1], ]) + config = Configuration(var=(X, cvar), dof=[[1, 1],]) config.reweight = [1.1, 1.2] MCIntegration.MPIbcastConfig!(config) @@ -143,4 +143,27 @@ end @test cvar[1].histogram[1] ≈ 1.2 #Y @test cvar[2].histogram[1] ≈ 1.3 #Z end +end + +@testset "MPI doReweight!" begin + (MPI.Initialized() == false) && MPI.Init() + comm = MPI.COMM_WORLD + Nworker = MPI.Comm_size(comm) # number of MPI workers + rank = MPI.Comm_rank(comm) + root = 0 + + X = Continuous(0.0, 1.0) + config = Configuration(var=(X,), dof=[[1], [1], [1]]) + config.visited = [1, 2, 3, 4] + @test config.reweight == [0.25, 0.25, 0.25, 0.25] + + gamma = 1.0 + reweight_goal = [1.0, 2.0, 3.0, 4.0] + n_iterations = 5 + expected_reweight = [0.25, 0.25, 0.25, 0.25] + + for _ in 1:n_iterations + MCIntegration.doReweightMPI!(config, gamma, reweight_goal, comm) + end + @test all(isapprox.(config.reweight, expected_reweight, rtol=1e-3)) end \ No newline at end of file