Skip to content

Commit

Permalink
Merge pull request #21 from JuliaAI/measure
Browse files Browse the repository at this point in the history
Address breaking changes in MLJBase 1.0
  • Loading branch information
ablaom authored Sep 27, 2023
2 parents bed1dfc + 76ad34f commit e6a02cc
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 9 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJFlow"
uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f"
authors = ["Jose Esparza <[email protected]>"]
version = "0.1.1"
version = "0.2.0"

[deps]
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
Expand All @@ -10,15 +10,16 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"

[compat]
MLFlowClient = "0.4.4"
MLJBase = "0.21.14"
MLJBase = "1"
MLJModelInterface = "1.9.1"
julia = "1.6"

[extras]
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface"]
test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"]
3 changes: 1 addition & 2 deletions src/MLJFlow.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module MLJFlow

using MLJBase: info, name, Model,
Machine
using MLJBase: Model, Machine, name
using MLJModelInterface: flat_params
using MLFlowClient: MLFlow, logparam, logmetric,
createrun, MLFlowRun, updaterun,
Expand Down
5 changes: 4 additions & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ function log_evaluation(logger::MLFlowLogger, performance_evaluation)
artifact_location=logger.artifact_location)
run = createrun(logger.service, experiment;
tags=[
Dict("key" => "resampling", "value" => string(performance_evaluation.resampling)),
Dict(
"key" => "resampling",
"value" => string(performance_evaluation.resampling)
),
Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)),
Dict("key" => "model type", "value" => name(performance_evaluation.model)),
]
Expand Down
43 changes: 40 additions & 3 deletions src/service.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,57 @@ function logmodelparams(service::MLFlow, run::MLFlowRun, model::Model)
end
end

const MLFLOW_CHAR_SET =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-. /"

"""
good_name(measure)
**Private method.**
Returns a string representation of `measure` that can be used as a valid name in
MLflow. Includes the value of the first hyperparameter, if there is one.
```julia
julia> good_name(macro_f1score)
"MulticlassFScore-beta_1.0"
"""
function good_name(measure)
name = string(measure)
name = replace(name, ", …" => "")
name = replace(name, " = " => "_")
name = replace(name, "()" => "")
name = replace(name, ")" => "")
map(collect(name)) do char
char in ['(', ','] && return '-'
char == '=' && return '_'
char in MLFLOW_CHAR_SET && return char
" "
end |> join
end

"""
logmachinemeasures(service::MLFlow, run::MLFlowRun, model::Model)
Extracts the parameters of a model and logs them to the MLFlow server.
# Arguments
- `service::MLFlow`: An MLFlow service. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
- `run::MLFlowRun`: An MLFlow run. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun)
- `service::MLFlow`: An MLFlow service. See
[MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow)
- `run::MLFlowRun`: An MLFlow run. See
[MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun)
- `measures`: A vector of measures.
- `measurements`: A vector of measurements.
"""
function logmachinemeasures(service::MLFlow, run::MLFlowRun, measures,
measurements)
measure_names = measures .|> info .|> x -> x.name
measure_names = measures .|> good_name
for (name, value) in zip(measure_names, measurements)
logmetric(service, run, name, value)
end
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using MLJBase
using MLJModels
using MLFlowClient
using MLJModelInterface
using StatisticalMeasures

include("base.jl")
include("types.jl")
include("service.jl")

7 changes: 7 additions & 0 deletions test/service.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@testset "good_name" begin
@test MLJFlow.good_name(rms) == "RootMeanSquaredError"
@test MLJFlow.good_name(macro_f1score) == "MulticlassFScore-beta_1.0"
@test MLJFlow.good_name(log_score) == "LogScore-tol_2.22045e-16"
end

true

0 comments on commit e6a02cc

Please sign in to comment.