From 58ef31926f4cde766d99955ee5e509f529c40d13 Mon Sep 17 00:00:00 2001 From: David Josephs <42522233+josephsdavid@users.noreply.github.com> Date: Mon, 27 Jun 2022 15:55:02 -0500 Subject: [PATCH 1/7] Add docstrings, todo line formatting --- src/MLJFlux.jl | 898 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 897 insertions(+), 1 deletion(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 84bce73f..34c7d95d 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -1,4 +1,4 @@ -module MLJFlux +module MLJFlux export CUDALibs, CPU1 @@ -37,4 +37,900 @@ MLJModelInterface.metadata_pkg.((NeuralNetworkRegressor, export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor export NeuralNetworkClassifier, ImageClassifier +""" +$(MMI.doc_header(NeuralNetworkRegressor)) + +`NeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` target, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Continuous`; check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. + Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder + using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating + of the weights of the network. For further reference, see either the examples or + [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). + To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to + start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function + which can be called in the form `loss(yhat, y)`. + Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). + For a regression task, the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents + one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents + the number of samples per update of the networks weights. Typcally, batch size should be + somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, + while larger batch sizes lead towards smoother training loss curves. + In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), + and stick with it, and only tune the learning rate. In most examples, batch size is set + in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value + in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. + A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a + machine if the associated optimiser has changed. If true, the associated machine will + retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. + For training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. + Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a regression model using the Boston house price dataset +```julia + using MLJ + using MLJFlux + using Flux + using Plots +``` +First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: +```julia +data = OpenML.load(531); # Loads from https://www.openml.org/d/531 + +y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); + +scitype(y) +schema(X) +``` +Since MLJFlux models do not handle ordered factos, we can treat `:RAD` as `Continuous`: +```julia +X = coerce(X, :RAD=>Continuous) +``` +Lets also make a test set: +```julia +(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) +end +``` +Finally, we can define the model! +```julia +NeuralNetworkRegressor = @load NeuralNetworkRegressor + model = NeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` + +```julia +mach = machine(pipe, X, y) +fit!(mach, verbosity=2) + +# first element initial loss, 2:end per epoch training losses +report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + +rates = 10. .^ (-5:0) + +foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) +end +plt #!md + +savefig(joinpath("assets", "learning_rate.png")) + +pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +```julia +# For initializing or clearing the traces: + +clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing +end + + # And to update the traces: + +update_loss(loss) = push!(losses, loss) +update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) +update_epochs(epoch) = push!(epochs, epoch) +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia +controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), +WithIterationsDo(update_epochs)] + + +iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia +clear() +mach = machine(iterated_pipe, X, y) +fit!(mach) + +plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); +scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + +savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia +e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +using Measurements +l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) +@show l1_loss +``` +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +```julia +function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean +end +performance(iterated_pipe) + +three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + +errs = performance.(three_models) + +(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty +``` + +See also +[`MultitargetNeuralNetworkRegressor`](@ref) +""" +NeuralNetworkRegressor + +""" +$(MMI.doc_header(MultitargetNeuralNetworkRegressor)) + +`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any table of output targets whose element + scitype is `Continuous`; check the scitype with `schema(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a regression model using the Boston house price dataset. +```julia +using MLJ +using MLJFlux +using Flux +using Plots +using MLJBase: augment_X +``` +First, we generate some data: +```julia +X = augment_X(randn(10000, 8), true); +θ = randn((9,2)); +y = X * θ; +X = MLJ.table(X) +y = MLJ.table(y) + +schema(y) +schema(X) +``` +Lets also make a test set: +```julia +(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) +end +``` +Finally, we can define the model! +```julia +MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor + model = MultitargetNeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` + +```julia +mach = machine(pipe, X, y) +fit!(mach, verbosity=2) + +# first element initial loss, 2:end per epoch training losses +report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + +rates = 10. .^ (-5:0) + +foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) +end +plt #!md + +savefig(joinpath("assets", "learning_rate.png")) + + +pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 + +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +```julia +# For initializing or clearing the traces: + +clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing +end + +# And to update the traces: + +update_loss(loss) = push!(losses, loss) +update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) +update_epochs(epoch) = push!(epochs, epoch) +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia +controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), +WithIterationsDo(update_epochs)] + +iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia +clear() +mach = machine(iterated_pipe, X, y) +fit!(mach) + +plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); +scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + +savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia +e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +using Measurements +l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) +@show l1_loss +``` +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +```julia + +function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean +end +performance(iterated_pipe) + +three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + +errs = performance.(three_models) + +(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty + + +``` +See also +[`NeuralNetworkRegressor`](@ref) +""" +MultitargetNeuralNetworkRegressor +""" +$(MMI.doc_header(NeuralNetworkClassifier)) + +`NeuralNetworkClassifier`: a neural network model for making probabilistic predictions +of a Multiclass or OrderedFactor target, given a table of Continuous features. ) + TODO: + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; + check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: + - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + - `Flux.binary_focal_loss`: Binary version of the above +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a classification task, `softmax` is used for multiclass, single label regression, `sigmoid` is used for either binary classification or multi label classification (when there are multiple possible labels for a given sample). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a classification model using the Iris dataset. +```julia +using MLJ +using Flux +import RDatasets + +using Random +Random.seed!(123) + +MLJ.color_off() + +using Plots +pyplot(size=(600, 300*(sqrt(5)-1))); +``` +This is a very basic example, using a default builder and no standardization. +For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: +```julia +iris = RDatasets.dataset("datasets", "iris"); +y, X = unpack(iris, ==(:Species), colname -> true, rng=123); +NeuralNetworkClassifier = @load NeuralNetworkClassifier +clf = NeuralNetworkClassifier() +``` +Next, we can train the model: +```julia +import Random.seed!; seed!(123) +mach = machine(clf, X, y) +fit!(mach) +``` +We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: +```julia +clf.optimiser.eta = clf.optimiser.eta * 2 +clf.epochs = clf.epochs + 5 + +# note that if the optimizer_changes_trigger_retraining flag was set to true +# the model would be completely retrained from scratch because the optimizer was +# updated +fit!(mach, verbosity=2); +``` +We can inspect the mean training loss using the `cross_entropy` function: +```julia + +training_loss = cross_entropy(predict(mach, X), y) |> mean + +``` +And we can access the Flux chain (model) using `fitted_params`: +```julia +chain = fitted_params(mach).chain +``` +Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function +```julia +r = range(clf, :epochs, lower=1, upper=200, scale=:log10) +curve = learning_curve(clf, X, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) +using Plots +plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + +savefig("iris_history.png") +``` +See also +[`ImageClassifier`](@ref) +""" +NeuralNetworkClassifier +""" +$(MMI.doc_header(ImageClassifier)) + +`ImageClassifier`: A neural network model for making probabilistic +"predictions of a `GrayImage` target, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with +mach = machine(model, X, y) +Where +- `X`: is any `AbstractVector` of input features (eg, a `DataFrame`) whose items + are of scitype `GrayImage`; check the scitype with `scitype(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; + check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: + - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + - `Flux.binary_focal_loss`: Binary version of the above +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we use MLJ to classify the MNIST image dataset +```julia +using MLJ +using Flux +import MLJFlux +import MLJIteration # for `skip` + +MLJ.color_off() + +using Plots +pyplot(size=(600, 300*(sqrt(5)-1))); +``` +First we want to download the MNIST dataset, and unpack into images and labels +```julia +import MLDatasets: MNIST + +ENV["DATADEPS_ALWAYS_ACCEPT"] = true +images, labels = MNIST.traindata(); +``` +In MLJ, integers cannot be used for encoding categorical data, so we must coerce them into the `Multiclass` [scientific type](https://juliaai.github.io/ScientificTypes.jl/dev/). For more in this, see [Working with Categorical Data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/): +```julia +labels = coerce(labels, Multiclass); +images = coerce(images, GrayImage); + +# Checking scientific types: + +@assert scitype(images) <: AbstractVector{<:Image} +@assert scitype(labels) <: AbstractVector{<:Finite} + +images[1] +``` +For general instructions on coercing image data, see [type coercion for image data](https://alan-turing-institute.github.io/ScientificTypes.jl/dev/%23Type-coercion-for-image-data-1) +We start by defining a suitable `builder` object. This is a recipe +for building the neural network. Our builder will work for images of +any (constant) size, whether they be color or black and white (ie, +single or multi-channel). The architecture always consists of six +alternating convolution and max-pool layers, and a final dense +layer; the filter size and the number of channels after each +convolution layer is customisable. +```julia +import MLJFlux + +struct MyConvBuilder + filter_size::Int + channels1::Int + channels2::Int + channels3::Int +end + +make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) + +function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) + k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 + mod(k, 2) == 1 || error("`filter_size` must be odd. ") + p = div(k - 1, 2) # padding to preserve image size + init = Flux.glorot_uniform(rng) + front = Chain( + Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu, init=init), + MaxPool((2, 2)), + Conv((k, k), c2 => c3, pad=(p, p), relu, init=init), + MaxPool((2 ,2)), + make2d) + d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first + return Chain(front, Dense(d, n_out, init=init)) +end +``` +It is important to note that in our `build` function, there is no final softmax. This is applie by default in all MLJFlux classifiers, using the `finaliser` hyperparameter of the classifier. Now that we have our builder defined, we can define the actual moel. If you have a GPU, you can substitute in `acceleration=CudaLibs()` below. Note that in the case of convolutions, this will **greatly** increase the speed of training. +```julia +ImageClassifier = @load ImageClassifier +clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), + batch_size=50, + epochs=10, + rng=123) +``` +You can add flux options such as `optimiser` and `loss` in the snippet above. Currently, `loss` must be a flux-compatible loss, and not an MLJ measure. +Next, we can bind the model with the data in a machine, and fit the first 500 or so images: +```julia +mach = machine(clf, images, labels); + +fit!(mach, rows=1:500, verbosity=2); + +report(mach) + +chain = fitted_params(mach) + +Flux.params(chain)[2] +``` +We can tack on 20 more epochs by modifying the `epochs` field, and iteratively fit some more: +```julia +clf.epochs = clf.epochs + 20 +fit!(mach, rows=1:500); +``` +We can also make predictions and calculate an out-of-sample loss estimate, in two ways! +```julia +predicted_labels = predict(mach, rows=501:1000); +cross_entropy(predicted_labels, labels[501:1000]) |> mean +# alternative one liner! +evaluate!(mach, + resampling=Holdout(fraction_train=0.5), + measure=cross_entropy, + rows=1:1000, + verbosity=0) +``` + +## Wrapping in iteration controls + +Any iterative MLJFlux model can be wrapped in **iteration controls**, as we demonstrate next. For more on MLJ's `IteratedModel` wrapper, see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/). +The "self-iterating" classifier (`iterated_clf` below) is for iterating the image classifier defined above until a stopping criterion is hit. We use the following stopping criterion: +- `Patience(3)`: 3 consecutive increases in the loss +- `InvalidValue()`: an out-of-sample loss or a training loss that is `NaN` or `±Inf` +- `TimeLimit(t=5/60)`: training time has exceeded 5 minutes. +We can specify how often these checks (and other controls) are applied using the `Step` control. Additionally, we can define controls to +- save a snapshot of the machine every N control cycles (`save_control`) +- record traces of the out-of-sample loss and training losses for plotting (`WithLossDo`) +- record mean value traces of each Flux parameter for plotting (`Callback`) +And other controls. For a full list, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). +First, we define some helper functions and some empty vectors to store traces: +```julia +make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) +make1d(x::AbstractArray) = reshape(x, length(x)); + +# to extract the flux parameters from a machine +parameters(mach) = make1d.(Flux.params(fitted_params(mach))); + +# trace storage +losses = [] +training_losses = [] +parameter_means = Float32[]; +epochs = [] + +# to update traces +update_loss(loss) = push!(losses, loss) +update_training_loss(losses) = push!(training_losses, losses[end]) +update_means(mach) = append!(parameter_means, mean.(parameters(mach))); +update_epochs(epoch) = push!(epochs, epoch) +``` +Next, we can define our controls! We store them in a simple vector: +```julia +save_control = + MLJIteration.skip(Save(joinpath(DIR, "mnist.jlso")), predicate=3) + +controls=[Step(2), + Patience(3), + InvalidValue(), + TimeLimit(5/60), + save_control, + WithLossDo(), + WithLossDo(update_loss), + WithTrainingLossesDo(update_training_loss), + Callback(update_means), + WithIterationsDo(update_epochs) +``` +Once the controls are defined, we can instantiate and fit our "self-iterating" classifier: +```julia +iterated_clf = IteratedModel(model=clf, + controls=controls, + resampling=Holdout(fraction_train=0.7), + measure=log_loss) + +mach = machine(iterated_clf, images, labels); +fit!(mach, rows=1:500); +``` +Next we can compare the training and out-of-sample losses, as well as view the evolution of the weights: +```julia +plot(epochs, losses, + xlab = "epoch", + ylab = "root squared error", + label="out-of-sample") +plot!(epochs, training_losses, label="training") + +savefig(joinpath(DIR, "loss.png")) + +n_epochs = length(losses) +n_parameters = div(length(parameter_means), n_epochs) +parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)' +plot(epochs, parameter_means2, + title="Flux parameter mean weights", + xlab = "epoch") +# **Note.** The the higher the number, the deeper the chain parameter. +savefig(joinpath(DIR, "weights.png")) +``` +Since we saved our model every few epochs, we can retrieve the snapshots so we can make predictions! +```julia +mach2 = machine(joinpath(DIR, "mnist3.jlso")) +predict_mode(mach2, images[501:503]) +``` + +## Resuming training + +If we change `iterated_clf.controls` or `clf.epochs`, we can resume training from where it left off. This is very useful for long-running training sessions, where you may be interrupted by for example a bad connection or computer hibernation. +```julia +iterated_clf.controls[2] = Patience(4) +fit!(mach, rows=1:500) + +plot(epochs, losses, + xlab = "epoch", + ylab = "root squared error", + label="out-of-sample") +plot!(epochs, training_losses, label="training") +``` +See also +[`NeuralNetworkClassifier`](@ref) +""" +ImageClassifier + + end #module From 387f88c4f17077be3bc647c960fca7e9399520cc Mon Sep 17 00:00:00 2001 From: David Josephs <42522233+josephsdavid@users.noreply.github.com> Date: Mon, 27 Jun 2022 16:56:33 -0500 Subject: [PATCH 2/7] First model properly indented --- src/MLJFlux.jl | 97 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 24 deletions(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 34c7d95d..4fe7d3bd 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -94,7 +94,8 @@ Where - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. - `finaliser=Flux.softmax`: The final activation function of the neural network. - Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include + `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). # Operations @@ -108,7 +109,8 @@ Where The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. +- `chain`: The trained "chain", or series of layers, functions, and activations which + make up the neural network. # Report @@ -116,7 +118,9 @@ The fields of `fitted_params(mach)` are: The fields of `report(mach)` are: - `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - + all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + epoch n-1. # Examples In this example we build a regression model using the Boston house price dataset @@ -135,7 +139,7 @@ y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); scitype(y) schema(X) ``` -Since MLJFlux models do not handle ordered factos, we can treat `:RAD` as `Continuous`: +Since MLJFlux models do not handle ordered factors, we can treat `:RAD` as `Continuous`: ```julia X = coerce(X, :RAD=>Continuous) ``` @@ -144,6 +148,8 @@ Lets also make a test set: (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); ``` Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +random initial weights of the network. ```julia builder = MLJFlux.@builder begin init=Flux.glorot_uniform(rng) @@ -160,11 +166,14 @@ NeuralNetworkRegressor = @load NeuralNetworkRegressor epochs=20) ``` For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! ```julia pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) ``` If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` - +also see the losses in the output of `report(mach)` ```julia mach = machine(pipe, X, y) fit!(mach, verbosity=2) @@ -199,6 +208,8 @@ pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 ## Using Iteration Controls We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. +`NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. ```julia # For initializing or clearing the traces: @@ -252,7 +263,12 @@ savefig(joinpath("assets", "loss.png")) ### Brief note on iterated models -Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. +Training an `IteratedModel` means holding out some data (80% in this case) so an +out-of-sample loss can be tracked and used in the specified stopping criterion, +`NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by +`IteratedModel` (our pipeline model) is retrained on all data for the same number of +iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned +parameters. ## Evaluating Iterated Models @@ -266,11 +282,14 @@ using Measurements l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) @show l1_loss ``` -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). +We take this estimate of the uncertainty of the generalization error with a [grain of +salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). ## Comparison with other models on the test set -Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +Although we cannot assign them statistical significance, here are comparisons, on the +untouched test set, of the eror of our self-iterating neural network regressor with a +couple of other models trained on the same data (using default hyperparameters): ```julia function performance(model) mach = machine(model, X, y) |> fit! @@ -297,7 +316,8 @@ NeuralNetworkRegressor $(MMI.doc_header(MultitargetNeuralNetworkRegressor)) `MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. +predictions of a `Continuous` multi-target, presented as a table, given a table of +`Continuous` features. # Training data @@ -314,22 +334,47 @@ Where # Hyper-parameters -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural + network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct + your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the + updating of the weights of the network. For further reference, see either the examples + or [the Flux optimiser + documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a + learning rate (the update rate of the optimizer), a good rule of thumb is to start out + at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a + function which can be called in the form `loss(yhat, y)`. Possible loss functions are + listed in [the Flux loss function + documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, + the most natural loss functions are: - `Flux.mse` - `Flux.mae` - `Flux.msle` - `Flux.huber_loss` -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents + one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents + the number of samples per update of the networks weights. Typcally, batch size should be + somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, + while larger batch sizes lead towards smoother training loss curves. In general, it is a + good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and + only tune the learning rate. In most literature, batch size is set in powers of twos, + but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be + any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of + 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during + training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting + a machine if the associated optimiser has changed. If true, the associated machine will + retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. + For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. +Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include +`Flux.sigmoid` and the identity function (otherwise known as "linear activation"). # Operations @@ -342,18 +387,22 @@ Where The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. +- `chain`: The trained "chain", or series of layers, functions, and activations which + make up the neural network. # Report The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. +- `training_losses`: The history of training losses, a vector containing the history of + all the losses during training. The first element of the vector is the initial + penalized loss. After the first element, the nth element corresponds to the loss of + epoch n-1. # Examples -In this example we build a regression model using the Boston house price dataset. +In this example we build a regression model using a toy dataset. ```julia using MLJ using MLJFlux From 0b999ceb43141adb79295491a8584f1982503277 Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Mon, 27 Jun 2022 14:59:11 -0500 Subject: [PATCH 3/7] wip --- nn.md | 245 ++++++++++++++++++++++++++++++++++++++++++ nnc.md | 128 ++++++++++++++++++++++ nnclassif.norg | 148 +++++++++++++++++++++++++ nnm.md | 247 ++++++++++++++++++++++++++++++++++++++++++ nnregressor.norg | 273 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1041 insertions(+) create mode 100644 nn.md create mode 100644 nnc.md create mode 100644 nnclassif.norg create mode 100644 nnm.md create mode 100644 nnregressor.norg diff --git a/nn.md b/nn.md new file mode 100644 index 00000000..46641a33 --- /dev/null +++ b/nn.md @@ -0,0 +1,245 @@ +# NeuralNetworkRegressor + +`NeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` target, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with +mach = machine(model, X, y) +Where +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Continuous`; check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a regression model using the Boston house price dataset +```julia + + using MLJ + using MLJFlux + using Flux + using Plots + +``` +First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: +```julia + + data = OpenML.load(531); # Loads from https://www.openml.org/d/531 + + y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); + + scitype(y) + schema(X) + +``` +Since MLJFlux models do not handle ordered factos, we can treat `:RAD` as `Continuous`: +```julia +X = coerce(X, :RAD=>Continuous) +``` +Lets also make a test set: +```julia + + (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); + +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) + end +``` +Finally, we can define the model! +```julia + + NeuralNetworkRegressor = @load NeuralNetworkRegressor + model = NeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` + +```julia +mach = machine(pipe, X, y) + fit!(mach, verbosity=2) + + # first element initial loss, 2:end per epoch training losses + report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + + rates = 10. .^ (-5:0) + + foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) + end + plt #!md + + savefig(joinpath("assets", "learning_rate.png")) + + + pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 + +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +```julia + + # For initializing or clearing the traces: + + clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing + end + + # And to update the traces: + + update_loss(loss) = push!(losses, loss) + update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) + update_epochs(epoch) = push!(epochs, epoch) + +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia + + controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), + WithIterationsDo(update_epochs)] + + + iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) + +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia + + + clear() + mach = machine(iterated_pipe, X, y) + fit!(mach) + + plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); + scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + + savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia + + e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +#- + + using Measurements + l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) + @show l1_loss + +``` +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +```julia + + function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean + end + performance(iterated_pipe) + + three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + + errs = performance.(three_models) + + (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty + + +``` diff --git a/nnc.md b/nnc.md new file mode 100644 index 00000000..a14f2ecc --- /dev/null +++ b/nnc.md @@ -0,0 +1,128 @@ +# NeuralNetworkClassifier + +`NeuralNetworkClassifier`: +- TODO + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with +mach = machine(model, X, y) +Where +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Finite` with `n_out` classes; check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: + - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + - `Flux.binary_focal_loss`: Binary version of the above +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a classification model using the Iris dataset. +```julia + + using MLJ + using Flux + import RDatasets + + using Random + Random.seed!(123) + + MLJ.color_off() + + using Plots + pyplot(size=(600, 300*(sqrt(5)-1))); + +``` +This is a very basic example, using a default builder and no standardization. +For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: +```julia + + iris = RDatasets.dataset("datasets", "iris"); + y, X = unpack(iris, ==(:Species), colname -> true, rng=123); + NeuralNetworkClassifier = @load NeuralNetworkClassifier + clf = NeuralNetworkClassifier() + +``` +Next, we can train the model: +```julia +import Random.seed!; seed!(123) + mach = machine(clf, X, y) + fit!(mach) +``` +We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: +```julia +clf.optimiser.eta = clf.optimiser.eta * 2 + clf.epochs = clf.epochs + 5 + + # note that if the optimizer_changes_trigger_retraining flag was set to true + # the model would be completely retrained from scratch because the optimizer was + # updated + fit!(mach, verbosity=2); +``` +We can inspect the mean training loss using the `cross_entropy` function: +```julia + + training_loss = cross_entropy(predict(mach, X), y) |> mean + +``` +And we can access the Flux chain (model) using `fitted_params`: +```julia +training_loss = cross_entropy(predict(mach, X), y) |> mean +``` +Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function +```julia +r = range(clf, :epochs, lower=1, upper=200, scale=:log10) + curve = learning_curve(clf, X, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) + using Plots + plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + + savefig("iris_history.png") +``` diff --git a/nnclassif.norg b/nnclassif.norg new file mode 100644 index 00000000..25ba3847 --- /dev/null +++ b/nnclassif.norg @@ -0,0 +1,148 @@ +* NeuralNetworkClassifier + + `NeuralNetworkClassifier`: + - [ ] TODO + +* Training data + + In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + + Where + + - `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` + + - `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Finite` with `n_out` classes; check the scitype with `scitype(y)` + + +* Hyper-parameters + + - `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. + - `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or {https://fluxml.ai/Flux.jl/stable/training/optimisers/}[the Flux optimiser documentation]. To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. + - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in {https://fluxml.ai/Flux.jl/stable/models/losses/}[the Flux loss function documentation]. For a classification task, the most natural loss functions are: + -- `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + -- `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + -- `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + -- `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + -- `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + -- `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + -- `Flux.binary_focal_loss`: Binary version of the above + - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. + - `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. + - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. + - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. + - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. + - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. + - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. + - `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +* Operations + + - `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. + - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + + +* Fitted parameters + + The fields of `fitted_params(mach)` are: + + - `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +* Report + + The fields of `report(mach)` are: + + - `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch $n-1$. + +* Examples + + In this example we build a classification model using the Iris dataset. + + @code julia + + using MLJ + using Flux + import RDatasets + + using Random + Random.seed!(123) + + MLJ.color_off() + + using Plots + pyplot(size=(600, 300*(sqrt(5)-1))); + + @end + + This is a very basic example, using a default builder and no standardization. + For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: + + @code julia + + iris = RDatasets.dataset("datasets", "iris"); + y, X = unpack(iris, ==(:Species), colname -> true, rng=123); + NeuralNetworkClassifier = @load NeuralNetworkClassifier + clf = NeuralNetworkClassifier() + + @end + + Next, we can train the model: + @code julia + import Random.seed!; seed!(123) + mach = machine(clf, X, y) + fit!(mach) + @end + + We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: + + @code julia + clf.optimiser.eta = clf.optimiser.eta * 2 + clf.epochs = clf.epochs + 5 + + # note that if the optimizer_changes_trigger_retraining flag was set to true + # the model would be completely retrained from scratch because the optimizer was + # updated + fit!(mach, verbosity=2); + @end + + We can inspect the mean training loss using the `cross_entropy` function: + + @code julia + + training_loss = cross_entropy(predict(mach, X), y) |> mean + + @end + + And we can access the Flux chain (model) using `fitted_params`: + + @code julia + training_loss = cross_entropy(predict(mach, X), y) |> mean + @end + + Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function + + @code julia + r = range(clf, :epochs, lower=1, upper=200, scale=:log10) + curve = learning_curve(clf, X, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) + using Plots + plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + + savefig("iris_history.png") + @end + diff --git a/nnm.md b/nnm.md new file mode 100644 index 00000000..5c0234dc --- /dev/null +++ b/nnm.md @@ -0,0 +1,247 @@ +# MultitargetNeuralNetworkRegressor + +`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with +mach = machine(model, X, y) +Where +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any table of output targets whose element + scitype is `Continuous`; check the scitype with `schema(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a regression model using the Boston house price dataset +```julia + + using MLJ + using MLJFlux + using Flux + using Plots + using MLJBase: augment_X + +``` +First, we generate some data: +```julia + + X = augment_X(randn(10000, 8), true); + θ = randn((9,2)); + y = X * θ; + X = MLJ.table(X) + y = MLJ.table(y) + + + + + schema(y) + schema(X) + +``` +Lets also make a test set: +```julia + + (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); + +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) + end +``` +Finally, we can define the model! +```julia + + MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor + model = MultitargetNeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` + +```julia +mach = machine(pipe, X, y) + fit!(mach, verbosity=2) + + # first element initial loss, 2:end per epoch training losses + report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + + rates = 10. .^ (-5:0) + + foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) + end + plt #!md + + savefig(joinpath("assets", "learning_rate.png")) + + + pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 + +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +```julia + + # For initializing or clearing the traces: + + clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing + end + + # And to update the traces: + + update_loss(loss) = push!(losses, loss) + update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) + update_epochs(epoch) = push!(epochs, epoch) + +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia + + controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), + WithIterationsDo(update_epochs)] + + + iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) + +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia + + + clear() + mach = machine(iterated_pipe, X, y) + fit!(mach) + + plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); + scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + + savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia + + e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +#- + + using Measurements + l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) + @show l1_loss + +``` +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +```julia + + function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean + end + performance(iterated_pipe) + + three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + + errs = performance.(three_models) + + (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty + + +``` diff --git a/nnregressor.norg b/nnregressor.norg new file mode 100644 index 00000000..cdeb2277 --- /dev/null +++ b/nnregressor.norg @@ -0,0 +1,273 @@ +* MultitargetNeuralNetworkRegressor + + `MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic + predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. + +* Training data + + In MLJ or MLJBase, bind an instance `model` to data with + + mach = machine(model, X, y) + + Where + + - `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` + + - `y`: is the target, which can be any table of output targets whose element + scitype is `Continuous`; check the scitype with `schema(y)` + + +* Hyper-parameters + + - `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. + - `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or {https://fluxml.ai/Flux.jl/stable/training/optimisers/}[the Flux optimiser documentation]. To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. + - `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in {https://fluxml.ai/Flux.jl/stable/models/losses/}[the Flux loss function documentation]. For a regression task, the most natural loss functions are: + -- `Flux.mse` + -- `Flux.mae` + -- `Flux.msle` + -- `Flux.huber_loss` + - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. + - `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. + - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. + - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. + - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. + - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. + - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. + - `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +* Operations + + - `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +* Fitted parameters + + The fields of `fitted_params(mach)` are: + + - `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +* Report + + The fields of `report(mach)` are: + + - `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch $n-1$. + +* Examples + +In this example we build a regression model using the Boston house price dataset + + @code julia + + using MLJ + using MLJFlux + using Flux + using Plots + using MLJBase: augment_X + + @end + + First, we generate some data: + + @code julia + + X = augment_X(randn(10000, 8), true); + θ = randn((9,2)); + y = X * θ; + X = MLJ.table(X) + y = MLJ.table(y) + + + + + schema(y) + schema(X) + + @end + + Lets also make a test set: + + @code julia + + (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); + + @end + + Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. + + @code julia + builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) + end + @end + + Finally, we can define the model! + + @code julia + + MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor + model = MultitargetNeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) + @end + + For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have {https://www.informit.com/articles/article.aspx?p=3131594&seqNum=2}[saturated neurons] and not train well. Therefore, standardization is key! + + @code julia + pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) + @end + + If we fit with a high verbosity ($>1$), we will see the losses during training. We can also see the losses in the output of `report(mach)` + + + @code julia + mach = machine(pipe, X, y) + fit!(mach, verbosity=2) + + # first element initial loss, 2:end per epoch training losses + report(mach).transformed_target_model_deterministic.training_losses + + @end + +** Experimenting with learning rate + + We can visually compare how the learning rate affects the predictions: + + @code julia + plt = plot() + + rates = 10. .^ (-5:0) + + foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) + end + plt #!md + + savefig(joinpath("assets", "learning_rate.png")) + + + pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 + + @end + +** Using Iteration Controls + We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. + + @code julia + + # For initializing or clearing the traces: + + clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing + end + + # And to update the traces: + + update_loss(loss) = push!(losses, loss) + update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) + update_epochs(epoch) = push!(epochs, epoch) + + @end + + For further reference of controls, see {https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/#Controls-provided}[the documentation]. To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: + + @code julia + + controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), + WithIterationsDo(update_epochs)] + + + iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) + + @end + + Next, we can clear the traces, fit the model, and plot the traces: + + @code julia + + + clear() + mach = machine(iterated_pipe, X, y) + fit!(mach) + + plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); + scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + + savefig(joinpath("assets", "loss.png")) + @end + +*** Brief note on iterated models + Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +** Evaluating Iterated Models + We can evaluate our model with the `evaluate!` function: + + @code julia + + e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +#- + + using Measurements + l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) + @show l1_loss + + @end + +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +** Comparison with other models on the test set + + Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): + + @code julia + + function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean + end + performance(iterated_pipe) + + three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + + errs = performance.(three_models) + + (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty + + + @end + From adb2cb8410dceabeb054ad61eb53a90f23de0fce Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 12 Jul 2022 16:36:33 -0500 Subject: [PATCH 4/7] update with code review suggestions --- src/MLJFlux.jl | 943 ------------------------------------------- src/classifier.jl | 6 +- src/image.jl | 5 +- src/regressor.jl | 13 +- src/types.jl | 996 +++++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 951 insertions(+), 1012 deletions(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 4fe7d3bd..d2e63add 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -37,949 +37,6 @@ MLJModelInterface.metadata_pkg.((NeuralNetworkRegressor, export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor export NeuralNetworkClassifier, ImageClassifier -""" -$(MMI.doc_header(NeuralNetworkRegressor)) - -`NeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` target, given a table of `Continuous` features. - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with - mach = machine(model, X, y) - -Where - -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Continuous`; check the scitype with `scitype(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. - Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder - using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating - of the weights of the network. For further reference, see either the examples or - [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). - To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to - start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function - which can be called in the form `loss(yhat, y)`. - Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). - For a regression task, the most natural loss functions are: - - `Flux.mse` - - `Flux.mae` - - `Flux.msle` - - `Flux.huber_loss` -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents - one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents - the number of samples per update of the networks weights. Typcally, batch size should be - somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, - while larger batch sizes lead towards smoother training loss curves. - In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), - and stick with it, and only tune the learning rate. In most examples, batch size is set - in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value - in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. - A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a - machine if the associated optimiser has changed. If true, the associated machine will - retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. - For training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. - Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include - `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - deterministic. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: - -- `chain`: The trained "chain", or series of layers, functions, and activations which - make up the neural network. - - -# Report - -The fields of `report(mach)` are: - -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - epoch n-1. -# Examples - -In this example we build a regression model using the Boston house price dataset -```julia - using MLJ - using MLJFlux - using Flux - using Plots -``` -First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: -```julia -data = OpenML.load(531); # Loads from https://www.openml.org/d/531 - -y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); - -scitype(y) -schema(X) -``` -Since MLJFlux models do not handle ordered factors, we can treat `:RAD` as `Continuous`: -```julia -X = coerce(X, :RAD=>Continuous) -``` -Lets also make a test set: -```julia -(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); -``` -Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. -expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. -random initial weights of the network. -```julia -builder = MLJFlux.@builder begin - init=Flux.glorot_uniform(rng) - Chain(Dense(n_in, 64, relu, init=init), - Dense(64, 32, relu, init=init), - Dense(32, 1, init=init)) -end -``` -Finally, we can define the model! -```julia -NeuralNetworkRegressor = @load NeuralNetworkRegressor - model = NeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) -``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -```julia -pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) -``` -If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` -also see the losses in the output of `report(mach)` -```julia -mach = machine(pipe, X, y) -fit!(mach, verbosity=2) - -# first element initial loss, 2:end per epoch training losses -report(mach).transformed_target_model_deterministic.training_losses - -``` - -## Experimenting with learning rate - -We can visually compare how the learning rate affects the predictions: -```julia -plt = plot() - -rates = 10. .^ (-5:0) - -foreach(rates) do η - pipe.transformed_target_model_deterministic.model.optimiser.eta = η - fit!(mach, force=true, verbosity=0) - losses = - report(mach).transformed_target_model_deterministic.model.training_losses[3:end] - plot!(1:length(losses), losses, label=η) -end -plt #!md - -savefig(joinpath("assets", "learning_rate.png")) - -pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 -``` - -## Using Iteration Controls - -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. -trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. -`NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. -```julia -# For initializing or clearing the traces: - -clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing -end - - # And to update the traces: - -update_loss(loss) = push!(losses, loss) -update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) -update_epochs(epoch) = push!(epochs, epoch) -``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia -controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), -WithIterationsDo(update_epochs)] - - -iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia -clear() -mach = machine(iterated_pipe, X, y) -fit!(mach) - -plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); -scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - -savefig(joinpath("assets", "loss.png")) -``` - -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an -out-of-sample loss can be tracked and used in the specified stopping criterion, -`NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by -`IteratedModel` (our pipeline model) is retrained on all data for the same number of -iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned -parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia -e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -using Measurements -l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) -@show l1_loss -``` -We take this estimate of the uncertainty of the generalization error with a [grain of -salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the -untouched test set, of the eror of our self-iterating neural network regressor with a -couple of other models trained on the same data (using default hyperparameters): -```julia -function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean -end -performance(iterated_pipe) - -three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - -errs = performance.(three_models) - -(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty -``` - -See also -[`MultitargetNeuralNetworkRegressor`](@ref) -""" -NeuralNetworkRegressor - -""" -$(MMI.doc_header(MultitargetNeuralNetworkRegressor)) - -`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` multi-target, presented as a table, given a table of -`Continuous` features. - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with - mach = machine(model, X, y) - -Where - -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any table of output targets whose element - scitype is `Continuous`; check the scitype with `schema(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural - network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct - your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the - updating of the weights of the network. For further reference, see either the examples - or [the Flux optimiser - documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a - learning rate (the update rate of the optimizer), a good rule of thumb is to start out - at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.mse`: The loss function which the network will optimize. Should be a - function which can be called in the form `loss(yhat, y)`. Possible loss functions are - listed in [the Flux loss function - documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, - the most natural loss functions are: - - `Flux.mse` - - `Flux.mae` - - `Flux.msle` - - `Flux.huber_loss` -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents - one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents - the number of samples per update of the networks weights. Typcally, batch size should be - somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, - while larger batch sizes lead towards smoother training loss curves. In general, it is a - good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and - only tune the learning rate. In most literature, batch size is set in powers of twos, - but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be - any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of - 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during - training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting - a machine if the associated optimiser has changed. If true, the associated machine will - retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. - For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. -Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include -`Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - deterministic. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: - -- `chain`: The trained "chain", or series of layers, functions, and activations which - make up the neural network. - - -# Report - -The fields of `report(mach)` are: - -- `training_losses`: The history of training losses, a vector containing the history of - all the losses during training. The first element of the vector is the initial - penalized loss. After the first element, the nth element corresponds to the loss of - epoch n-1. - -# Examples - -In this example we build a regression model using a toy dataset. -```julia -using MLJ -using MLJFlux -using Flux -using Plots -using MLJBase: augment_X -``` -First, we generate some data: -```julia -X = augment_X(randn(10000, 8), true); -θ = randn((9,2)); -y = X * θ; -X = MLJ.table(X) -y = MLJ.table(y) - -schema(y) -schema(X) -``` -Lets also make a test set: -```julia -(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); -``` -Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. -```julia -builder = MLJFlux.@builder begin - init=Flux.glorot_uniform(rng) - Chain(Dense(n_in, 64, relu, init=init), - Dense(64, 32, relu, init=init), - Dense(32, 1, init=init)) -end -``` -Finally, we can define the model! -```julia -MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor - model = MultitargetNeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) -``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -```julia -pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) -``` -If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` - -```julia -mach = machine(pipe, X, y) -fit!(mach, verbosity=2) - -# first element initial loss, 2:end per epoch training losses -report(mach).transformed_target_model_deterministic.training_losses - -``` - -## Experimenting with learning rate - -We can visually compare how the learning rate affects the predictions: -```julia -plt = plot() - -rates = 10. .^ (-5:0) - -foreach(rates) do η - pipe.transformed_target_model_deterministic.model.optimiser.eta = η - fit!(mach, force=true, verbosity=0) - losses = - report(mach).transformed_target_model_deterministic.model.training_losses[3:end] - plot!(1:length(losses), losses, label=η) -end -plt #!md - -savefig(joinpath("assets", "learning_rate.png")) - - -pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 - -``` - -## Using Iteration Controls - -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. -```julia -# For initializing or clearing the traces: - -clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing -end - -# And to update the traces: - -update_loss(loss) = push!(losses, loss) -update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) -update_epochs(epoch) = push!(epochs, epoch) -``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia -controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), -WithIterationsDo(update_epochs)] - -iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia -clear() -mach = machine(iterated_pipe, X, y) -fit!(mach) - -plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); -scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - -savefig(joinpath("assets", "loss.png")) -``` - -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia -e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -using Measurements -l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) -@show l1_loss -``` -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): -```julia - -function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean -end -performance(iterated_pipe) - -three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - -errs = performance.(three_models) - -(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty - - -``` -See also -[`NeuralNetworkRegressor`](@ref) -""" -MultitargetNeuralNetworkRegressor -""" -$(MMI.doc_header(NeuralNetworkClassifier)) - -`NeuralNetworkClassifier`: a neural network model for making probabilistic predictions -of a Multiclass or OrderedFactor target, given a table of Continuous features. ) - TODO: - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with - mach = machine(model, X, y) - -Where - -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; - check the scitype with `scitype(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: - - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. - - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. - - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. - - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - - `Flux.binary_focal_loss`: Binary version of the above -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a classification task, `softmax` is used for multiclass, single label regression, `sigmoid` is used for either binary classification or multi label classification (when there are multiple possible labels for a given sample). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. -- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions - returned above. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: - -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -# Report - -The fields of `report(mach)` are: - -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - -# Examples - -In this example we build a classification model using the Iris dataset. -```julia -using MLJ -using Flux -import RDatasets - -using Random -Random.seed!(123) - -MLJ.color_off() - -using Plots -pyplot(size=(600, 300*(sqrt(5)-1))); -``` -This is a very basic example, using a default builder and no standardization. -For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: -```julia -iris = RDatasets.dataset("datasets", "iris"); -y, X = unpack(iris, ==(:Species), colname -> true, rng=123); -NeuralNetworkClassifier = @load NeuralNetworkClassifier -clf = NeuralNetworkClassifier() -``` -Next, we can train the model: -```julia -import Random.seed!; seed!(123) -mach = machine(clf, X, y) -fit!(mach) -``` -We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: -```julia -clf.optimiser.eta = clf.optimiser.eta * 2 -clf.epochs = clf.epochs + 5 - -# note that if the optimizer_changes_trigger_retraining flag was set to true -# the model would be completely retrained from scratch because the optimizer was -# updated -fit!(mach, verbosity=2); -``` -We can inspect the mean training loss using the `cross_entropy` function: -```julia - -training_loss = cross_entropy(predict(mach, X), y) |> mean - -``` -And we can access the Flux chain (model) using `fitted_params`: -```julia -chain = fitted_params(mach).chain -``` -Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function -```julia -r = range(clf, :epochs, lower=1, upper=200, scale=:log10) -curve = learning_curve(clf, X, y, - range=r, - resampling=Holdout(fraction_train=0.7), - measure=cross_entropy) -using Plots -plot(curve.parameter_values, - curve.measurements, - xlab=curve.parameter_name, - xscale=curve.parameter_scale, - ylab = "Cross Entropy") - -savefig("iris_history.png") -``` -See also -[`ImageClassifier`](@ref) -""" -NeuralNetworkClassifier -""" -$(MMI.doc_header(ImageClassifier)) - -`ImageClassifier`: A neural network model for making probabilistic -"predictions of a `GrayImage` target, given a table of `Continuous` features. - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with -mach = machine(model, X, y) -Where -- `X`: is any `AbstractVector` of input features (eg, a `DataFrame`) whose items - are of scitype `GrayImage`; check the scitype with `scitype(X)` -- `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; - check the scitype with `scitype(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: - - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. - - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. - - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. - - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - - `Flux.binary_focal_loss`: Binary version of the above -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. -- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions - returned above. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -# Report - -The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - -# Examples - -In this example we use MLJ to classify the MNIST image dataset -```julia -using MLJ -using Flux -import MLJFlux -import MLJIteration # for `skip` - -MLJ.color_off() - -using Plots -pyplot(size=(600, 300*(sqrt(5)-1))); -``` -First we want to download the MNIST dataset, and unpack into images and labels -```julia -import MLDatasets: MNIST - -ENV["DATADEPS_ALWAYS_ACCEPT"] = true -images, labels = MNIST.traindata(); -``` -In MLJ, integers cannot be used for encoding categorical data, so we must coerce them into the `Multiclass` [scientific type](https://juliaai.github.io/ScientificTypes.jl/dev/). For more in this, see [Working with Categorical Data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/): -```julia -labels = coerce(labels, Multiclass); -images = coerce(images, GrayImage); - -# Checking scientific types: - -@assert scitype(images) <: AbstractVector{<:Image} -@assert scitype(labels) <: AbstractVector{<:Finite} - -images[1] -``` -For general instructions on coercing image data, see [type coercion for image data](https://alan-turing-institute.github.io/ScientificTypes.jl/dev/%23Type-coercion-for-image-data-1) -We start by defining a suitable `builder` object. This is a recipe -for building the neural network. Our builder will work for images of -any (constant) size, whether they be color or black and white (ie, -single or multi-channel). The architecture always consists of six -alternating convolution and max-pool layers, and a final dense -layer; the filter size and the number of channels after each -convolution layer is customisable. -```julia -import MLJFlux - -struct MyConvBuilder - filter_size::Int - channels1::Int - channels2::Int - channels3::Int -end - -make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) - -function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) - k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 - mod(k, 2) == 1 || error("`filter_size` must be odd. ") - p = div(k - 1, 2) # padding to preserve image size - init = Flux.glorot_uniform(rng) - front = Chain( - Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init), - MaxPool((2, 2)), - Conv((k, k), c1 => c2, pad=(p, p), relu, init=init), - MaxPool((2, 2)), - Conv((k, k), c2 => c3, pad=(p, p), relu, init=init), - MaxPool((2 ,2)), - make2d) - d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first - return Chain(front, Dense(d, n_out, init=init)) -end -``` -It is important to note that in our `build` function, there is no final softmax. This is applie by default in all MLJFlux classifiers, using the `finaliser` hyperparameter of the classifier. Now that we have our builder defined, we can define the actual moel. If you have a GPU, you can substitute in `acceleration=CudaLibs()` below. Note that in the case of convolutions, this will **greatly** increase the speed of training. -```julia -ImageClassifier = @load ImageClassifier -clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), - batch_size=50, - epochs=10, - rng=123) -``` -You can add flux options such as `optimiser` and `loss` in the snippet above. Currently, `loss` must be a flux-compatible loss, and not an MLJ measure. -Next, we can bind the model with the data in a machine, and fit the first 500 or so images: -```julia -mach = machine(clf, images, labels); - -fit!(mach, rows=1:500, verbosity=2); - -report(mach) - -chain = fitted_params(mach) - -Flux.params(chain)[2] -``` -We can tack on 20 more epochs by modifying the `epochs` field, and iteratively fit some more: -```julia -clf.epochs = clf.epochs + 20 -fit!(mach, rows=1:500); -``` -We can also make predictions and calculate an out-of-sample loss estimate, in two ways! -```julia -predicted_labels = predict(mach, rows=501:1000); -cross_entropy(predicted_labels, labels[501:1000]) |> mean -# alternative one liner! -evaluate!(mach, - resampling=Holdout(fraction_train=0.5), - measure=cross_entropy, - rows=1:1000, - verbosity=0) -``` - -## Wrapping in iteration controls - -Any iterative MLJFlux model can be wrapped in **iteration controls**, as we demonstrate next. For more on MLJ's `IteratedModel` wrapper, see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/). -The "self-iterating" classifier (`iterated_clf` below) is for iterating the image classifier defined above until a stopping criterion is hit. We use the following stopping criterion: -- `Patience(3)`: 3 consecutive increases in the loss -- `InvalidValue()`: an out-of-sample loss or a training loss that is `NaN` or `±Inf` -- `TimeLimit(t=5/60)`: training time has exceeded 5 minutes. -We can specify how often these checks (and other controls) are applied using the `Step` control. Additionally, we can define controls to -- save a snapshot of the machine every N control cycles (`save_control`) -- record traces of the out-of-sample loss and training losses for plotting (`WithLossDo`) -- record mean value traces of each Flux parameter for plotting (`Callback`) -And other controls. For a full list, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). -First, we define some helper functions and some empty vectors to store traces: -```julia -make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) -make1d(x::AbstractArray) = reshape(x, length(x)); - -# to extract the flux parameters from a machine -parameters(mach) = make1d.(Flux.params(fitted_params(mach))); - -# trace storage -losses = [] -training_losses = [] -parameter_means = Float32[]; -epochs = [] - -# to update traces -update_loss(loss) = push!(losses, loss) -update_training_loss(losses) = push!(training_losses, losses[end]) -update_means(mach) = append!(parameter_means, mean.(parameters(mach))); -update_epochs(epoch) = push!(epochs, epoch) -``` -Next, we can define our controls! We store them in a simple vector: -```julia -save_control = - MLJIteration.skip(Save(joinpath(DIR, "mnist.jlso")), predicate=3) - -controls=[Step(2), - Patience(3), - InvalidValue(), - TimeLimit(5/60), - save_control, - WithLossDo(), - WithLossDo(update_loss), - WithTrainingLossesDo(update_training_loss), - Callback(update_means), - WithIterationsDo(update_epochs) -``` -Once the controls are defined, we can instantiate and fit our "self-iterating" classifier: -```julia -iterated_clf = IteratedModel(model=clf, - controls=controls, - resampling=Holdout(fraction_train=0.7), - measure=log_loss) - -mach = machine(iterated_clf, images, labels); -fit!(mach, rows=1:500); -``` -Next we can compare the training and out-of-sample losses, as well as view the evolution of the weights: -```julia -plot(epochs, losses, - xlab = "epoch", - ylab = "root squared error", - label="out-of-sample") -plot!(epochs, training_losses, label="training") - -savefig(joinpath(DIR, "loss.png")) - -n_epochs = length(losses) -n_parameters = div(length(parameter_means), n_epochs) -parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)' -plot(epochs, parameter_means2, - title="Flux parameter mean weights", - xlab = "epoch") -# **Note.** The the higher the number, the deeper the chain parameter. -savefig(joinpath(DIR, "weights.png")) -``` -Since we saved our model every few epochs, we can retrieve the snapshots so we can make predictions! -```julia -mach2 = machine(joinpath(DIR, "mnist3.jlso")) -predict_mode(mach2, images[501:503]) -``` - -## Resuming training - -If we change `iterated_clf.controls` or `clf.epochs`, we can resume training from where it left off. This is very useful for long-running training sessions, where you may be interrupted by for example a bad connection or computer hibernation. -```julia -iterated_clf.controls[2] = Patience(4) -fit!(mach, rows=1:500) - -plot(epochs, losses, - xlab = "epoch", - ylab = "root squared error", - label="out-of-sample") -plot!(epochs, training_losses, label="training") -``` -See also -[`NeuralNetworkClassifier`](@ref) -""" -ImageClassifier end #module diff --git a/src/classifier.jl b/src/classifier.jl index 82d4efc9..2825dff7 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -31,8 +31,4 @@ end MLJModelInterface.metadata_model(NeuralNetworkClassifier, input=Table(Continuous), target=AbstractVector{<:Finite}, - path="MLJFlux.NeuralNetworkClassifier", - descr="A neural network model for making "* - "probabilistic predictions of a "* - "`Multiclass` or `OrderedFactor` target, "* - "given a table of `Continuous` features. ") + path="MLJFlux.NeuralNetworkClassifier") diff --git a/src/image.jl b/src/image.jl index 5c973eb7..dc8d5637 100644 --- a/src/image.jl +++ b/src/image.jl @@ -29,7 +29,4 @@ end MLJModelInterface.metadata_model(ImageClassifier, input=AbstractVector{<:MLJModelInterface.Image}, target=AbstractVector{<:Multiclass}, - path="MLJFlux.ImageClassifier", - descr="A neural network model for making probabilistic "* - "predictions of a `GrayImage` target, "* - "given a table of `Continuous` features. ") + path="MLJFlux.ImageClassifier") diff --git a/src/regressor.jl b/src/regressor.jl index f932bff7..85a431aa 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -23,11 +23,7 @@ end MLJModelInterface.metadata_model(NeuralNetworkRegressor, input=Table(Continuous), target=AbstractVector{<:Continuous}, - path="MLJFlux.NeuralNetworkRegressor", - descr="A neural network model for making "* - "deterministic predictions of a "* - "`Continuous` target, given a table of "* - "`Continuous` features. ") + path="MLJFlux.NeuralNetworkRegressor") # # MULTITARGET NEURAL NETWORK REGRESSOR @@ -59,9 +55,4 @@ end MLJModelInterface.metadata_model(MultitargetNeuralNetworkRegressor, input=Table(Continuous), target=Table(Continuous), - path="MLJFlux.MultitargetNeuralNetworkRegressor", - descr = "A neural network model for making "* - "deterministic predictions of a "* - "`Continuous` multi-target, presented "* - "as a table, given a table of "* - "`Continuous` features. ") + path="MLJFlux.MultitargetNeuralNetworkRegressor") diff --git a/src/types.jl b/src/types.jl index bf5674af..16c3c295 100644 --- a/src/types.jl +++ b/src/types.jl @@ -3,51 +3,6 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} -const doc_regressor(model_name) = """ - - $model_name(; hyparameters...) - -Instantiate an MLJFlux model. Available hyperparameters: - -- `builder`: Default = `MLJFlux.Linear(σ=Flux.relu)` (regressors) or - `MLJFlux.Short(n_hidden=0, dropout=0.5, σ=Flux.σ)` (classifiers) - -- `optimiser`: The optimiser to use for training. Default = - `Flux.ADAM()` - -- `loss`: The loss function used for training. Default = `Flux.mse` - (regressors) and `Flux.crossentropy` (classifiers) - -- `epochs`: Number of epochs to train for. Default = `10` - -- `batch_size`: The batch_size for the data. Default = 1 - -- `lambda`: The regularization strength. Default = 0. Range = [0, ∞) - -- `alpha`: The L2/L1 mix of regularization. Default = 0. Range = [0, 1] - -- `rng`: The random number generator (RNG) passed to builders, for - weight intitialization, for example. Can be any `AbstractRNG` or - the seed (integer) for a `MersenneTwister` that is reset on every - cold restart of model (machine) training. Default = - `GLOBAL_RNG`. - -- `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`. - -- `optimiser_changes_trigger_retraining`: True if fitting an - associated machine should trigger retraining from scratch whenever - the optimiser changes. Default = `false` - -""" - -doc_classifier(model_name) = doc_regressor(model_name)*""" -- `finaliser`: Operation applied to the unnormalized output of the - final layer to obtain probabilities (outputs summing to - one). The shape of the inputs and outputs - of this operator must match. Default = `Flux.softmax`. - -""" - for Model in [:NeuralNetworkClassifier, :ImageClassifier] ex = quote @@ -97,13 +52,416 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] return model end - @doc doc_classifier($Model) $Model - end eval(ex) end +""" +$(MMI.doc_header(NeuralNetworkClassifier)) + +`NeuralNetworkClassifier`: a neural network model for making probabilistic predictions +of a Multiclass or OrderedFactor target, given a table of Continuous features. ) + TODO: + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; + check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: + - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + - `Flux.binary_focal_loss`: Binary version of the above +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a classification task, `softmax` is used for multiclass, single label regression, `sigmoid` is used for either binary classification or multi label classification (when there are multiple possible labels for a given sample). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we build a classification model using the Iris dataset. +```julia +using MLJ +using Flux +import RDatasets + +using Random +Random.seed!(123) + +MLJ.color_off() + +using Plots +pyplot(size=(600, 300*(sqrt(5)-1))); +``` +This is a very basic example, using a default builder and no standardization. +For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: +```julia +iris = RDatasets.dataset("datasets", "iris"); +y, X = unpack(iris, ==(:Species), colname -> true, rng=123); +NeuralNetworkClassifier = @load NeuralNetworkClassifier +clf = NeuralNetworkClassifier() +``` +Next, we can train the model: +```julia +import Random.seed!; seed!(123) +mach = machine(clf, X, y) +fit!(mach) +``` +We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: +```julia +clf.optimiser.eta = clf.optimiser.eta * 2 +clf.epochs = clf.epochs + 5 + +# note that if the optimizer_changes_trigger_retraining flag was set to true +# the model would be completely retrained from scratch because the optimizer was +# updated +fit!(mach, verbosity=2); +``` +We can inspect the mean training loss using the `cross_entropy` function: +```julia + +training_loss = cross_entropy(predict(mach, X), y) |> mean + +``` +And we can access the Flux chain (model) using `fitted_params`: +```julia +chain = fitted_params(mach).chain +``` +Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function +```julia +r = range(clf, :epochs, lower=1, upper=200, scale=:log10) +curve = learning_curve(clf, X, y, + range=r, + resampling=Holdout(fraction_train=0.7), + measure=cross_entropy) +using Plots +plot(curve.parameter_values, + curve.measurements, + xlab=curve.parameter_name, + xscale=curve.parameter_scale, + ylab = "Cross Entropy") + +savefig("iris_history.png") +``` +See also +[`ImageClassifier`](@ref) +""" +NeuralNetworkClassifier + +""" +$(MMI.doc_header(ImageClassifier)) + +`ImageClassifier`: A neural network model for making probabilistic +"predictions of a `GrayImage` target, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with +mach = machine(model, X, y) +Where +- `X`: is any `AbstractVector` of input features (eg, a `DataFrame`) whose items + are of scitype `GrayImage`; check the scitype with `scitype(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; + check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: + - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. + - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. + - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. + - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. + - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. + - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. + - `Flux.binary_focal_loss`: Binary version of the above +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + probabilistic. +- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions + returned above. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: +- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + + +# Report + +The fields of `report(mach)` are: +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + +# Examples + +In this example we use MLJ to classify the MNIST image dataset +```julia +using MLJ +using Flux +import MLJFlux +import MLJIteration # for `skip` + +MLJ.color_off() + +using Plots +pyplot(size=(600, 300*(sqrt(5)-1))); +``` +First we want to download the MNIST dataset, and unpack into images and labels +```julia +import MLDatasets: MNIST + +ENV["DATADEPS_ALWAYS_ACCEPT"] = true +images, labels = MNIST.traindata(); +``` +In MLJ, integers cannot be used for encoding categorical data, so we must coerce them into the `Multiclass` [scientific type](https://juliaai.github.io/ScientificTypes.jl/dev/). For more in this, see [Working with Categorical Data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/): +```julia +labels = coerce(labels, Multiclass); +images = coerce(images, GrayImage); + +# Checking scientific types: + +@assert scitype(images) <: AbstractVector{<:Image} +@assert scitype(labels) <: AbstractVector{<:Finite} + +images[1] +``` +For general instructions on coercing image data, see [type coercion for image data](https://alan-turing-institute.github.io/ScientificTypes.jl/dev/%23Type-coercion-for-image-data-1) +We start by defining a suitable `builder` object. This is a recipe +for building the neural network. Our builder will work for images of +any (constant) size, whether they be color or black and white (ie, +single or multi-channel). The architecture always consists of six +alternating convolution and max-pool layers, and a final dense +layer; the filter size and the number of channels after each +convolution layer is customisable. +```julia +import MLJFlux + +struct MyConvBuilder + filter_size::Int + channels1::Int + channels2::Int + channels3::Int +end + +make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) + +function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) + k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 + mod(k, 2) == 1 || error("`filter_size` must be odd. ") + p = div(k - 1, 2) # padding to preserve image size + init = Flux.glorot_uniform(rng) + front = Chain( + Conv((k, k), n_channels => c1, pad=(p, p), relu, init=init), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu, init=init), + MaxPool((2, 2)), + Conv((k, k), c2 => c3, pad=(p, p), relu, init=init), + MaxPool((2 ,2)), + make2d) + d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first + return Chain(front, Dense(d, n_out, init=init)) +end +``` +It is important to note that in our `build` function, there is no final softmax. This is applie by default in all MLJFlux classifiers, using the `finaliser` hyperparameter of the classifier. Now that we have our builder defined, we can define the actual moel. If you have a GPU, you can substitute in `acceleration=CudaLibs()` below. Note that in the case of convolutions, this will **greatly** increase the speed of training. +```julia +ImageClassifier = @load ImageClassifier +clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), + batch_size=50, + epochs=10, + rng=123) +``` +You can add flux options such as `optimiser` and `loss` in the snippet above. Currently, `loss` must be a flux-compatible loss, and not an MLJ measure. +Next, we can bind the model with the data in a machine, and fit the first 500 or so images: +```julia +mach = machine(clf, images, labels); + +fit!(mach, rows=1:500, verbosity=2); + +report(mach) + +chain = fitted_params(mach) + +Flux.params(chain)[2] +``` +We can tack on 20 more epochs by modifying the `epochs` field, and iteratively fit some more: +```julia +clf.epochs = clf.epochs + 20 +fit!(mach, rows=1:500); +``` +We can also make predictions and calculate an out-of-sample loss estimate, in two ways! +```julia +predicted_labels = predict(mach, rows=501:1000); +cross_entropy(predicted_labels, labels[501:1000]) |> mean +# alternative one liner! +evaluate!(mach, + resampling=Holdout(fraction_train=0.5), + measure=cross_entropy, + rows=1:1000, + verbosity=0) +``` + +## Wrapping in iteration controls + +Any iterative MLJFlux model can be wrapped in **iteration controls**, as we demonstrate next. For more on MLJ's `IteratedModel` wrapper, see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/). +The "self-iterating" classifier (`iterated_clf` below) is for iterating the image classifier defined above until a stopping criterion is hit. We use the following stopping criterion: +- `Patience(3)`: 3 consecutive increases in the loss +- `InvalidValue()`: an out-of-sample loss or a training loss that is `NaN` or `±Inf` +- `TimeLimit(t=5/60)`: training time has exceeded 5 minutes. +We can specify how often these checks (and other controls) are applied using the `Step` control. Additionally, we can define controls to +- save a snapshot of the machine every N control cycles (`save_control`) +- record traces of the out-of-sample loss and training losses for plotting (`WithLossDo`) +- record mean value traces of each Flux parameter for plotting (`Callback`) +And other controls. For a full list, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). +First, we define some helper functions and some empty vectors to store traces: +```julia +make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) +make1d(x::AbstractArray) = reshape(x, length(x)); + +# to extract the flux parameters from a machine +parameters(mach) = make1d.(Flux.params(fitted_params(mach))); + +# trace storage +losses = [] +training_losses = [] +parameter_means = Float32[]; +epochs = [] + +# to update traces +update_loss(loss) = push!(losses, loss) +update_training_loss(losses) = push!(training_losses, losses[end]) +update_means(mach) = append!(parameter_means, mean.(parameters(mach))); +update_epochs(epoch) = push!(epochs, epoch) +``` +Next, we can define our controls! We store them in a simple vector: +```julia +save_control = + MLJIteration.skip(Save(joinpath(DIR, "mnist.jlso")), predicate=3) + +controls=[Step(2), + Patience(3), + InvalidValue(), + TimeLimit(5/60), + save_control, + WithLossDo(), + WithLossDo(update_loss), + WithTrainingLossesDo(update_training_loss), + Callback(update_means), + WithIterationsDo(update_epochs) +``` +Once the controls are defined, we can instantiate and fit our "self-iterating" classifier: +```julia +iterated_clf = IteratedModel(model=clf, + controls=controls, + resampling=Holdout(fraction_train=0.7), + measure=log_loss) + +mach = machine(iterated_clf, images, labels); +fit!(mach, rows=1:500); +``` +Next we can compare the training and out-of-sample losses, as well as view the evolution of the weights: +```julia +plot(epochs, losses, + xlab = "epoch", + ylab = "root squared error", + label="out-of-sample") +plot!(epochs, training_losses, label="training") + +savefig(joinpath(DIR, "loss.png")) + +n_epochs = length(losses) +n_parameters = div(length(parameter_means), n_epochs) +parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)' +plot(epochs, parameter_means2, + title="Flux parameter mean weights", + xlab = "epoch") +# **Note.** The the higher the number, the deeper the chain parameter. +savefig(joinpath(DIR, "weights.png")) +``` +Since we saved our model every few epochs, we can retrieve the snapshots so we can make predictions! +```julia +mach2 = machine(joinpath(DIR, "mnist3.jlso")) +predict_mode(mach2, images[501:503]) +``` + +## Resuming training + +If we change `iterated_clf.controls` or `clf.epochs`, we can resume training from where it left off. This is very useful for long-running training sessions, where you may be interrupted by for example a bad connection or computer hibernation. +```julia +iterated_clf.controls[2] = Patience(4) +fit!(mach, rows=1:500) + +plot(epochs, losses, + xlab = "epoch", + ylab = "root squared error", + label="out-of-sample") +plot!(epochs, training_losses, label="training") +``` +See also +[`NeuralNetworkClassifier`](@ref) +""" +ImageClassifier + for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] ex = quote @@ -149,12 +507,552 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] return model end - @doc $doc_regressor($Model) $Model - end eval(ex) end + +""" +$(MMI.doc_header(NeuralNetworkRegressor)) + +`NeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` target, given a table of `Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any `AbstractVector` whose element + scitype is `Continuous`; check the scitype with `scitype(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. + Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder + using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating + of the weights of the network. For further reference, see either the examples or + [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). + To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to + start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function + which can be called in the form `loss(yhat, y)`. + Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). + For a regression task, the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents + one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents + the number of samples per update of the networks weights. Typcally, batch size should be + somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, + while larger batch sizes lead towards smoother training loss curves. + In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), + and stick with it, and only tune the learning rate. In most examples, batch size is set + in powers of twos, but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value + in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. + A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a + machine if the associated optimiser has changed. If true, the associated machine will + retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. + For training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. + Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include + `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which + make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. + epoch n-1. +# Examples + +In this example we build a regression model using the Boston house price dataset +```julia + using MLJ + using MLJFlux + using Flux + using Plots +``` +First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: +```julia +data = OpenML.load(531); # Loads from https://www.openml.org/d/531 + +y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); + +scitype(y) +schema(X) +``` +Since MLJFlux models do not handle ordered factors, we can treat `:RAD` as `Continuous`: +```julia +X = coerce(X, :RAD=>Continuous) +``` +Lets also make a test set: +```julia +(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) +end +``` +Finally, we can define the model! +```julia +NeuralNetworkRegressor = @load NeuralNetworkRegressor + model = NeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` +also see the losses in the output of `report(mach)` +```julia +mach = machine(pipe, X, y) +fit!(mach, verbosity=2) + +# first element initial loss, 2:end per epoch training losses +report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + +rates = 10. .^ (-5:0) + +foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) +end +plt #!md + +savefig(joinpath("assets", "learning_rate.png")) + +pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. +`NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. +```julia +# For initializing or clearing the traces: + +clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing +end + + # And to update the traces: + +update_loss(loss) = push!(losses, loss) +update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) +update_epochs(epoch) = push!(epochs, epoch) +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia +controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), +WithIterationsDo(update_epochs)] + + +iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia +clear() +mach = machine(iterated_pipe, X, y) +fit!(mach) + +plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); +scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + +savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an +out-of-sample loss can be tracked and used in the specified stopping criterion, +`NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by +`IteratedModel` (our pipeline model) is retrained on all data for the same number of +iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned +parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia +e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +using Measurements +l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) +@show l1_loss +``` +We take this estimate of the uncertainty of the generalization error with a [grain of +salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the +untouched test set, of the eror of our self-iterating neural network regressor with a +couple of other models trained on the same data (using default hyperparameters): +```julia +function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean +end +performance(iterated_pipe) + +three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + +errs = performance.(three_models) + +(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty +``` + +See also +[`MultitargetNeuralNetworkRegressor`](@ref) +""" +NeuralNetworkRegressor + +""" +$(MMI.doc_header(MultitargetNeuralNetworkRegressor)) + +`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic +predictions of a `Continuous` multi-target, presented as a table, given a table of +`Continuous` features. + +# Training data + +In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) + +Where + +- `X`: is any table of input features (eg, a `DataFrame`) whose columns + are of scitype `Continuous`; check the scitype with `schema(X)` +- `y`: is the target, which can be any table of output targets whose element + scitype is `Continuous`; check the scitype with `schema(y)` + + +# Hyper-parameters + +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural + network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct + your own builder using the `@builder` macro, see examples for further information. +- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the + updating of the weights of the network. For further reference, see either the examples + or [the Flux optimiser + documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a + learning rate (the update rate of the optimizer), a good rule of thumb is to start out + at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `loss=Flux.mse`: The loss function which the network will optimize. Should be a + function which can be called in the form `loss(yhat, y)`. Possible loss functions are + listed in [the Flux loss function + documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, + the most natural loss functions are: + - `Flux.mse` + - `Flux.mae` + - `Flux.msle` + - `Flux.huber_loss` +- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents + one pass through the entirety of the training dataset. +- `batch_size::Int=1`: The batch size to be used for training. The batch size represents + the number of samples per update of the networks weights. Typcally, batch size should be + somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, + while larger batch sizes lead towards smoother training loss curves. In general, it is a + good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and + only tune the learning rate. In most literature, batch size is set in powers of twos, + but this is fairly arbitrary. +- `lambda::Float64=0`: The stregth of the regularization used during training. Can be + any value in the range `[0, ∞)`. +- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of + 0 represents L2 regularization, and a value of 1 represents L1 regularization. +- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during + training. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting + a machine if the associated optimiser has changed. If true, the associated machine will + retrain from scratch on `fit`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. + For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network. +Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include +`Flux.sigmoid` and the identity function (otherwise known as "linear activation"). + +# Operations + +- `predict(mach, Xnew)`: return predictions of the target given new + features `Xnew` having the same Scitype as `X` above. Predictions are + deterministic. + + +# Fitted parameters + +The fields of `fitted_params(mach)` are: + +- `chain`: The trained "chain", or series of layers, functions, and activations which + make up the neural network. + + +# Report + +The fields of `report(mach)` are: + +- `training_losses`: The history of training losses, a vector containing the history of + all the losses during training. The first element of the vector is the initial + penalized loss. After the first element, the nth element corresponds to the loss of + epoch n-1. + +# Examples + +In this example we build a regression model using a toy dataset. +```julia +using MLJ +using MLJFlux +using Flux +using Plots +using MLJBase: augment_X +``` +First, we generate some data: +```julia +X = augment_X(randn(10000, 8), true); +θ = randn((9,2)); +y = X * θ; +X = MLJ.table(X) +y = MLJ.table(y) + +schema(y) +schema(X) +``` +Lets also make a test set: +```julia +(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); +``` +Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. +```julia +builder = MLJFlux.@builder begin + init=Flux.glorot_uniform(rng) + Chain(Dense(n_in, 64, relu, init=init), + Dense(64, 32, relu, init=init), + Dense(32, 1, init=init)) +end +``` +Finally, we can define the model! +```julia +MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor + model = MultitargetNeuralNetworkRegressor(builder=builder, + rng=123, + epochs=20) +``` +For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +```julia +pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) +``` +If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` + +```julia +mach = machine(pipe, X, y) +fit!(mach, verbosity=2) + +# first element initial loss, 2:end per epoch training losses +report(mach).transformed_target_model_deterministic.training_losses + +``` + +## Experimenting with learning rate + +We can visually compare how the learning rate affects the predictions: +```julia +plt = plot() + +rates = 10. .^ (-5:0) + +foreach(rates) do η + pipe.transformed_target_model_deterministic.model.optimiser.eta = η + fit!(mach, force=true, verbosity=0) + losses = + report(mach).transformed_target_model_deterministic.model.training_losses[3:end] + plot!(1:length(losses), losses, label=η) +end +plt #!md + +savefig(joinpath("assets", "learning_rate.png")) + + +pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 + +``` + +## Using Iteration Controls + +We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +```julia +# For initializing or clearing the traces: + +clear() = begin + global losses = [] + global training_losses = [] + global epochs = [] + return nothing +end + +# And to update the traces: + +update_loss(loss) = push!(losses, loss) +update_training_loss(report) = + push!(training_losses, + report.transformed_target_model_deterministic.model.training_losses[end]) +update_epochs(epoch) = push!(epochs, epoch) +``` +For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: +```julia +controls=[Step(1), + NumberSinceBest(6), + InvalidValue(), + TimeLimit(1/60), + WithLossDo(update_loss), + WithReportDo(update_training_loss), +WithIterationsDo(update_epochs)] + +iterated_pipe = + IteratedModel(model=pipe, + controls=controls, + resampling=Holdout(fraction_train=0.8), + measure = l2) +``` +Next, we can clear the traces, fit the model, and plot the traces: +```julia +clear() +mach = machine(iterated_pipe, X, y) +fit!(mach) + +plot(epochs, losses, + xlab = "epoch", + ylab = "mean sum of squares error", + label="out-of-sample", + legend = :topleft); +scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md + +savefig(joinpath("assets", "loss.png")) +``` + +### Brief note on iterated models + +Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. + +## Evaluating Iterated Models + +We can evaluate our model with the `evaluate!` function: +```julia +e = evaluate!(mach, + resampling=CV(nfolds=8), + measures=[l1, l2]) + +using Measurements +l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) +@show l1_loss +``` +We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). + +## Comparison with other models on the test set + +Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): +```julia + +function performance(model) + mach = machine(model, X, y) |> fit! + yhat = predict(mach, Xtest) + l1(yhat, ytest) |> mean +end +performance(iterated_pipe) + +three_models = [(@load EvoTreeRegressor)(), # tree boosting model + (@load LinearRegressor pkg=MLJLinearModels)(), + iterated_pipe] + +errs = performance.(three_models) + +(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty + + +``` +See also +[`NeuralNetworkRegressor`](@ref) +""" +MultitargetNeuralNetworkRegressor + const Regressor = Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor} From a19d93a28b0427d33de18fc13a045585a0db0d73 Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 12 Jul 2022 16:42:47 -0500 Subject: [PATCH 5/7] git killing me --- nn.md | 245 ------------------------------------------ nnc.md | 128 ---------------------- nnclassif.norg | 148 ------------------------- nnm.md | 247 ------------------------------------------ nnregressor.norg | 273 ----------------------------------------------- 5 files changed, 1041 deletions(-) delete mode 100644 nn.md delete mode 100644 nnc.md delete mode 100644 nnclassif.norg delete mode 100644 nnm.md delete mode 100644 nnregressor.norg diff --git a/nn.md b/nn.md deleted file mode 100644 index 46641a33..00000000 --- a/nn.md +++ /dev/null @@ -1,245 +0,0 @@ -# NeuralNetworkRegressor - -`NeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` target, given a table of `Continuous` features. - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with -mach = machine(model, X, y) -Where -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Continuous`; check the scitype with `scitype(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: - - `Flux.mse` - - `Flux.mae` - - `Flux.msle` - - `Flux.huber_loss` -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - deterministic. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -# Report - -The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - -# Examples - -In this example we build a regression model using the Boston house price dataset -```julia - - using MLJ - using MLJFlux - using Flux - using Plots - -``` -First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: -```julia - - data = OpenML.load(531); # Loads from https://www.openml.org/d/531 - - y, X = unpack(data, ==(:MEDV), !=(:CHAS); rng=123); - - scitype(y) - schema(X) - -``` -Since MLJFlux models do not handle ordered factos, we can treat `:RAD` as `Continuous`: -```julia -X = coerce(X, :RAD=>Continuous) -``` -Lets also make a test set: -```julia - - (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); - -``` -Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. -```julia -builder = MLJFlux.@builder begin - init=Flux.glorot_uniform(rng) - Chain(Dense(n_in, 64, relu, init=init), - Dense(64, 32, relu, init=init), - Dense(32, 1, init=init)) - end -``` -Finally, we can define the model! -```julia - - NeuralNetworkRegressor = @load NeuralNetworkRegressor - model = NeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) -``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -```julia -pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) -``` -If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` - -```julia -mach = machine(pipe, X, y) - fit!(mach, verbosity=2) - - # first element initial loss, 2:end per epoch training losses - report(mach).transformed_target_model_deterministic.training_losses - -``` - -## Experimenting with learning rate - -We can visually compare how the learning rate affects the predictions: -```julia -plt = plot() - - rates = 10. .^ (-5:0) - - foreach(rates) do η - pipe.transformed_target_model_deterministic.model.optimiser.eta = η - fit!(mach, force=true, verbosity=0) - losses = - report(mach).transformed_target_model_deterministic.model.training_losses[3:end] - plot!(1:length(losses), losses, label=η) - end - plt #!md - - savefig(joinpath("assets", "learning_rate.png")) - - - pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 - -``` - -## Using Iteration Controls - -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. -```julia - - # For initializing or clearing the traces: - - clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing - end - - # And to update the traces: - - update_loss(loss) = push!(losses, loss) - update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) - update_epochs(epoch) = push!(epochs, epoch) - -``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia - - controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), - WithIterationsDo(update_epochs)] - - - iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) - -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia - - - clear() - mach = machine(iterated_pipe, X, y) - fit!(mach) - - plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); - scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - - savefig(joinpath("assets", "loss.png")) -``` - -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia - - e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -#- - - using Measurements - l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) - @show l1_loss - -``` -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): -```julia - - function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean - end - performance(iterated_pipe) - - three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - - errs = performance.(three_models) - - (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty - - -``` diff --git a/nnc.md b/nnc.md deleted file mode 100644 index a14f2ecc..00000000 --- a/nnc.md +++ /dev/null @@ -1,128 +0,0 @@ -# NeuralNetworkClassifier - -`NeuralNetworkClassifier`: -- TODO - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with -mach = machine(model, X, y) -Where -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Finite` with `n_out` classes; check the scitype with `scitype(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: - - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. - - `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. - - `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. - - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - - `Flux.binary_focal_loss`: Binary version of the above -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. -- `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions - returned above. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -# Report - -The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - -# Examples - -In this example we build a classification model using the Iris dataset. -```julia - - using MLJ - using Flux - import RDatasets - - using Random - Random.seed!(123) - - MLJ.color_off() - - using Plots - pyplot(size=(600, 300*(sqrt(5)-1))); - -``` -This is a very basic example, using a default builder and no standardization. -For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: -```julia - - iris = RDatasets.dataset("datasets", "iris"); - y, X = unpack(iris, ==(:Species), colname -> true, rng=123); - NeuralNetworkClassifier = @load NeuralNetworkClassifier - clf = NeuralNetworkClassifier() - -``` -Next, we can train the model: -```julia -import Random.seed!; seed!(123) - mach = machine(clf, X, y) - fit!(mach) -``` -We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: -```julia -clf.optimiser.eta = clf.optimiser.eta * 2 - clf.epochs = clf.epochs + 5 - - # note that if the optimizer_changes_trigger_retraining flag was set to true - # the model would be completely retrained from scratch because the optimizer was - # updated - fit!(mach, verbosity=2); -``` -We can inspect the mean training loss using the `cross_entropy` function: -```julia - - training_loss = cross_entropy(predict(mach, X), y) |> mean - -``` -And we can access the Flux chain (model) using `fitted_params`: -```julia -training_loss = cross_entropy(predict(mach, X), y) |> mean -``` -Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function -```julia -r = range(clf, :epochs, lower=1, upper=200, scale=:log10) - curve = learning_curve(clf, X, y, - range=r, - resampling=Holdout(fraction_train=0.7), - measure=cross_entropy) - using Plots - plot(curve.parameter_values, - curve.measurements, - xlab=curve.parameter_name, - xscale=curve.parameter_scale, - ylab = "Cross Entropy") - - savefig("iris_history.png") -``` diff --git a/nnclassif.norg b/nnclassif.norg deleted file mode 100644 index 25ba3847..00000000 --- a/nnclassif.norg +++ /dev/null @@ -1,148 +0,0 @@ -* NeuralNetworkClassifier - - `NeuralNetworkClassifier`: - - [ ] TODO - -* Training data - - In MLJ or MLJBase, bind an instance `model` to data with - - mach = machine(model, X, y) - - Where - - - `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` - - - `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Finite` with `n_out` classes; check the scitype with `scitype(y)` - - -* Hyper-parameters - - - `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. - - `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or {https://fluxml.ai/Flux.jl/stable/training/optimisers/}[the Flux optimiser documentation]. To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. - - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in {https://fluxml.ai/Flux.jl/stable/models/losses/}[the Flux loss function documentation]. For a classification task, the most natural loss functions are: - -- `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - -- `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. - -- `Flux.binarycrossentropy`: Typically used as loss in binary classification, with labels in a 1-hot encoded format. - -- `Flux.logitbinarycrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `sigmoid` and then calculating binary crossentropy. - -- `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - -- `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - -- `Flux.binary_focal_loss`: Binary version of the above - - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. - - `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. - - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. - - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. - - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. - - `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -* Operations - - - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. - - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions - returned above. - - - -* Fitted parameters - - The fields of `fitted_params(mach)` are: - - - `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -* Report - - The fields of `report(mach)` are: - - - `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch $n-1$. - -* Examples - - In this example we build a classification model using the Iris dataset. - - @code julia - - using MLJ - using Flux - import RDatasets - - using Random - Random.seed!(123) - - MLJ.color_off() - - using Plots - pyplot(size=(600, 300*(sqrt(5)-1))); - - @end - - This is a very basic example, using a default builder and no standardization. - For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: - - @code julia - - iris = RDatasets.dataset("datasets", "iris"); - y, X = unpack(iris, ==(:Species), colname -> true, rng=123); - NeuralNetworkClassifier = @load NeuralNetworkClassifier - clf = NeuralNetworkClassifier() - - @end - - Next, we can train the model: - @code julia - import Random.seed!; seed!(123) - mach = machine(clf, X, y) - fit!(mach) - @end - - We can train the model in an incremental fashion with the `optimizer_changes_trigger_retraining` flag set to false (which is by default). Here, we change the number of iterations and the learning rate of the optimiser: - - @code julia - clf.optimiser.eta = clf.optimiser.eta * 2 - clf.epochs = clf.epochs + 5 - - # note that if the optimizer_changes_trigger_retraining flag was set to true - # the model would be completely retrained from scratch because the optimizer was - # updated - fit!(mach, verbosity=2); - @end - - We can inspect the mean training loss using the `cross_entropy` function: - - @code julia - - training_loss = cross_entropy(predict(mach, X), y) |> mean - - @end - - And we can access the Flux chain (model) using `fitted_params`: - - @code julia - training_loss = cross_entropy(predict(mach, X), y) |> mean - @end - - Finally, we can see how the out-of-sample performance changes over time, using the `learning_curve` function - - @code julia - r = range(clf, :epochs, lower=1, upper=200, scale=:log10) - curve = learning_curve(clf, X, y, - range=r, - resampling=Holdout(fraction_train=0.7), - measure=cross_entropy) - using Plots - plot(curve.parameter_values, - curve.measurements, - xlab=curve.parameter_name, - xscale=curve.parameter_scale, - ylab = "Cross Entropy") - - savefig("iris_history.png") - @end - diff --git a/nnm.md b/nnm.md deleted file mode 100644 index 5c0234dc..00000000 --- a/nnm.md +++ /dev/null @@ -1,247 +0,0 @@ -# MultitargetNeuralNetworkRegressor - -`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. - -# Training data - -In MLJ or MLJBase, bind an instance `model` to data with -mach = machine(model, X, y) -Where -- `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` -- `y`: is the target, which can be any table of output targets whose element - scitype is `Continuous`; check the scitype with `schema(y)` - - -# Hyper-parameters - -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. -- `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a regression task, the most natural loss functions are: - - `Flux.mse` - - `Flux.mae` - - `Flux.msle` - - `Flux.huber_loss` -- `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. -- `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. -- `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. -- `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -# Operations - -- `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - deterministic. - - -# Fitted parameters - -The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -# Report - -The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - -# Examples - -In this example we build a regression model using the Boston house price dataset -```julia - - using MLJ - using MLJFlux - using Flux - using Plots - using MLJBase: augment_X - -``` -First, we generate some data: -```julia - - X = augment_X(randn(10000, 8), true); - θ = randn((9,2)); - y = X * θ; - X = MLJ.table(X) - y = MLJ.table(y) - - - - - schema(y) - schema(X) - -``` -Lets also make a test set: -```julia - - (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); - -``` -Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. -```julia -builder = MLJFlux.@builder begin - init=Flux.glorot_uniform(rng) - Chain(Dense(n_in, 64, relu, init=init), - Dense(64, 32, relu, init=init), - Dense(32, 1, init=init)) - end -``` -Finally, we can define the model! -```julia - - MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor - model = MultitargetNeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) -``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -```julia -pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) -``` -If we fit with a high verbosity (>1), we will see the losses during training. We can also see the losses in the output of `report(mach)` - -```julia -mach = machine(pipe, X, y) - fit!(mach, verbosity=2) - - # first element initial loss, 2:end per epoch training losses - report(mach).transformed_target_model_deterministic.training_losses - -``` - -## Experimenting with learning rate - -We can visually compare how the learning rate affects the predictions: -```julia -plt = plot() - - rates = 10. .^ (-5:0) - - foreach(rates) do η - pipe.transformed_target_model_deterministic.model.optimiser.eta = η - fit!(mach, force=true, verbosity=0) - losses = - report(mach).transformed_target_model_deterministic.model.training_losses[3:end] - plot!(1:length(losses), losses, label=η) - end - plt #!md - - savefig(joinpath("assets", "learning_rate.png")) - - - pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 - -``` - -## Using Iteration Controls - -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. -```julia - - # For initializing or clearing the traces: - - clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing - end - - # And to update the traces: - - update_loss(loss) = push!(losses, loss) - update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) - update_epochs(epoch) = push!(epochs, epoch) - -``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia - - controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), - WithIterationsDo(update_epochs)] - - - iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) - -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia - - - clear() - mach = machine(iterated_pipe, X, y) - fit!(mach) - - plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); - scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - - savefig(joinpath("assets", "loss.png")) -``` - -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia - - e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -#- - - using Measurements - l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) - @show l1_loss - -``` -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): -```julia - - function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean - end - performance(iterated_pipe) - - three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - - errs = performance.(three_models) - - (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty - - -``` diff --git a/nnregressor.norg b/nnregressor.norg deleted file mode 100644 index cdeb2277..00000000 --- a/nnregressor.norg +++ /dev/null @@ -1,273 +0,0 @@ -* MultitargetNeuralNetworkRegressor - - `MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic - predictions of a `Continuous` multi-target, presented as a table, given a table of `Continuous` features. - -* Training data - - In MLJ or MLJBase, bind an instance `model` to data with - - mach = machine(model, X, y) - - Where - - - `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` - - - `y`: is the target, which can be any table of output targets whose element - scitype is `Continuous`; check the scitype with `schema(y)` - - -* Hyper-parameters - - - `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. - - `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or {https://fluxml.ai/Flux.jl/stable/training/optimisers/}[the Flux optimiser documentation]. To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. - - `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in {https://fluxml.ai/Flux.jl/stable/models/losses/}[the Flux loss function documentation]. For a regression task, the most natural loss functions are: - -- `Flux.mse` - -- `Flux.mae` - -- `Flux.msle` - -- `Flux.huber_loss` - - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. - - `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. - - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. - - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. - - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. - - `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). - - -* Operations - - - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - deterministic. - - -* Fitted parameters - - The fields of `fitted_params(mach)` are: - - - `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. - - -* Report - - The fields of `report(mach)` are: - - - `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch $n-1$. - -* Examples - -In this example we build a regression model using the Boston house price dataset - - @code julia - - using MLJ - using MLJFlux - using Flux - using Plots - using MLJBase: augment_X - - @end - - First, we generate some data: - - @code julia - - X = augment_X(randn(10000, 8), true); - θ = randn((9,2)); - y = X * θ; - X = MLJ.table(X) - y = MLJ.table(y) - - - - - schema(y) - schema(X) - - @end - - Lets also make a test set: - - @code julia - - (X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true); - - @end - - Next, we can define a `builder`. In the following macro call, `n_in` is the number of expected input features, and rng is a RNG. `init` is the function used to generate the random initial weights of the network. - - @code julia - builder = MLJFlux.@builder begin - init=Flux.glorot_uniform(rng) - Chain(Dense(n_in, 64, relu, init=init), - Dense(64, 32, relu, init=init), - Dense(32, 1, init=init)) - end - @end - - Finally, we can define the model! - - @code julia - - MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor - model = MultitargetNeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) - @end - - For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have {https://www.informit.com/articles/article.aspx?p=3131594&seqNum=2}[saturated neurons] and not train well. Therefore, standardization is key! - - @code julia - pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) - @end - - If we fit with a high verbosity ($>1$), we will see the losses during training. We can also see the losses in the output of `report(mach)` - - - @code julia - mach = machine(pipe, X, y) - fit!(mach, verbosity=2) - - # first element initial loss, 2:end per epoch training losses - report(mach).transformed_target_model_deterministic.training_losses - - @end - -** Experimenting with learning rate - - We can visually compare how the learning rate affects the predictions: - - @code julia - plt = plot() - - rates = 10. .^ (-5:0) - - foreach(rates) do η - pipe.transformed_target_model_deterministic.model.optimiser.eta = η - fit!(mach, force=true, verbosity=0) - losses = - report(mach).transformed_target_model_deterministic.model.training_losses[3:end] - plot!(1:length(losses), losses, label=η) - end - plt #!md - - savefig(joinpath("assets", "learning_rate.png")) - - - pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 - - @end - -** Using Iteration Controls - We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. - - @code julia - - # For initializing or clearing the traces: - - clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing - end - - # And to update the traces: - - update_loss(loss) = push!(losses, loss) - update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) - update_epochs(epoch) = push!(epochs, epoch) - - @end - - For further reference of controls, see {https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/#Controls-provided}[the documentation]. To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: - - @code julia - - controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), - WithIterationsDo(update_epochs)] - - - iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) - - @end - - Next, we can clear the traces, fit the model, and plot the traces: - - @code julia - - - clear() - mach = machine(iterated_pipe, X, y) - fit!(mach) - - plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); - scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - - savefig(joinpath("assets", "loss.png")) - @end - -*** Brief note on iterated models - Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. - -** Evaluating Iterated Models - We can evaluate our model with the `evaluate!` function: - - @code julia - - e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -#- - - using Measurements - l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) - @show l1_loss - - @end - -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -** Comparison with other models on the test set - - Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): - - @code julia - - function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean - end - performance(iterated_pipe) - - three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - - errs = performance.(three_models) - - (models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty - - - @end - From 8aaa88f756d541dd77d77f5d1ce5a1adb61e3050 Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 12 Jul 2022 16:44:17 -0500 Subject: [PATCH 6/7] update to fix CI --- src/MLJFlux.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index d2e63add..981bc4d4 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -14,6 +14,8 @@ using ColorTypes using ComputationalResources using Random +const MMI=MLJModelInterface + include("penalizers.jl") include("core.jl") include("builders.jl") @@ -24,7 +26,7 @@ include("image.jl") include("mlj_model_interface.jl") ### Package specific model traits: -MLJModelInterface.metadata_pkg.((NeuralNetworkRegressor, +MMI.metadata_pkg.((NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor, NeuralNetworkClassifier, ImageClassifier), From ccbed5eabeef26b5fc8b5d84a522d0b12ac2564a Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Mon, 18 Jul 2022 16:17:12 -0500 Subject: [PATCH 7/7] code review --- src/types.jl | 556 ++++++++++++++------------------------------------- 1 file changed, 148 insertions(+), 408 deletions(-) diff --git a/src/types.jl b/src/types.jl index 16c3c295..1f454a63 100644 --- a/src/types.jl +++ b/src/types.jl @@ -22,7 +22,7 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] function $Model(; builder::B = Short() , finaliser::F = Flux.softmax - , optimiser::O = Flux.Optimise.ADAM() + , optimiser::O = Flux.Optimise.Adam() , loss::L = Flux.crossentropy , epochs = 10 , batch_size = 1 @@ -60,19 +60,22 @@ end """ $(MMI.doc_header(NeuralNetworkClassifier)) -`NeuralNetworkClassifier`: a neural network model for making probabilistic predictions -of a Multiclass or OrderedFactor target, given a table of Continuous features. ) - TODO: +`NeuralNetworkClassifier` is for training a data-dependent Flux.jl neural network +for making probabilistic predictions of a `Multiclass` or `OrderedFactor` target, +given a table of `Continuous` features. Users provide a recipe for constructing + the network, based on properties of the data that is encountered, by specifying + an appropriate `builder`. See MLJFlux documentation for more on builders. # Training data In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) Where - `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` + are of scitype `Continuous`; check the column scitypes with `schema(X)`. - `y`: is the target, which can be any `AbstractVector` whose element scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; check the scitype with `scitype(y)` @@ -80,8 +83,11 @@ Where # Hyper-parameters -- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural + network. Possible `builders` include: `MLJFlux.Linear`, `MLJFlux.Short`, + and `MLJFlux.MLP`. See MLJFlux documentation for examples of + user-defined builders. +- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. @@ -90,21 +96,28 @@ Where - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - `Flux.binary_focal_loss`: Binary version of the above + Currently MLJ measures are not supported as loss functions here. - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `batch_size::int=1`: the batch size to be used for training. the batch size represents + the number of samples per update of the networks weights. typcally, batch size should be + somewhere between 8 and 512. smaller batch sizes lead to noisier training loss curves, + while larger batch sizes lead towards smoother training loss curves. + In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), + and stick with it, and only tune the learning rate. In most examples, batch size is set + in powers of twos, but this is fairly arbitrary. - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit!`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`. For training on GPU, use `CUDALibs()`. - `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a classification task, `softmax` is used for multiclass, single label regression, `sigmoid` is used for either binary classification or multi label classification (when there are multiple possible labels for a given sample). # Operations - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. + features `Xnew` having the same scitype as `X` above. Predictions are + probabilistic but uncalibrated. - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned above. @@ -113,14 +126,17 @@ Where The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. +- `chain`: The trained "chain" (Flux.jl model), namely the series of layers, + functions, and activations which make up the neural network. This includes + the final layer specified by `finaliser` (eg, `softmax`). # Report The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. +- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in + historical order, of length `epochs + 1`. The first element is the pre-training loss. # Examples @@ -133,16 +149,12 @@ import RDatasets using Random Random.seed!(123) -MLJ.color_off() - -using Plots -pyplot(size=(600, 300*(sqrt(5)-1))); ``` This is a very basic example, using a default builder and no standardization. -For a more advance illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: +For a more advanced illustration, see [`NeuralNetworkRegressor`](@ref) or [`ImageClassifier`](@ref). First, we can load the data: ```julia iris = RDatasets.dataset("datasets", "iris"); -y, X = unpack(iris, ==(:Species), colname -> true, rng=123); +y, X = unpack(iris, ==(:Species), rng=123); NeuralNetworkClassifier = @load NeuralNetworkClassifier clf = NeuralNetworkClassifier() ``` @@ -157,7 +169,7 @@ We can train the model in an incremental fashion with the `optimizer_changes_tri clf.optimiser.eta = clf.optimiser.eta * 2 clf.epochs = clf.epochs + 5 -# note that if the optimizer_changes_trigger_retraining flag was set to true +# note that if the `optimizer_changes_trigger_retraining` flag was set to true # the model would be completely retrained from scratch because the optimizer was # updated fit!(mach, verbosity=2); @@ -186,7 +198,6 @@ plot(curve.parameter_values, xscale=curve.parameter_scale, ylab = "Cross Entropy") -savefig("iris_history.png") ``` See also [`ImageClassifier`](@ref) @@ -196,25 +207,34 @@ NeuralNetworkClassifier """ $(MMI.doc_header(ImageClassifier)) -`ImageClassifier`: A neural network model for making probabilistic -"predictions of a `GrayImage` target, given a table of `Continuous` features. +`ImageClassifier` classifies images using a neural network adapted to the type + of images provided (color or greyscale). Predictions are probabistic. Users + provide a recipe for constructing the network, based on properties of the image + encountered, by specifying an appropriate `builder`. See MLJFlux documentation + for more on builders. # Training data In MLJ or MLJBase, bind an instance `model` to data with -mach = machine(model, X, y) + + mach = machine(model, X, y) + Where -- `X`: is any `AbstractVector` of input features (eg, a `DataFrame`) whose items - are of scitype `GrayImage`; check the scitype with `scitype(X)` +- `X`: is any `AbstractVector` of images with `ColorImage` or `GrayImage` + scitype; check the scitype with `scitype(X)` and refer to ScientificTypes.jl + documentation on coercing typical image formats into an appropriate type. - `y`: is the target, which can be any `AbstractVector` whose element - scitype is `Multiclass` or `OrderedFactor` with `n_out` classes; - check the scitype with `scitype(y)` + scitype is `Multiclass`; check the scitype with `scitype(y)`. # Hyper-parameters -- `builder=MLJFlux.Short()`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `builder`: An MLJFlux builder that constructs the neural network. + The fallback builds a depth-16 VGG architecture adapted to the image + size and number of target classes, with no batch normalisation; see the + Metalhead.jl documentation for details. See the example below for a + user-specified builder. +- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in [the Flux loss function documentation](https://fluxml.ai/Flux.jl/stable/models/losses/). For a classification task, the most natural loss functions are: - `Flux.crossentropy`: Typically used as loss in multiclass classification, with labels in a 1-hot encoded format. - `Flux.logitcrossentopy`: Mathematically equal to crossentropy, but computationally more numerically stable than finalising the outputs with `softmax` and then calculating crossentropy. @@ -223,21 +243,26 @@ Where - `Flux.tversky_loss`: Used with imbalanced data to give more weight to false negatives. - `Flux.focal_loss`: Used with highly imbalanced data. Weights harder examples more than easier examples. - `Flux.binary_focal_loss`: Binary version of the above + Currently MLJ measures are not supported as loss functions here. - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents the number of samples per update of the networks weights. Typcally, batch size should be somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, while larger batch sizes lead towards smoother training loss curves. In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and only tune the learning rate. In most literature, batch size is set in powers of twos, but this is fairly arbitrary. +- `batch_size::Int=1`: The batch size to be used for training. The batch size + represents the number of samples per update of the networks weights. Batch + sizes between 8 and 512 are typical. Increasing batch size can speed up + training, especially on a GPU (`acceleration=CUDALibs()`). - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit`, otherwise it will not. -- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit!`, otherwise it will not. +- `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. For Training on GPU, use `CudaLibs()`. For training on GPU, use `CUDALibs()`. +- `finaliser=Flux.softmax`: The final activation function of the neural network, + needed to convert outputs to probabilities (builders do not provide this). # Operations - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are - probabilistic. + features `Xnew` having the same scitype as `X` above. Predictions are + probabilistic but uncalibrated. - `predict_mode(mach, Xnew)`: Return the modes of the probabilistic predictions returned above. @@ -245,13 +270,17 @@ Where # Fitted parameters The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which make up the neural network. + +- `chain`: The trained "chain" (Flux.jl model), namely the series of layers, + functions, and activations which make up the neural network. This includes + the final layer specified by `finaliser` (eg, `softmax`). # Report The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. +- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in + historical order, of length `epochs + 1`. The first element is the pre-training loss. # Examples @@ -262,31 +291,20 @@ using Flux import MLJFlux import MLJIteration # for `skip` -MLJ.color_off() - -using Plots -pyplot(size=(600, 300*(sqrt(5)-1))); ``` First we want to download the MNIST dataset, and unpack into images and labels ```julia import MLDatasets: MNIST -ENV["DATADEPS_ALWAYS_ACCEPT"] = true images, labels = MNIST.traindata(); ``` -In MLJ, integers cannot be used for encoding categorical data, so we must coerce them into the `Multiclass` [scientific type](https://juliaai.github.io/ScientificTypes.jl/dev/). For more in this, see [Working with Categorical Data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/): +In MLJ, integers cannot be used for encoding categorical data, so we must coerce them into the `Multiclass` scitype: ```julia labels = coerce(labels, Multiclass); images = coerce(images, GrayImage); -# Checking scientific types: - -@assert scitype(images) <: AbstractVector{<:Image} -@assert scitype(labels) <: AbstractVector{<:Finite} - images[1] ``` -For general instructions on coercing image data, see [type coercion for image data](https://alan-turing-institute.github.io/ScientificTypes.jl/dev/%23Type-coercion-for-image-data-1) We start by defining a suitable `builder` object. This is a recipe for building the neural network. Our builder will work for images of any (constant) size, whether they be color or black and white (ie, @@ -323,7 +341,7 @@ function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) return Chain(front, Dense(d, n_out, init=init)) end ``` -It is important to note that in our `build` function, there is no final softmax. This is applie by default in all MLJFlux classifiers, using the `finaliser` hyperparameter of the classifier. Now that we have our builder defined, we can define the actual moel. If you have a GPU, you can substitute in `acceleration=CudaLibs()` below. Note that in the case of convolutions, this will **greatly** increase the speed of training. +It is important to note that in our `build` function, there is no final `softmax`. This is applied by default in all MLJFlux classifiers (override this using the `finaliser` hyperparameter). Now that we have our builder defined, we can define the actual model. If you have a GPU, you can substitute in `acceleration=CUDALibs()` below to greatly speed up training. ```julia ImageClassifier = @load ImageClassifier clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), @@ -349,114 +367,20 @@ We can tack on 20 more epochs by modifying the `epochs` field, and iteratively f clf.epochs = clf.epochs + 20 fit!(mach, rows=1:500); ``` -We can also make predictions and calculate an out-of-sample loss estimate, in two ways! +We can also make predictions and calculate an out-of-sample loss estimate: ```julia predicted_labels = predict(mach, rows=501:1000); cross_entropy(predicted_labels, labels[501:1000]) |> mean -# alternative one liner! +``` +The preceding `fit!`/`predict`/evaluate workflow can be alternatively executed as folllows: + +```julia evaluate!(mach, resampling=Holdout(fraction_train=0.5), measure=cross_entropy, rows=1:1000, verbosity=0) ``` - -## Wrapping in iteration controls - -Any iterative MLJFlux model can be wrapped in **iteration controls**, as we demonstrate next. For more on MLJ's `IteratedModel` wrapper, see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/). -The "self-iterating" classifier (`iterated_clf` below) is for iterating the image classifier defined above until a stopping criterion is hit. We use the following stopping criterion: -- `Patience(3)`: 3 consecutive increases in the loss -- `InvalidValue()`: an out-of-sample loss or a training loss that is `NaN` or `±Inf` -- `TimeLimit(t=5/60)`: training time has exceeded 5 minutes. -We can specify how often these checks (and other controls) are applied using the `Step` control. Additionally, we can define controls to -- save a snapshot of the machine every N control cycles (`save_control`) -- record traces of the out-of-sample loss and training losses for plotting (`WithLossDo`) -- record mean value traces of each Flux parameter for plotting (`Callback`) -And other controls. For a full list, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). -First, we define some helper functions and some empty vectors to store traces: -```julia -make2d(x::AbstractArray) = reshape(x, :, size(x)[end]) -make1d(x::AbstractArray) = reshape(x, length(x)); - -# to extract the flux parameters from a machine -parameters(mach) = make1d.(Flux.params(fitted_params(mach))); - -# trace storage -losses = [] -training_losses = [] -parameter_means = Float32[]; -epochs = [] - -# to update traces -update_loss(loss) = push!(losses, loss) -update_training_loss(losses) = push!(training_losses, losses[end]) -update_means(mach) = append!(parameter_means, mean.(parameters(mach))); -update_epochs(epoch) = push!(epochs, epoch) -``` -Next, we can define our controls! We store them in a simple vector: -```julia -save_control = - MLJIteration.skip(Save(joinpath(DIR, "mnist.jlso")), predicate=3) - -controls=[Step(2), - Patience(3), - InvalidValue(), - TimeLimit(5/60), - save_control, - WithLossDo(), - WithLossDo(update_loss), - WithTrainingLossesDo(update_training_loss), - Callback(update_means), - WithIterationsDo(update_epochs) -``` -Once the controls are defined, we can instantiate and fit our "self-iterating" classifier: -```julia -iterated_clf = IteratedModel(model=clf, - controls=controls, - resampling=Holdout(fraction_train=0.7), - measure=log_loss) - -mach = machine(iterated_clf, images, labels); -fit!(mach, rows=1:500); -``` -Next we can compare the training and out-of-sample losses, as well as view the evolution of the weights: -```julia -plot(epochs, losses, - xlab = "epoch", - ylab = "root squared error", - label="out-of-sample") -plot!(epochs, training_losses, label="training") - -savefig(joinpath(DIR, "loss.png")) - -n_epochs = length(losses) -n_parameters = div(length(parameter_means), n_epochs) -parameter_means2 = reshape(copy(parameter_means), n_parameters, n_epochs)' -plot(epochs, parameter_means2, - title="Flux parameter mean weights", - xlab = "epoch") -# **Note.** The the higher the number, the deeper the chain parameter. -savefig(joinpath(DIR, "weights.png")) -``` -Since we saved our model every few epochs, we can retrieve the snapshots so we can make predictions! -```julia -mach2 = machine(joinpath(DIR, "mnist3.jlso")) -predict_mode(mach2, images[501:503]) -``` - -## Resuming training - -If we change `iterated_clf.controls` or `clf.epochs`, we can resume training from where it left off. This is very useful for long-running training sessions, where you may be interrupted by for example a bad connection or computer hibernation. -```julia -iterated_clf.controls[2] = Patience(4) -fit!(mach, rows=1:500) - -plot(epochs, losses, - xlab = "epoch", - ylab = "root squared error", - label="out-of-sample") -plot!(epochs, training_losses, label="training") -``` See also [`NeuralNetworkClassifier`](@ref) """ @@ -479,7 +403,7 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] end function $Model(; builder::B = Linear() - , optimiser::O = Flux.Optimise.ADAM() + , optimiser::O = Flux.Optimise.Adam() , loss::L = Flux.mse , epochs = 10 , batch_size = 1 @@ -516,28 +440,32 @@ end """ $(MMI.doc_header(NeuralNetworkRegressor)) -`NeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` target, given a table of `Continuous` features. +`NeuralNetworkRegressor` is for training a data-dependent Flux.jl neural +network to predict a `Continuous` target, given a table of +`Continuous` features. Users provide a recipe for constructing the +network, based on properties of the data that is encountered, by specifying +an appropriate `builder`. See MLJFlux documentation for more on builders. # Training data In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) Where - `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` + are of scitype `Continuous`; check the column scitypes with `schema(X)`. - `y`: is the target, which can be any `AbstractVector` whose element scitype is `Continuous`; check the scitype with `scitype(y)` # Hyper-parameters -- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. - Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder - using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the updating +- `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs + a neural network. Possible `builders` include: `MLJFlux.Linear`, `MLJFlux.Short`, + and `MLJFlux.MLP`. See below for an example of a user-specified builder. +- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a learning rate (the update rate of the optimizer), a good rule of thumb is to @@ -550,34 +478,27 @@ Where - `Flux.mae` - `Flux.msle` - `Flux.huber_loss` + Currently MLJ measures are not supported as loss functions here. - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents - the number of samples per update of the networks weights. Typcally, batch size should be - somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, - while larger batch sizes lead towards smoother training loss curves. - In general, it is a good idea to pick one fairly large batch size (e.g. 32, 64, 128), - and stick with it, and only tune the learning rate. In most examples, batch size is set - in powers of twos, but this is fairly arbitrary. +- `batch_size::Int=1`: The batch size to be used for training. The batch size + represents the number of samples per update of the networks weights. Batch + sizes between 8 and 512 are typical. Increasing batch size can speed up + training, especially on a GPU (`acceleration=CUDALibs()`). - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a - machine if the associated optimiser has changed. If true, the associated machine will - retrain from scratch on `fit`, otherwise it will not. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit!`, otherwise it will not. - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. - For training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. - Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include - `Flux.sigmoid` and the identity function (otherwise known as "linear activation"). +For training on GPU, use `CudaLibs()`. For training on GPU, use `CUDALibs()`. # Operations - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are + features `Xnew` having the same scitype as `X` above. Predictions are deterministic. @@ -585,18 +506,18 @@ Where The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which - make up the neural network. +- `chain`: The trained "chain" (Flux.jl model), namely the series of layers, + functions, and activations which make up the neural network. This includes + the final layer specified by `finaliser` (eg, `softmax`). # Report The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - all the losses during training. The first element of the vector is the initial penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - penalized loss. After the first element, the nth element corresponds to the loss of epoch n-1. - epoch n-1. +- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in + historical order, of length `epochs + 1`. The first element is the pre-training loss. + # Examples In this example we build a regression model using the Boston house price dataset @@ -604,7 +525,6 @@ In this example we build a regression model using the Boston house price dataset using MLJ using MLJFlux using Flux - using Plots ``` First, we load in the data, with target `:MEDV`. We load in all features except `:CHAS`: ```julia @@ -641,10 +561,9 @@ NeuralNetworkRegressor = @load NeuralNetworkRegressor rng=123, epochs=20) ``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! -neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +We will arrange for standardizaion of the the target by wrapping our model + in `TransformedTargetModel`, and standardization of the features by +inserting the wrapped model in a pipeline: ```julia pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) ``` @@ -663,7 +582,7 @@ report(mach).transformed_target_model_deterministic.training_losses We can visually compare how the learning rate affects the predictions: ```julia -plt = plot() +using Plots rates = 10. .^ (-5:0) @@ -674,114 +593,21 @@ foreach(rates) do η report(mach).transformed_target_model_deterministic.model.training_losses[3:end] plot!(1:length(losses), losses, label=η) end -plt #!md -savefig(joinpath("assets", "learning_rate.png")) pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 -``` -## Using Iteration Controls +# CV estimate, based on `(X, y)`: +evaluate!(mach, resampling=CV(nfolds=5), measure=l2) -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. -trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. -`NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as update the traces. -```julia -# For initializing or clearing the traces: - -clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing -end - - # And to update the traces: - -update_loss(loss) = push!(losses, loss) -update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) -update_epochs(epoch) = push!(epochs, epoch) +# loss for `(Xtest, test)`: +fit!(mach) # train on `(X, y)` +yhat = predict(mach, Xtest) +l2(yhat, ytest) |> mean ``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia -controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), -WithIterationsDo(update_epochs)] - - -iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia -clear() -mach = machine(iterated_pipe, X, y) -fit!(mach) - -plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); -scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - -savefig(joinpath("assets", "loss.png")) -``` - -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an -out-of-sample loss can be tracked and used in the specified stopping criterion, -`NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by -`IteratedModel` (our pipeline model) is retrained on all data for the same number of -iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned -parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia -e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -using Measurements -l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) -@show l1_loss -``` -We take this estimate of the uncertainty of the generalization error with a [grain of -salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the -untouched test set, of the eror of our self-iterating neural network regressor with a -couple of other models trained on the same data (using default hyperparameters): -```julia -function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean -end -performance(iterated_pipe) - -three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - -errs = performance.(three_models) - -(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty -``` +For impementing stopping criterion and other iteration controls, refer to examples linked +from the MLJFlux documentation See also [`MultitargetNeuralNetworkRegressor`](@ref) @@ -791,19 +617,22 @@ NeuralNetworkRegressor """ $(MMI.doc_header(MultitargetNeuralNetworkRegressor)) -`MultitargetNeuralNetworkRegressor`: A neural network model for making deterministic -predictions of a `Continuous` multi-target, presented as a table, given a table of -`Continuous` features. +`MultitargetNeuralNetworkRegressor` is for training a data-dependent Flux.jl + neural network to predict a multivalued `Continuous` target, represented as a table, + given a table of `Continuous` features. Users provide a recipe for constructing the + network, based on properties of the data that is encountered, by specifying an +appropriate `builder`. See MLJFlux documentation for more on builders. # Training data In MLJ or MLJBase, bind an instance `model` to data with + mach = machine(model, X, y) Where - `X`: is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the scitype with `schema(X)` + are of scitype `Continuous`; check the column scitypes with `schema(X)`. - `y`: is the target, which can be any table of output targets whose element scitype is `Continuous`; check the scitype with `schema(y)` @@ -813,7 +642,7 @@ Where - `builder=MLJFlux.Linear(σ=Flux.relu)`: An MLJFlux builder that constructs a neural network. Possible `builders` include: `Linear`, `Short`, and `MLP`. You can construct your own builder using the `@builder` macro, see examples for further information. -- `optimiser::Flux.ADAM()`: A `Flux.Optimise` optimiser. The optimiser performs the +- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the updating of the weights of the network. For further reference, see either the examples or [the Flux optimiser documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a @@ -828,34 +657,27 @@ Where - `Flux.mae` - `Flux.msle` - `Flux.huber_loss` + Currently MLJ measures are not supported as loss functions here. - `epochs::Int=10`: The number of epochs to train for. Typically, one epoch represents one pass through the entirety of the training dataset. -- `batch_size::Int=1`: The batch size to be used for training. The batch size represents - the number of samples per update of the networks weights. Typcally, batch size should be - somewhere between 8 and 512. Smaller batch sizes lead to noisier training loss curves, - while larger batch sizes lead towards smoother training loss curves. In general, it is a - good idea to pick one fairly large batch size (e.g. 32, 64, 128), and stick with it, and - only tune the learning rate. In most literature, batch size is set in powers of twos, - but this is fairly arbitrary. +- `batch_size::Int=1`: The batch size to be used for training. The batch size + represents the number of samples per update of the networks weights. Batch + sizes between 8 and 512 are typical. Increasing batch size can speed up + training, especially on a GPU (`acceleration=CUDALibs()`). - `lambda::Float64=0`: The stregth of the regularization used during training. Can be any value in the range `[0, ∞)`. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator/seed used during training. -- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting - a machine if the associated optimiser has changed. If true, the associated machine will - retrain from scratch on `fit`, otherwise it will not. +- `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when fitting a machine if the associated optimiser has changed. If true, the associated machine will retrain from scratch on `fit!`, otherwise it will not. - `acceleration::AbstractResource=CPU1()`: Defines on what hardware training is done. - For Training on GPU, use `CudaLibs()`, otherwise defaults to `CPU`()`. -- `finaliser=Flux.softmax`: The final activation function of the neural network. -Defaults to `Flux.softmax`. For a regression task, reasonable alternatives include -`Flux.sigmoid` and the identity function (otherwise known as "linear activation"). +For Training on GPU, use `CudaLibs()`. For training on GPU, use `CUDALibs()`. # Operations - `predict(mach, Xnew)`: return predictions of the target given new - features `Xnew` having the same Scitype as `X` above. Predictions are + features `Xnew` having the same scitype as `X` above. Predictions are deterministic. @@ -863,18 +685,17 @@ Defaults to `Flux.softmax`. For a regression task, reasonable alternatives inclu The fields of `fitted_params(mach)` are: -- `chain`: The trained "chain", or series of layers, functions, and activations which - make up the neural network. +- `chain`: The trained "chain" (Flux.jl model), namely the series of layers, + functions, and activations which make up the neural network. This includes + the final layer specified by `finaliser` (eg, `softmax`). # Report The fields of `report(mach)` are: -- `training_losses`: The history of training losses, a vector containing the history of - all the losses during training. The first element of the vector is the initial - penalized loss. After the first element, the nth element corresponds to the loss of - epoch n-1. +- `training_losses`: A vector of training losses (penalised if `lambda != 0`) in + historical order, of length `epochs + 1`. The first element is the pre-training loss. # Examples @@ -883,7 +704,6 @@ In this example we build a regression model using a toy dataset. using MLJ using MLJFlux using Flux -using Plots using MLJBase: augment_X ``` First, we generate some data: @@ -913,11 +733,11 @@ end Finally, we can define the model! ```julia MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor - model = MultitargetNeuralNetworkRegressor(builder=builder, - rng=123, - epochs=20) +model = MultitargetNeuralNetworkRegressor(builder=builder, rng=123, epochs=20) ``` -For our neural network, since different features likely have different scales, if we do not standardize the network may be implicitly biased towards features with higher magnitudes, or may have [saturated neurons](https://www.informit.com/articles/article.aspx%3fp=3131594&seqNum=2) and not train well. Therefore, standardization is key! +We will arrange for standardizaion of the the target by wrapping our model + in `TransformedTargetModel`, and standardization of the features by +inserting the wrapped model in a pipeline: ```julia pipe = Standardizer |> TransformedTargetModel(model, target=Standardizer) ``` @@ -936,7 +756,7 @@ report(mach).transformed_target_model_deterministic.training_losses We can visually compare how the learning rate affects the predictions: ```julia -plt = plot() +using Plots rates = 10. .^ (-5:0) @@ -947,108 +767,28 @@ foreach(rates) do η report(mach).transformed_target_model_deterministic.model.training_losses[3:end] plot!(1:length(losses), losses, label=η) end -plt #!md -savefig(joinpath("assets", "learning_rate.png")) pipe.transformed_target_model_deterministic.model.optimiser.eta = 0.0001 ``` -## Using Iteration Controls - -We can also wrap the model with MLJ Iteration controls. Suppose we want a model that trains until the out of sample loss does not improve for 6 epochs. We can use the `NumberSinceBest(6)` stopping criterion. We can also add some extra stopping criterion, `InvalidValue` and `Timelimit(1/60)`, as well as some controls to print traces of the losses. First we can define some methods to initialize or clear the traces as well as updte the traces. +With the learning rate fixed, we can now compute a CV estimate of the performance (using +all data bound to `mach`) and compare this with performance on the test set: ```julia -# For initializing or clearing the traces: +# custom MLJ loss: +multi_loss(yhat, y) = l2(MLJ.matrix(yhat), MLJ.matrix(y)) |> mean -clear() = begin - global losses = [] - global training_losses = [] - global epochs = [] - return nothing -end - -# And to update the traces: +# CV estimate, based on `(X, y)`: +evaluate!(mach, resampling=CV(nfolds=5), measure=multi_loss) -update_loss(loss) = push!(losses, loss) -update_training_loss(report) = - push!(training_losses, - report.transformed_target_model_deterministic.model.training_losses[end]) -update_epochs(epoch) = push!(epochs, epoch) -``` -For further reference of controls, see [the documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/%23Controls-provided). To apply the controls, we simply stack them in a vector and then make an `IteratedModel`: -```julia -controls=[Step(1), - NumberSinceBest(6), - InvalidValue(), - TimeLimit(1/60), - WithLossDo(update_loss), - WithReportDo(update_training_loss), -WithIterationsDo(update_epochs)] - -iterated_pipe = - IteratedModel(model=pipe, - controls=controls, - resampling=Holdout(fraction_train=0.8), - measure = l2) -``` -Next, we can clear the traces, fit the model, and plot the traces: -```julia -clear() -mach = machine(iterated_pipe, X, y) +# loss for `(Xtest, test)`: fit!(mach) - -plot(epochs, losses, - xlab = "epoch", - ylab = "mean sum of squares error", - label="out-of-sample", - legend = :topleft); -scatter!(twinx(), epochs, training_losses, label="training", color=:red) #!md - -savefig(joinpath("assets", "loss.png")) +yhat = predict(mach, Xtest) +multi_loss(yhat, y) ``` -### Brief note on iterated models - -Training an `IteratedModel` means holding out some data (80% in this case) so an out-of-sample loss can be tracked and used in the specified stopping criterion, `NumberSinceBest(4)`. However, once the stop is triggered, the model wrapped by `IteratedModel` (our pipeline model) is retrained on all data for the same number of iterations. Calling `predict(mach, Xnew)` on new data uses the updated learned parameters. - -## Evaluating Iterated Models - -We can evaluate our model with the `evaluate!` function: -```julia -e = evaluate!(mach, - resampling=CV(nfolds=8), - measures=[l1, l2]) - -using Measurements -l1_loss = e.measurement[1] ± std(e.per_fold[1])/sqrt(7) -@show l1_loss -``` -We take this estimate of the uncertainty of the generalization error with a [grain of salt](https://direct.mit.edu/neco/article-abstract/10/7/1895/6224/Approximate-Statistical-Tests-for-Comparing)). - -## Comparison with other models on the test set - -Although we cannot assign them statistical significance, here are comparisons, on the untouched test set, of the eror of our self-iterating neural network regressor with a couple of other models trained on the same data (using default hyperparameters): -```julia - -function performance(model) - mach = machine(model, X, y) |> fit! - yhat = predict(mach, Xtest) - l1(yhat, ytest) |> mean -end -performance(iterated_pipe) - -three_models = [(@load EvoTreeRegressor)(), # tree boosting model - (@load LinearRegressor pkg=MLJLinearModels)(), - iterated_pipe] - -errs = performance.(three_models) - -(models=MLJ.name.(three_models), mean_square_errors=errs) |> pretty - - -``` See also [`NeuralNetworkRegressor`](@ref) """