Skip to content

Commit

Permalink
add more parameters to control fitting, and add data checks (#24)
Browse files Browse the repository at this point in the history
* add more parameters to control fitting, and add data checks

* Apply suggestions from code review

Co-authored-by: Rik Huijzer <[email protected]>

Co-authored-by: Rik Huijzer <[email protected]>
  • Loading branch information
OkonSamuel and rikhuijzer authored Feb 22, 2022
1 parent 5f6db29 commit 35c2bc0
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 52 deletions.
195 changes: 156 additions & 39 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ end
fit_intercept::Bool = true
link::GLM.Link01 = GLM.LogitLink()
offsetcol::Union{Symbol, Nothing} = nothing
maxiter::Integer = 30
atol::Real = 1e-6
rtol::Real = 1e-6
minstepfac::Real = 0.001
report_keys::KEYS_TYPE = DEFAULT_KEYS::(isnothing(_) || issubset(_, VALID_KEYS))
end

Expand All @@ -68,6 +72,10 @@ end
distribution::Distribution = Poisson()
link::GLM.Link = GLM.LogLink()
offsetcol::Union{Symbol, Nothing} = nothing
maxiter::Integer = 30
atol::Real = 1e-6
rtol::Real = 1e-6
minstepfac::Real = 0.001
report_keys::KEYS_TYPE = DEFAULT_KEYS::(isnothing(_) || issubset(_, VALID_KEYS))
end

Expand All @@ -87,17 +95,19 @@ Augment the matrix `X` with a column of ones if the intercept is to be
fitted (`b=true`), return `X` otherwise.
"""
function augment_X(X::Matrix, b::Bool)::Matrix
b && return hcat(X, ones(eltype(X), size(X, 1), 1))
b && return hcat(X, ones(float(Int), size(X, 1), 1))
return X
end

_to_vector(v::Vector) = v
_to_vector(v) = collect(v)
_to_array(v::AbstractArray) = v
_to_array(v) = collect(v)

"""
split_X_offset(X, offsetcol::Nothing)
When no offset is specied, return X and an empty vector.
When no offset is specified, return `X` and an empty vector.
"""
split_X_offset(X, offsetcol::Nothing) = (X, Float64[])

Expand All @@ -115,19 +125,125 @@ function split_X_offset(X, offsetcol::Symbol)
return newX, _to_vector(offset)
end

# If `estimates_dispersion_param` returns `false` then the dispersion
# parameter isn't estimated from data but known apriori to be `1`.
estimates_dispersion_param(::LinearRegressor) = true
estimates_dispersion_param(::LinearBinaryClassifier) = false

function estimates_dispersion_param(model::LinearCountRegressor)
return GLM.dispersion_parameter(model.distribution)
end

function _throw_sample_size_error(model, est_dispersion_param)
requires_info = _requires_info(model, est_dispersion_param)

if isnothing(model.offsetcol)
offset_info = " `offsetcol == nothing`"
else
offset_info = " `offsetcol !== nothing`"
end

modelname = nameof(typeof(model))
if model isa LinearCountRegressor
distribution_info = "and `distribution = $(nameof(typeof(model.distribution)))()`"
else
distribution_info = "\b"
end

throw(
ArgumentError(
" `$(modelname)` with `fit_intercept = $(model.fit_intercept)`,"*
"$(offset_info) $(distribution_info) requires $(requires_info)"
)
)
return nothing
end

"""
_requires_info(model, est_dispersion_param)
Returns one of the following strings
- "`n_samples >= n_features`", "`n_samples > n_features`"
- "`n_samples >= n_features - 1`", "`n_samples > n_features - 1`"
- "`n_samples >= n_features + 1`", "`n_samples > n_features + 1`"
"""
function _requires_info(model, est_dispersion_param)
inequality = est_dispersion_param ? ">" : ">="
int_num = model.fit_intercept - !isnothing(model.offsetcol)

if iszero(int_num)
int_num_string = "\b"
elseif int_num < 0
int_num_string = "- $(abs(int_num))"
else
int_num_string = "+ $(int_num)"
end

return "`n_samples $(inequality) n_features $(int_num_string)`."
end

function check_sample_size(model, n, p)
if estimates_dispersion_param(model)
n <= p + model.fit_intercept && _throw_sample_size_error(model, true)
else
n < p + model.fit_intercept && _throw_sample_size_error(model, false)
end
return nothing
end

function _matrix_and_features(model, Xcols, handle_intercept=false)
col_names = Tables.columnnames(Xcols)
n, p = Tables.rowcount(Xcols), length(col_names)
augment = handle_intercept && model.fit_intercept

if !handle_intercept # i.e This only runs during `fit`
check_sample_size(model, n, p)
end

if p == 0
Xmatrix = Matrix{float(Int)}(undef, n, p)
else
Xmatrix = Tables.matrix(Xcols)
end

Xmatrix = augment_X(Xmatrix, augment)

return Xmatrix, col_names
end

_to_columns(t::Tables.AbstractColumns) = t
_to_columns(t) = Tables.Columns(t)

"""
prepare_inputs(model, X; handle_intercept=false)
Handle `model.offsetcol` and `model.fit_intercept` if `handle_intercept=true`.
`handle_intercept` is disabled for fitting since the StatsModels.@formula handles the intercept.
"""
function prepare_inputs(model, X; handle_intercept=false)
Xminoffset, offset = split_X_offset(X, model.offsetcol)
Xmatrix = MMI.matrix(Xminoffset)
if handle_intercept
Xmatrix = augment_X(Xmatrix, model.fit_intercept)
Xcols = _to_columns(X)
table_features = Tables.columnnames(Xcols)
p = length(table_features)
p >= 1 || throw(
ArgumentError("`X` must contain at least one feature column.")
)
if !isnothing(model.offsetcol)
model.offsetcol in table_features || throw(
ArgumentError("offset column `$(model.offsetcol)` not found in table `X")
)
if p < 2 && !model.fit_intercept
throw(
ArgumentError(
"At least 2 feature columns are required for learning with"*
" `offsetcol !== nothing` and `fit_intercept == false`."
)
)
end
end
return Xmatrix, offset
Xminoffset, offset = split_X_offset(Xcols, model.offsetcol)
Xminoffset_cols = _to_columns(Xminoffset)
Xmatrix, features = _matrix_and_features(model, Xminoffset_cols , handle_intercept)
return Xmatrix, offset, _to_array(features)
end

"""
Expand Down Expand Up @@ -170,7 +286,6 @@ function glm_report(glm_model, features, reportkeys)
return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
end


"""
glm_formula(model, features) -> FormulaTerm
Expand All @@ -191,31 +306,10 @@ end
Return data which is ready to be passed to `fit(form, data, ...)`.
"""
function glm_data(model, Xmatrix, y, features)
header = collect(features)
data = Tables.table([Xmatrix y]; header=[header; :y])
data = Tables.table([Xmatrix y]; header=[features...; :y])
return data
end

_to_array(v::AbstractArray) = v
_to_array(v) = collect(v)

"""
glm_features(model, X)
Returns an iterable features object, to be used in the construction of
glm formula and glm data header.
"""
function glm_features(model, X)
if Tables.columnaccess(X)
table_features = _to_array(keys(Tables.columns(X)))
else
first_row = iterate(Tables.rows(X), 1)[1]
table_features = first_row === nothing ? Symbol[] : _to_array(keys(first_row))
end
filter!(!=(model.offsetcol), table_features)
return table_features
end

"""
check_weights(w, y)
Expand Down Expand Up @@ -259,8 +353,7 @@ params(fr::FitResult) = fr.params

function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
y_ = isempty(offset) ? y : y .- offset
wts = check_weights(w, y_)
data = glm_data(model, Xmatrix, y_, features)
Expand All @@ -278,12 +371,20 @@ end

function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y, features)
wts = check_weights(w, y)
form = glm_formula(model, features)
fitted_glm = GLM.glm(form, data, model.distribution, model.link; offset, wts).model
fitted_glm_frame = GLM.glm(
form, data, model.distribution, model.link;
offset,
model.maxiter,
model.atol,
model.rtol,
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)
Expand All @@ -299,11 +400,19 @@ function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
decode = y[1]
y_plain = MMI.int(y) .- 1 # 0, 1 of type Int
wts = check_weights(w, y_plain)
Xmatrix, offset = prepare_inputs(model, X)
features = glm_features(model, X)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y_plain, features)
form = glm_formula(model, features)
fitted_glm = GLM.glm(form, data, Bernoulli(), model.link; offset, wts).model
fitted_glm_frame = GLM.glm(
form, data, Bernoulli(), model.link;
offset,
model.maxiter,
model.atol,
model.rtol,
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)
Expand Down Expand Up @@ -342,9 +451,17 @@ glm_link(::LinearRegressor) = GLM.IdentityLink()

# more efficient than MLJBase fallback
function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew)
Xmatrix, offset = prepare_inputs(model, Xnew; handle_intercept=true)
Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true)
result = glm_fitresult(model, fitresult) # ::FitResult
coef = coefs(result)
p = size(Xmatrix, 2)
if p != length(coef)
throw(
DimensionMismatch(
"The number of features in training and prediction datasets must be equal"
)
)
end
link = glm_link(model)
return glm_predict(link, coef, Xmatrix, model.offsetcol, offset)
end
Expand Down
Loading

0 comments on commit 35c2bc0

Please sign in to comment.