-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from mrc-ide/mrc-5220
Parse basic expressions
- Loading branch information
Showing
12 changed files
with
646 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(cnd_footer,odin_parse_error) | ||
importFrom(rlang,cnd_footer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
SPECIAL_LHS <- c( | ||
"initial", "deriv", "update", "output", "dim", "config", "compare") | ||
|
||
FUNCTIONS <- list( | ||
exp = function(x) {}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
find_dependencies <- function(expr) { | ||
functions <- collector() | ||
variables <- collector() | ||
descend <- function(e) { | ||
if (is.recursive(e)) { | ||
nm <- deparse(e[[1L]]) | ||
## If we hit dim/length here we should not take a dependency | ||
## most of the time though. | ||
functions$add(nm) | ||
for (el in as.list(e[-1])) { | ||
if (!missing(el)) { | ||
descend(el) | ||
} | ||
} | ||
} else { | ||
if (is.symbol(e)) { | ||
variables$add(deparse(e)) | ||
} | ||
} | ||
} | ||
descend(expr) | ||
list(functions = unique(functions$get()), | ||
variables = unique(variables$get())) | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
odin_parse <- function(expr, input_type = NULL) { | ||
call <- environment() | ||
dat <- parse_prepare(rlang::enquo(expr), input_type, call) | ||
exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call)) | ||
NULL | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
odin_parse_error <- function(msg, src, call, ..., | ||
.envir = parent.frame()) { | ||
if (!is.null(names(src))) { | ||
src <- list(src) | ||
} | ||
cli::cli_abort(msg, | ||
class = "odin_parse_error", | ||
src = src, | ||
call = call, | ||
..., | ||
.envir = .envir) | ||
} | ||
|
||
|
||
##' @importFrom rlang cnd_footer | ||
##' @export | ||
cnd_footer.odin_parse_error <- function(cnd, ...) { | ||
## TODO: later, we might want to point at specific bits of the error | ||
## and say "here, this is where you are wrong" but that's not done | ||
## yet... | ||
src <- cnd$src | ||
if (is.null(src[[1]]$str)) { | ||
context <- unlist(lapply(cnd$src, function(x) deparse1(x$value))) | ||
} else { | ||
line <- unlist(lapply(src, function(x) seq(x$start, x$end))) | ||
src <- unlist(lapply(src, "[[", "str")) | ||
context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src) | ||
} | ||
c(">" = "Context:", context) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
parse_expr <- function(expr, src, call) { | ||
if (rlang::is_call(expr, c("<-", "="))) { | ||
parse_expr_assignment(expr, src, call) | ||
} else if (rlang::is_call(expr, "~")) { | ||
parse_expr_compare(expr, src, call) | ||
} else if (rlang::is_call(expr, "print")) { | ||
parse_expr_print(expr, src, call) | ||
} else { | ||
odin_parse_error( | ||
c("Unclassifiable expression", | ||
i = "Expected an assignment (with '<-') or a relationship (with '~')"), | ||
src, call) | ||
} | ||
} | ||
|
||
|
||
parse_expr_assignment <- function(expr, src, call) { | ||
lhs <- parse_expr_assignment_lhs(expr[[2]], src, call) | ||
rhs <- parse_expr_assignment_rhs(expr[[3]], src, call) | ||
|
||
special <- lhs$special | ||
lhs$special <- NULL | ||
if (rhs$type == "data") { | ||
if (!is.null(special)) { | ||
odin_parse_error( | ||
"Calls to 'data()' must be assigned to a symbol", | ||
src, call) | ||
} | ||
special <- "data" | ||
} | ||
|
||
list(special = special, | ||
lhs = lhs, | ||
rhs = rhs, | ||
src = src) | ||
} | ||
|
||
|
||
parse_expr_assignment_lhs <- function(lhs, src, call) { | ||
array <- NULL | ||
special <- NULL | ||
name <- NULL | ||
|
||
if (rlang::is_call(lhs, SPECIAL_LHS)) { | ||
special <- deparse1(lhs[[1]]) | ||
if (length(lhs) != 2 || !is.null(names(lhs))) { | ||
odin_parse_error( | ||
c("Invalid special function call", | ||
i = "Expected a single unnamed argument to '{special}()'"), | ||
src, call) | ||
} | ||
if (special == "compare") { | ||
## TODO: a good candidate for pointing at the source location of | ||
## the error. | ||
odin_parse_error( | ||
c("'compare()' expressions must use '~', not '<-'", | ||
i = paste("Compare expressions do not represent assignents, but", | ||
"relationships, which we emphasise by using '~'. This", | ||
"also keeps the syntax close to that for the prior", | ||
"specification in mcstate2")), | ||
src, call) | ||
} | ||
lhs <- lhs[[2]] | ||
} | ||
|
||
is_array <- rlang::is_call(lhs, "[") | ||
if (is_array) { | ||
odin_parse_error( | ||
"Arrays are not supported yet", src, call) | ||
} | ||
|
||
name <- parse_expr_check_lhs_name(lhs, src, call) | ||
|
||
lhs <- list( | ||
name = name, | ||
special = special) | ||
} | ||
|
||
|
||
parse_expr_assignment_rhs <- function(rhs, src, call) { | ||
if (rlang::is_call(rhs, "delay")) { | ||
odin_parse_error("'delay()' is not implemented yet", src, call) | ||
} else if (rlang::is_call(rhs, "parameter")) { | ||
parse_expr_assignment_rhs_parameter(rhs, src, call) | ||
} else if (rlang::is_call(rhs, "data")) { | ||
parse_expr_assignment_rhs_data(rhs, src, call) | ||
} else if (rlang::is_call(rhs, "interpolate")) { | ||
odin_parse_error("'interpolate()' is not implemented yet", src, call) | ||
} else { | ||
parse_expr_assignment_rhs_expression(rhs, src, call) | ||
} | ||
} | ||
|
||
|
||
parse_expr_assignment_rhs_expression <- function(rhs, src, call) { | ||
## So here, we want to find dependencies used in the rhs and make | ||
## sure that the user correctly restricts to the right set of | ||
## dependencies. There is some faff with sum, and we detect here if | ||
## the user uses anything stochastic. We do look for the range | ||
## operator but I'm not totlaly sure that's the best place to do so. | ||
depends <- find_dependencies(rhs) | ||
list(type = "expression", | ||
expr = rhs, | ||
depends = depends) | ||
} | ||
|
||
|
||
parse_expr_check_lhs_name <- function(lhs, src, call) { | ||
## There are lots of checks we should add here, but fundamentally | ||
## it's a case of making sure that we have been given a symbol and | ||
## that symbol is not anything reserved, nor does it start with | ||
## anything reserved. Add these in later, see | ||
## "ir_parse_expr_check_lhs_name" for details. | ||
if (!rlang::is_symbol(lhs)) { | ||
odin_parse_error("Expected a symbol on the lhs", src, call) | ||
} | ||
name <- deparse1(lhs) | ||
name | ||
} | ||
|
||
|
||
## TODO: we'll have a variant of this that acts as a compatibility | ||
## layer for user(), as this is probably the biggest required change | ||
## to people's code, really. | ||
parse_expr_assignment_rhs_parameter <- function(rhs, src, call) { | ||
template <- function(default = NULL, constant = NULL, differentiate = FALSE) { | ||
} | ||
result <- match_call(rhs, template) | ||
if (!result$success) { | ||
odin_parse_error( | ||
c("Invalid call to 'parameter()'", | ||
x = conditionMessage(result$error)), | ||
src, call) | ||
} | ||
args <- as.list(result$value)[-1] | ||
if (is.language(args$default)) { | ||
deps <- find_dependencies(args$default) | ||
if (length(deps$variables) > 0) { | ||
default_str <- deparse1(args$default) | ||
odin_parse_error( | ||
c("Invalid default argument to 'parameter()': {default_str}", | ||
i = paste("Default arguments can only perform basic arithmetic", | ||
"operations on numbers, and may not reference any", | ||
"other parameter or variable")), | ||
src, call) | ||
} | ||
## TODO: validate the functions used at some point, once we do | ||
## that generally. | ||
} | ||
|
||
if (!is_scalar_logical(args$differentiate)) { | ||
str <- deparse1(args$differentiate) | ||
odin_parse_error( | ||
"'differentiate' must be a scalar logical, but was '{str}'", | ||
src, call) | ||
} | ||
## constant has a different default | ||
if (is.null(args$constant)) { | ||
args$constant <- NA | ||
} else if (!is_scalar_logical(args$constant)) { | ||
str <- deparse1(args$constant) | ||
odin_parse_error( | ||
"'constant' must be a scalar logical if given, but was '{str}'", | ||
src, call) | ||
} | ||
list(type = "parameter", | ||
args = args) | ||
} | ||
|
||
|
||
parse_expr_assignment_rhs_data <- function(rhs, src, call) { | ||
if (length(rhs) != 1) { | ||
odin_parse_error("Calls to 'data()' must have no arguments", | ||
src, call) | ||
} | ||
list(type = "data") | ||
} | ||
|
||
|
||
parse_expr_compare <- function(expr, src, call) { | ||
lhs <- parse_expr_compare_lhs(expr[[2]], src, call) | ||
rhs <- parse_expr_compare_rhs(expr[[3]], src, call) | ||
|
||
## Quickly rewrite the expression, at least for now: | ||
rhs$expr <- as.call(c(list(rhs$expr[[1]], lhs), | ||
as.list(rhs$expr[-1]))) | ||
rhs$depends$variables <- union(rhs$depends$variables, | ||
as.character(lhs)) | ||
list(special = "compare", | ||
rhs = rhs, | ||
src = src) | ||
} | ||
|
||
|
||
parse_expr_compare_lhs <- function(lhs, src, call) { | ||
if (!rlang::is_call(lhs, "compare")) { | ||
## TODO: this is a good candidate for pointing at the assignment | ||
## symbol in the error message, if we have access to the source, | ||
## as that's the most likely fix. | ||
odin_parse_error( | ||
c("Expected the lhs of '~' to be a 'compare()' call", | ||
i = "Did you mean to use '<-' in place of '~'?"), | ||
src, call) | ||
} | ||
lhs <- lhs[[2]] | ||
if (!is.symbol(lhs)) { | ||
odin_parse_error( | ||
"Expected the argument of 'compare()' to be a symbol", | ||
src, call) | ||
} | ||
lhs | ||
} | ||
|
||
|
||
parse_expr_compare_rhs <- function(rhs, src, call) { | ||
result <- mcstate2::mcstate_dsl_parse_distribution(rhs, "The rhs of '~'") | ||
if (!result$success) { | ||
odin_parse_error( | ||
result$error, | ||
src, call) | ||
} | ||
depends <- find_dependencies(rhs) | ||
list(type = "compare", | ||
distribution = result$value$cpp$density, | ||
args = result$value$args, | ||
depends = depends) | ||
} | ||
|
||
|
||
parse_expr_print <- function(expr, src, call) { | ||
odin_parse_error( | ||
"'print()' is not implemented yet", src, call) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.