Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated to Flux 0.14, removed POMDPSimulators, updated CI #66

Merged
merged 10 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .codecov.yml

This file was deleted.

36 changes: 28 additions & 8 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
name: CI

on: [push, pull_request]

on:
push:
branches:
- master
tags: '*'
pull_request:
jobs:
test:
runs-on: ubuntu-latest

name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- "1"
- "1.9" # min supported version
os:
- ubuntu-latest
- macOS-latest
- windows-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1'
arch: x64
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
files: lcov.info
70 changes: 0 additions & 70 deletions .github/workflows/ci.yml

This file was deleted.

4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
*log**/
log*
*.bson
events.out.tfevents*
events.out.tfevents*
.vscode
Manifest.toml
13 changes: 6 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepQLearning"
uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3"
repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl"
version = "0.6.5"
version = "0.7.0"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand All @@ -21,21 +21,20 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
[compat]
BSON = "0.2, 0.3"
CommonRLInterface = "0.2, 0.3"
EllipsisNotation = "0.4, 1.0"
Flux = "0.10, 0.11, 0.12"
EllipsisNotation = "1"
Flux = "0.14"
POMDPLinter = "0.1"
POMDPTools = "0.1"
POMDPs = "0.9"
Parameters = "0.12"
StatsBase = "0.32, 0.33"
StatsBase = "0.32, 0.33, 0.34"
TensorBoardLogger = "0.1"
julia = "1"
julia = "1.9"

[extras]
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["POMDPModels", "POMDPSimulators", "StaticArrays", "Test"]
test = ["POMDPModels", "StaticArrays", "Test"]
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# DeepQLearning

[![Build status](https://github.com/JuliaPOMDP/DeepQLearning.jl/workflows/CI/badge.svg)](https://github.com/JuliaPOMDP/DeepQLearning.jl/actions)
[![CodeCov](https://codecov.io/gh/JuliaPOMDP/DeepQLearning.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaPOMDP/DeepQLearning.jl)
[![Build status](https://github.com/JuliaPOMDP/DeepQLearning.jl/workflows/CI/badge.svg)](https://github.com/JuliaPOMDP/DeepQLearning.jl/actions/workflows/CI.yml)
[![codecov](https://codecov.io/github/JuliaPOMDP/DeepQLearning.jl/branch/master/graph/badge.svg?token=EfDZPMisVB)](https://codecov.io/github/JuliaPOMDP/DeepQLearning.jl)

This package provides an implementation of the Deep Q learning algorithm for solving MDPs. For more information see https://arxiv.org/pdf/1312.5602.pdf.
It uses POMDPs.jl and Flux.jl
Expand All @@ -17,12 +17,6 @@ It supports the following innovations:

```Julia
using Pkg
# Pkg.Registry.add("https://github.com/JuliaPOMDP/Registry) # for julia 1.1+

# for julia 1.0 add the registry throught the POMDP package
# Pkg.add("POMDPs")
# using POMDPs
# POMDPs.add_registry()
Pkg.add("DeepQLearning")
```

Expand Down
47 changes: 0 additions & 47 deletions appveyor.yml

This file was deleted.

4 changes: 2 additions & 2 deletions src/dueling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ function create_dueling_network(m::Chain)
@assert isa(l, Dense) error_str
end
nlayers = length(m.layers)
_, last_layer_size = size(m[end].W)
_, last_layer_size = size(m[end].weight)
val = Chain([deepcopy(m[i]) for i=duel_layer+1:nlayers-1]..., Dense(last_layer_size, 1))
adv = Chain([deepcopy(m[i]) for i=duel_layer+1:nlayers]...)
base = Chain([deepcopy(m[i]) for i=1:duel_layer+1-1]...)
base = Chain(identity, [deepcopy(m[i]) for i=1:duel_layer+1-1]...)
return DuelingNetwork(base, val, adv)
end
18 changes: 9 additions & 9 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
return solve(solver, env)
end

function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnv)
function POMDPs.solve(solver::DeepQLearningSolver, env::AbstractEnv)
action_map = collect(actions(env))
action_indices = Dict(a=>i for (i, a) in enumerate(action_map))

Expand All @@ -56,14 +56,14 @@
return dqn_train!(solver, env, policy, replay)
end

function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::AbstractNNPolicy, replay)
function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::AbstractNNPolicy, replay)
if solver.logdir !== nothing
logger = TBLogger(solver.logdir)
solver.logdir = logger.logdir
end
active_q = getnetwork(policy) # shallow copy
target_q = deepcopy(active_q)
optimizer = ADAM(solver.learning_rate)
optimizer = Adam(solver.learning_rate)
# start training
resetstate!(policy)
reset!(env)
Expand Down Expand Up @@ -177,7 +177,7 @@
return policy
end

function initialize_replay_buffer(solver::DeepQLearningSolver, env::AbstractEnv, action_indices)
function initialize_replay_buffer(solver::DeepQLearningSolver, env::AbstractEnv, action_indices)
# init and populate replay buffer
if solver.recurrence
replay = EpisodeReplayBuffer(env, solver.buffer_size, solver.batch_size, solver.trace_length)
Expand All @@ -200,7 +200,7 @@
s_batch, a_batch, r_batch, sp_batch, done_batch, indices, importance_weights = sample(replay)

active_q = getnetwork(policy)
p = params(active_q)
p = Flux.params(active_q)

loss_val = nothing
td_vals = nothing
Expand Down Expand Up @@ -237,7 +237,7 @@

# for RNNs
function batch_train!(solver::DeepQLearningSolver,
env::AbstractEnv,
env::AbstractEnv,
policy::AbstractNNPolicy,
optimizer,
target_q,
Expand All @@ -249,7 +249,7 @@
Flux.reset!(active_q)
Flux.reset!(target_q)

p = params(active_q)
p = Flux.params(active_q)

loss_val = nothing
td_vals = nothing
Expand Down Expand Up @@ -289,7 +289,7 @@

function save_model(solver::DeepQLearningSolver, active_q, scores_eval::Float64, saved_mean_reward::Float64, model_saved::Bool)
if scores_eval >= saved_mean_reward
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in params(active_q)])
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)])
if solver.verbose
@printf("Saving new model with eval reward %1.3f \n", scores_eval)
end
Expand All @@ -304,7 +304,7 @@
restore_best_model(solver, env)
end

function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)
function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)

Check warning on line 307 in src/solver.jl

View check run for this annotation

Codecov / codecov/patch

src/solver.jl#L307

Added line #L307 was not covered by tests
if solver.dueling
active_q = create_dueling_network(solver.qnetwork)
else
Expand Down
5 changes: 2 additions & 3 deletions test/flux_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using POMDPs
using Random
using DeepQLearning
using POMDPModels
using POMDPSimulators
using POMDPTools
using RLInterface
using Test
Expand Down Expand Up @@ -114,8 +113,8 @@ l, td = loss(q_sa, q_targets)

Flux.data(l)

optimizer = ADAM(Flux.params(active_q), 1e-3)
optimizer = Adam(Flux.params(active_q), 1e-3)

# use deep copy to update the target network

# use Flux.reset to reset RNN if necessary
# use Flux.reset to reset RNN if necessary
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using DeepQLearning
using POMDPModels
using POMDPSimulators
using POMDPTools
using Flux
using Random
Expand Down Expand Up @@ -203,6 +202,10 @@ end
end

RL.reset!(env::SimpleEnv) = env.s = 1
RL.state(env::SimpleEnv) = env.s
RL.setstate!(env::SimpleEnv, s::Int) = env.s = s
RL.setstate!(env::SimpleEnv, s::Float32) = env.s = Int(s)
RL.setstate!(env::SimpleEnv, s::Vector{Float32}) = env.s = Int(s[1])
RL.actions(env::SimpleEnv) = [-1, 1]
RL.observe(env::SimpleEnv) = Float32[env.s]
RL.terminated(env::SimpleEnv) = env.s >= 3
Expand Down
Loading