Skip to content

Commit

Permalink
Merge pull request #4 from mrc-ide/mrc-5220
Browse files Browse the repository at this point in the history
Parse basic expressions
  • Loading branch information
richfitz authored Jul 16, 2024
2 parents 232910a + 8848311 commit 86ebfea
Show file tree
Hide file tree
Showing 12 changed files with 646 additions and 1 deletion.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ Description: Temporary package for rewriting odin.
License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.1
RoxygenNote: 7.3.2
URL: https://github.com/mrc-ide/odin2
BugReports: https://github.com/mrc-ide/odin2/issues
Imports:
cli,
mcstate2,
rlang
Suggests:
testthat (>= 3.0.0),
withr
Config/testthat/edition: 3
Remotes:
mrc-ide/mcstate2
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# Generated by roxygen2: do not edit by hand

S3method(cnd_footer,odin_parse_error)
importFrom(rlang,cnd_footer)
5 changes: 5 additions & 0 deletions R/constants.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SPECIAL_LHS <- c(
"initial", "deriv", "update", "output", "dim", "config", "compare")

FUNCTIONS <- list(
exp = function(x) {})
25 changes: 25 additions & 0 deletions R/dependencies.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
find_dependencies <- function(expr) {
functions <- collector()
variables <- collector()
descend <- function(e) {
if (is.recursive(e)) {
nm <- deparse(e[[1L]])
## If we hit dim/length here we should not take a dependency
## most of the time though.
functions$add(nm)
for (el in as.list(e[-1])) {
if (!missing(el)) {
descend(el)
}
}
} else {
if (is.symbol(e)) {
variables$add(deparse(e))
}
}
}
descend(expr)
list(functions = unique(functions$get()),
variables = unique(variables$get()))

}
1 change: 1 addition & 0 deletions R/parse.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
odin_parse <- function(expr, input_type = NULL) {
call <- environment()
dat <- parse_prepare(rlang::enquo(expr), input_type, call)
exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call))
NULL
}
30 changes: 30 additions & 0 deletions R/parse_error.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
odin_parse_error <- function(msg, src, call, ...,
.envir = parent.frame()) {
if (!is.null(names(src))) {
src <- list(src)
}
cli::cli_abort(msg,
class = "odin_parse_error",
src = src,
call = call,
...,
.envir = .envir)
}


##' @importFrom rlang cnd_footer
##' @export
cnd_footer.odin_parse_error <- function(cnd, ...) {
## TODO: later, we might want to point at specific bits of the error
## and say "here, this is where you are wrong" but that's not done
## yet...
src <- cnd$src
if (is.null(src[[1]]$str)) {
context <- unlist(lapply(cnd$src, function(x) deparse1(x$value)))
} else {
line <- unlist(lapply(src, function(x) seq(x$start, x$end)))
src <- unlist(lapply(src, "[[", "str"))
context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src)
}
c(">" = "Context:", context)
}
233 changes: 233 additions & 0 deletions R/parse_expr.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
parse_expr <- function(expr, src, call) {
if (rlang::is_call(expr, c("<-", "="))) {
parse_expr_assignment(expr, src, call)
} else if (rlang::is_call(expr, "~")) {
parse_expr_compare(expr, src, call)
} else if (rlang::is_call(expr, "print")) {
parse_expr_print(expr, src, call)
} else {
odin_parse_error(
c("Unclassifiable expression",
i = "Expected an assignment (with '<-') or a relationship (with '~')"),
src, call)
}
}


parse_expr_assignment <- function(expr, src, call) {
lhs <- parse_expr_assignment_lhs(expr[[2]], src, call)
rhs <- parse_expr_assignment_rhs(expr[[3]], src, call)

special <- lhs$special
lhs$special <- NULL
if (rhs$type == "data") {
if (!is.null(special)) {
odin_parse_error(
"Calls to 'data()' must be assigned to a symbol",
src, call)
}
special <- "data"
}

list(special = special,
lhs = lhs,
rhs = rhs,
src = src)
}


parse_expr_assignment_lhs <- function(lhs, src, call) {
array <- NULL
special <- NULL
name <- NULL

if (rlang::is_call(lhs, SPECIAL_LHS)) {
special <- deparse1(lhs[[1]])
if (length(lhs) != 2 || !is.null(names(lhs))) {
odin_parse_error(
c("Invalid special function call",
i = "Expected a single unnamed argument to '{special}()'"),
src, call)
}
if (special == "compare") {
## TODO: a good candidate for pointing at the source location of
## the error.
odin_parse_error(
c("'compare()' expressions must use '~', not '<-'",
i = paste("Compare expressions do not represent assignents, but",
"relationships, which we emphasise by using '~'. This",
"also keeps the syntax close to that for the prior",
"specification in mcstate2")),
src, call)
}
lhs <- lhs[[2]]
}

is_array <- rlang::is_call(lhs, "[")
if (is_array) {
odin_parse_error(
"Arrays are not supported yet", src, call)
}

name <- parse_expr_check_lhs_name(lhs, src, call)

lhs <- list(
name = name,
special = special)
}


parse_expr_assignment_rhs <- function(rhs, src, call) {
if (rlang::is_call(rhs, "delay")) {
odin_parse_error("'delay()' is not implemented yet", src, call)
} else if (rlang::is_call(rhs, "parameter")) {
parse_expr_assignment_rhs_parameter(rhs, src, call)
} else if (rlang::is_call(rhs, "data")) {
parse_expr_assignment_rhs_data(rhs, src, call)
} else if (rlang::is_call(rhs, "interpolate")) {
odin_parse_error("'interpolate()' is not implemented yet", src, call)
} else {
parse_expr_assignment_rhs_expression(rhs, src, call)
}
}


parse_expr_assignment_rhs_expression <- function(rhs, src, call) {
## So here, we want to find dependencies used in the rhs and make
## sure that the user correctly restricts to the right set of
## dependencies. There is some faff with sum, and we detect here if
## the user uses anything stochastic. We do look for the range
## operator but I'm not totlaly sure that's the best place to do so.
depends <- find_dependencies(rhs)
list(type = "expression",
expr = rhs,
depends = depends)
}


parse_expr_check_lhs_name <- function(lhs, src, call) {
## There are lots of checks we should add here, but fundamentally
## it's a case of making sure that we have been given a symbol and
## that symbol is not anything reserved, nor does it start with
## anything reserved. Add these in later, see
## "ir_parse_expr_check_lhs_name" for details.
if (!rlang::is_symbol(lhs)) {
odin_parse_error("Expected a symbol on the lhs", src, call)
}
name <- deparse1(lhs)
name
}


## TODO: we'll have a variant of this that acts as a compatibility
## layer for user(), as this is probably the biggest required change
## to people's code, really.
parse_expr_assignment_rhs_parameter <- function(rhs, src, call) {
template <- function(default = NULL, constant = NULL, differentiate = FALSE) {
}
result <- match_call(rhs, template)
if (!result$success) {
odin_parse_error(
c("Invalid call to 'parameter()'",
x = conditionMessage(result$error)),
src, call)
}
args <- as.list(result$value)[-1]
if (is.language(args$default)) {
deps <- find_dependencies(args$default)
if (length(deps$variables) > 0) {
default_str <- deparse1(args$default)
odin_parse_error(
c("Invalid default argument to 'parameter()': {default_str}",
i = paste("Default arguments can only perform basic arithmetic",
"operations on numbers, and may not reference any",
"other parameter or variable")),
src, call)
}
## TODO: validate the functions used at some point, once we do
## that generally.
}

if (!is_scalar_logical(args$differentiate)) {
str <- deparse1(args$differentiate)
odin_parse_error(
"'differentiate' must be a scalar logical, but was '{str}'",
src, call)
}
## constant has a different default
if (is.null(args$constant)) {
args$constant <- NA
} else if (!is_scalar_logical(args$constant)) {
str <- deparse1(args$constant)
odin_parse_error(
"'constant' must be a scalar logical if given, but was '{str}'",
src, call)
}
list(type = "parameter",
args = args)
}


parse_expr_assignment_rhs_data <- function(rhs, src, call) {
if (length(rhs) != 1) {
odin_parse_error("Calls to 'data()' must have no arguments",
src, call)
}
list(type = "data")
}


parse_expr_compare <- function(expr, src, call) {
lhs <- parse_expr_compare_lhs(expr[[2]], src, call)
rhs <- parse_expr_compare_rhs(expr[[3]], src, call)

## Quickly rewrite the expression, at least for now:
rhs$expr <- as.call(c(list(rhs$expr[[1]], lhs),
as.list(rhs$expr[-1])))
rhs$depends$variables <- union(rhs$depends$variables,
as.character(lhs))
list(special = "compare",
rhs = rhs,
src = src)
}


parse_expr_compare_lhs <- function(lhs, src, call) {
if (!rlang::is_call(lhs, "compare")) {
## TODO: this is a good candidate for pointing at the assignment
## symbol in the error message, if we have access to the source,
## as that's the most likely fix.
odin_parse_error(
c("Expected the lhs of '~' to be a 'compare()' call",
i = "Did you mean to use '<-' in place of '~'?"),
src, call)
}
lhs <- lhs[[2]]
if (!is.symbol(lhs)) {
odin_parse_error(
"Expected the argument of 'compare()' to be a symbol",
src, call)
}
lhs
}


parse_expr_compare_rhs <- function(rhs, src, call) {
result <- mcstate2::mcstate_dsl_parse_distribution(rhs, "The rhs of '~'")
if (!result$success) {
odin_parse_error(
result$error,
src, call)
}
depends <- find_dependencies(rhs)
list(type = "compare",
distribution = result$value$cpp$density,
args = result$value$args,
depends = depends)
}


parse_expr_print <- function(expr, src, call) {
odin_parse_error(
"'print()' is not implemented yet", src, call)
}
34 changes: 34 additions & 0 deletions R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,37 @@ match_value <- function(x, choices, name = deparse(substitute(x)), arg = name,
}
x
}


match_call <- function(call, fn) {
## We'll probably expand on the error case here to return something
## much nicer?

## TODO: it would be great to totally prevent partial matching here.
## The warning emitted by R is not easily caught (no special class
## for example) and neither match.call nor the rlang wrapper provide
## a hook here to really pick this up. We can look for expanded
## names in the results, though that's not super obvious either
## since we're also filling them in and reordering.
tryCatch(
list(success = TRUE,
value = rlang::call_match(call, fn, defaults = TRUE)),
error = function(e) {
list(success = FALSE,
error = e)
})
}


is_scalar_logical <- function(x) {
is.logical(x) && length(x) == 1 && !is.na(x)
}


collector <- function(init = character(0)) {
env <- new.env(parent = emptyenv())
env$res <- init
list(
add = function(x) env$res <- c(env$res, x),
get = function() env$res)
}
Loading

0 comments on commit 86ebfea

Please sign in to comment.