From 9543f237ef7e12ae8fd375e5b8c56651cff49c78 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Jul 2024 13:37:55 +0100 Subject: [PATCH 1/8] Basic parsing of expressions --- R/constants.R | 12 ++ R/dependencies.R | 25 +++ R/parse.R | 12 ++ R/parse_expr.R | 250 +++++++++++++++++++++++ R/util.R | 34 +++ tests/testthat/test-parse-expr-compare.R | 58 ++++++ tests/testthat/test-parse-expr.R | 187 +++++++++++++++++ tests/testthat/test-util.R | 22 ++ 8 files changed, 600 insertions(+) create mode 100644 R/constants.R create mode 100644 R/dependencies.R create mode 100644 R/parse_expr.R create mode 100644 tests/testthat/test-parse-expr-compare.R create mode 100644 tests/testthat/test-parse-expr.R diff --git a/R/constants.R b/R/constants.R new file mode 100644 index 0000000..b157b20 --- /dev/null +++ b/R/constants.R @@ -0,0 +1,12 @@ +SPECIAL_LHS <- c( + "initial", "deriv", "update", "output", "dim", "config", "compare") + +COMPARE <- list( + Normal = function(mean, sd) {}, + Poisson = function(lambda) {}) + +STOCHASTIC <- list( + Binomial = function(size, prob) {}) + +FUNCTIONS <- list( + exp = function(x) {}) diff --git a/R/dependencies.R b/R/dependencies.R new file mode 100644 index 0000000..d7db84a --- /dev/null +++ b/R/dependencies.R @@ -0,0 +1,25 @@ +find_dependencies <- function(expr) { + functions <- collector() + variables <- collector() + descend <- function(e) { + if (is.recursive(e)) { + nm <- deparse(e[[1L]]) + ## If we hit dim/length here we should not take a dependency + ## most of the time though. + functions$add(nm) + for (el in as.list(e[-1])) { + if (!missing(el)) { + descend(el) + } + } + } else { + if (is.symbol(e)) { + variables$add(deparse(e)) + } + } + } + descend(expr) + list(functions = unique(functions$get()), + variables = unique(variables$get())) + +} diff --git a/R/parse.R b/R/parse.R index 66eaf33..4cf6366 100644 --- a/R/parse.R +++ b/R/parse.R @@ -1,5 +1,17 @@ odin_parse <- function(expr, input_type = NULL) { call <- environment() dat <- parse_prepare(rlang::enquo(expr), input_type, call) + exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call)) NULL } + + +odin_parse_error <- function(msg, src, call, parent = NULL, + .envir = parent.frame()) { + cli::cli_abort(msg, + class = "odin_parse_error", + src = src, + call = call, + parent = parent, + .envir = .envir) +} diff --git a/R/parse_expr.R b/R/parse_expr.R new file mode 100644 index 0000000..5fa35ce --- /dev/null +++ b/R/parse_expr.R @@ -0,0 +1,250 @@ +parse_expr <- function(expr, src, call) { + if (rlang::is_call(expr, c("<-", "="))) { + parse_expr_assignment(expr, src, call) + } else if (rlang::is_call(expr, "~")) { + parse_expr_compare(expr, src, call) + } else if (rlang::is_call(expr, "print")) { + parse_expr_print(expr, src, call) + } else { + odin_parse_error( + c("Unclassifiable expression", + i = "Expected an assignment (with '<-') or a relationship (with '~')"), + src, call) + } +} + + +parse_expr_assignment <- function(expr, src, call) { + lhs <- parse_expr_assignment_lhs(expr[[2]], src, call) + rhs <- parse_expr_assignment_rhs(expr[[3]], src, call) + + special <- lhs$special + lhs$special <- NULL + if (rhs$type == "data") { + if (!is.null(special)) { + odin_parse_error( + "Calls to 'data()' must be assigned to a symbol", + src, call) + } + special <- "data" + } + + list(special = special, + lhs = lhs, + rhs = rhs, + src = src) +} + + +parse_expr_assignment_lhs <- function(lhs, src, call) { + array <- NULL + special <- NULL + name <- NULL + + if (rlang::is_call(lhs, SPECIAL_LHS)) { + special <- deparse1(lhs[[1]]) + if (length(lhs) != 2 || !is.null(names(lhs))) { + odin_parse_error( + c("Invalid special function call", + i = "Expected a single unnamed argument to '{special}()'"), + src, call) + } + if (special == "compare") { + ## TODO: a good candidate for pointing at the source location of + ## the error. + odin_parse_error( + c("'compare()' expressions must use '~', not '<-'", + i = paste("Compare expressions do not represent assignents, but", + "relationships, which we emphasise by using '~'. This", + "also keeps the syntax close to that for the prior", + "specification in mcstate2")), + src, call) + } + lhs <- lhs[[2]] + } + + is_array <- rlang::is_call(lhs, "[") + if (is_array) { + odin_parse_error( + "Arrays are not supported yet", src, call) + } + + name <- parse_expr_check_lhs_name(lhs, src, call) + + lhs <- list( + name = name, + special = special) +} + + +parse_expr_assignment_rhs <- function(rhs, src, call) { + if (rlang::is_call(rhs, "delay")) { + odin_parse_error("'delay()' is not implemented yet", src, call) + } else if (rlang::is_call(rhs, "parameter")) { + parse_expr_assignment_rhs_parameter(rhs, src, call) + } else if (rlang::is_call(rhs, "data")) { + parse_expr_assignment_rhs_data(rhs, src, call) + } else if (rlang::is_call(rhs, "interpolate")) { + odin_parse_error("'interpolate()' is not implemented yet", src, call) + } else { + parse_expr_assignment_rhs_expression(rhs, src, call) + } +} + + +parse_expr_assignment_rhs_expression <- function(rhs, src, call) { + ## So here, we want to find dependencies used in the rhs and make + ## sure that the user correctly restricts to the right set of + ## dependencies. There is some faff with sum, and we detect here if + ## the user uses anything stochastic. We do look for the range + ## operator but I'm not totlaly sure that's the best place to do so. + depends <- find_dependencies(rhs) + list(type = "expression", + expr = rhs, + depends = depends) +} + + +parse_expr_check_lhs_name <- function(lhs, src, call) { + ## There are lots of checks we should add here, but fundamentally + ## it's a case of making sure that we have been given a symbol and + ## that symbol is not anything reserved, nor does it start with + ## anything reserved. Add these in later, see + ## "ir_parse_expr_check_lhs_name" for details. + if (!rlang::is_symbol(lhs)) { + odin_parse_error("Expected a symbol on the lhs", src, call) + } + name <- deparse1(lhs) + name +} + + +## TODO: we'll have a variant of this that acts as a compatibility +## layer for user(), as this is probably the biggest required change +## to people's code, really. +parse_expr_assignment_rhs_parameter <- function(rhs, src, call) { + template <- function(default = NULL, constant = NULL, differentiate = FALSE) { + } + result <- match_call(rhs, template) + if (!result$success) { + ## I don't think this is quite correct really, and I'm not sure + ## the generated error is hugely informative for the user. + odin_parse_error("Invalid call to 'parameter()'", + src, call, parent = result$error) + } + args <- as.list(result$value)[-1] + if (is.language(args$default)) { + deps <- find_dependencies(args$default) + if (length(deps$variables) > 0) { + default_str <- deparse1(args$default) + odin_parse_error( + c("Invalid default argument to 'parameter()': {default_str}", + i = paste("Default arguments can only perform basic arithmetic", + "operations on numbers, and may not reference any", + "other parameter or variable")), + src, call) + } + ## TODO: validate the functions used at some point, once we do + ## that generally. + } + + if (!is_scalar_logical(args$differentiate)) { + str <- deparse1(args$differentiate) + odin_parse_error( + "'differentiate' must be a scalar logical, but was '{str}'", + src, call) + } + ## constant has a different default + if (is.null(args$constant)) { + args$constant <- NA + } else if (!is_scalar_logical(args$constant)) { + str <- deparse1(args$constant) + odin_parse_error( + "'constant' must be a scalar logical if given, but was '{str}'", + src, call) + } + list(type = "parameter", + args = args) +} + + +parse_expr_assignment_rhs_data <- function(rhs, src, call) { + if (length(rhs) != 1) { + odin_parse_error("Calls to 'data()' must have no arguments", + src, call) + } + list(type = "data") +} + + +parse_expr_compare <- function(expr, src, call) { + lhs <- parse_expr_compare_lhs(expr[[2]], src, call) + rhs <- parse_expr_compare_rhs(expr[[3]], src, call) + + ## Quickly rewrite the expression, at least for now: + rhs$expr <- as.call(c(list(rhs$expr[[1]], lhs), + as.list(rhs$expr[-1]))) + rhs$depends$variables <- union(rhs$depends$variables, + as.character(lhs)) + list(special = "compare", + rhs = rhs, + src = src) +} + + +parse_expr_compare_lhs <- function(lhs, src, call) { + if (!rlang::is_call(lhs, "compare")) { + ## TODO: this is a good candidate for pointing at the assignment + ## symbol in the error message, if we have access to the source, + ## as that's the most likely fix. + odin_parse_error( + c("Expected the lhs of '~' to be a 'compare()' call", + i = "Did you mean to use '<-' in place of '~'?"), + src, call) + } + lhs <- lhs[[2]] + if (!is.symbol(lhs)) { + odin_parse_error( + "Expected the argument of 'compare()' to be a symbol", + src, call) + } + lhs +} + + +## TODO: See mcstate2 with `match_call_candidate()` for doing this +## properly with choices; we may want to leverage some of the code +## there to keep the same semantics, especially once we start +## differentiating. +parse_expr_compare_rhs <- function(rhs, src, call) { + if (!rlang::is_call(rhs, names(COMPARE))) { + ## TODO: Add DYM support here, including incorrect cases, and + ## dnorm() etc. + odin_parse_error( + "Expected the rhs of '~' to be a call to a distribution function", + src, call) + } + nm <- as.character(rhs[[1]]) + ## TODO: we really need this to come from mcstate earlier rather + ## than later; add a small utility there which does all the work for + ## us, as it's important that we match this well. The same thing is + ## going to happen with the stochastic functions, and again, we're + ## using the actual support from mcstate again so this should be + ## very consistent. + result <- match_call(rhs, COMPARE[[nm]]) + if (!result$success) { + odin_parse_error( + "Invalid call to '{nm}()': {conditionMessage(result$error)}", + src, call) + } + depends <- find_dependencies(rhs) + list(type = "compare", + expr = result$value, + depends = depends) +} + + +parse_expr_print <- function(expr, src, call) { + odin_parse_error( + "'print()' is not implemented yet", src, call) +} diff --git a/R/util.R b/R/util.R index acb05e8..5b3d882 100644 --- a/R/util.R +++ b/R/util.R @@ -24,3 +24,37 @@ match_value <- function(x, choices, name = deparse(substitute(x)), arg = name, } x } + + +match_call <- function(call, fn) { + ## We'll probably expand on the error case here to return something + ## much nicer? + + ## TODO: it would be great to totally prevent partial matching here. + ## The warning emitted by R is not easily caught (no special class + ## for example) and neither match.call nor the rlang wrapper provide + ## a hook here to really pick this up. We can look for expanded + ## names in the results, though that's not super obvious either + ## since we're also filling them in and reordering. + tryCatch( + list(success = TRUE, + value = rlang::call_match(call, fn, defaults = TRUE)), + error = function(e) { + list(success = FALSE, + error = e) + }) +} + + +is_scalar_logical <- function(x) { + is.logical(x) && length(x) == 1 && !is.na(x) +} + + +collector <- function(init = character(0)) { + env <- new.env(parent = emptyenv()) + env$res <- init + list( + add = function(x) env$res <- c(env$res, x), + get = function() env$res) +} diff --git a/tests/testthat/test-parse-expr-compare.R b/tests/testthat/test-parse-expr-compare.R new file mode 100644 index 0000000..8e84f01 --- /dev/null +++ b/tests/testthat/test-parse-expr-compare.R @@ -0,0 +1,58 @@ +test_that("Can parse compare expression", { + res <- parse_expr(quote(compare(x) ~ Normal(0, 1)), NULL, NULL) + expect_equal(res$special, "compare") + expect_equal(res$rhs$type, "compare") + expect_equal(res$rhs$expr, quote(Normal(x, mean = 0, sd = 1))) + expect_equal(res$rhs$depends, + list(functions = "Normal", variables = "x")) +}) + + +test_that("compare expressions must use '~'", { + expect_error( + parse_expr(quote(compare(x) <- Normal(0, 1)), NULL, NULL), + "'compare()' expressions must use '~', not '<-'", + fixed = TRUE) +}) + + +test_that("only compare expressions may use '~'", { + expect_error( + parse_expr(quote(initial(x) ~ 1), NULL, NULL), + "Expected the lhs of '~' to be a 'compare()' call", + fixed = TRUE) + expect_error( + parse_expr(quote(x ~ 1), NULL, NULL), + "Expected the lhs of '~' to be a 'compare()' call", + fixed = TRUE) +}) + + +test_that("compare() calls must wrap symbols", { + expect_error( + parse_expr(quote(compare(1) ~ Normal(0, 1)), NULL, NULL), + "Expected the argument of 'compare()' to be a symbol", + fixed = TRUE) + expect_error( + parse_expr(quote(compare(f(x)) ~ Normal(0, 1)), NULL, NULL), + "Expected the argument of 'compare()' to be a symbol", + fixed = TRUE) + expect_error( + parse_expr(quote(compare(x[]) ~ Normal(0, 1)), NULL, NULL), + "Expected the argument of 'compare()' to be a symbol", + fixed = TRUE) +}) + + +test_that("parse compare call rhs as distributions", { + expect_error( + parse_expr(quote(compare(x) ~ 1), NULL, NULL), + "Expected the rhs of '~' to be a call to a distribution function") + expect_error( + parse_expr(quote(compare(x) ~ Foo(0, 1)), NULL, NULL), + "Expected the rhs of '~' to be a call to a distribution function") + expect_error( + parse_expr(quote(compare(x) ~ Normal(mu = 1)), NULL, NULL), + "Invalid call to 'Normal()'", + fixed = TRUE) +}) diff --git a/tests/testthat/test-parse-expr.R b/tests/testthat/test-parse-expr.R new file mode 100644 index 0000000..0b41877 --- /dev/null +++ b/tests/testthat/test-parse-expr.R @@ -0,0 +1,187 @@ +test_that("can parse simple assignments", { + res <- parse_expr(quote(a <- 1), NULL, NULL) + expect_equal(res$lhs$name, "a") + expect_equal(res$rhs$type, "expression") + expect_equal(res$rhs$expr, 1) + expect_equal(res$rhs$depends, + list(functions = character(), variables = character())) +}) + + +test_that("can parse simple expressions involving functions/variables", { + res <- parse_expr(quote(a <- b + c / b), NULL, NULL) + expect_equal(res$rhs$expr, quote(b + c / b)) + expect_equal(res$rhs$depends, + list(functions = c("+", "/"), variables = c("b", "c"))) +}) + + +test_that("require that assignment lhs is reasonable", { + expect_error( + parse_expr(quote(1 <- 1), NULL, NULL), + "Expected a symbol on the lhs") + expect_error( + parse_expr(quote(f(1) <- 1), NULL, NULL), + "Expected a symbol on the lhs") +}) + + +## Special calls are initial/deriv/update/dim/output/config/compare +test_that("allow calls on lhs", { + res <- parse_expr(quote(initial(x) <- 1), NULL, NULL) + expect_equal(res$lhs$name, "x") + expect_equal(res$special, "initial") + expect_equal(res$rhs$expr, 1) + expect_equal(parse_expr(quote(deriv(x) <- 1), NULL, NULL)$special, "deriv") + expect_equal(parse_expr(quote(update(x) <- 1), NULL, NULL)$special, "update") +}) + + +test_that("requre that special calls are (currently) simple", { + expect_error( + parse_expr(quote(initial(x, TRUE) <- 1), NULL, NULL), + "Invalid special function call") + expect_error( + parse_expr(quote(initial() <- 1), NULL, NULL), + "Invalid special function call") + expect_error( + parse_expr(quote(initial(x = 1) <- 1), NULL, NULL), + "Invalid special function call") +}) + + +test_that("can parse parameter definitions", { + res <- parse_expr(quote(a <- parameter()), NULL, NULL) + expect_equal(res$rhs$type, "parameter") + expect_null(res$rhs$args$default) + expect_equal(res$rhs$args$constant, NA) + expect_false(res$rhs$args$differentiate) +}) + + +test_that("can parse parameter definitions with defaults", { + res <- parse_expr(quote(a <- parameter(10)), NULL, NULL) + expect_equal(res$rhs$type, "parameter") + expect_equal(res$rhs$args$default, 10) + expect_equal(res$rhs$args$constant, NA) + expect_false(res$rhs$args$differentiate) +}) + + +test_that("can parse parameter definitions with expression defaults", { + res <- parse_expr(quote(a <- parameter(4 / 3)), NULL, NULL) + expect_equal(res$rhs$type, "parameter") + expect_equal(res$rhs$args$default, quote(4 / 3)) + expect_equal(res$rhs$args$constant, NA) + expect_false(res$rhs$args$differentiate) +}) + + +test_that("parameter defaults must be simple", { + expect_error( + parse_expr(quote(a <- parameter(a)), NULL, NULL), + "Invalid default argument to 'parameter()': a", + fixed = TRUE) +}) + + +test_that("validate differentiate argument", { + res <- parse_expr(quote(a <- parameter(differentiate = TRUE)), NULL, NULL) + expect_true(res$rhs$args$differentiate) + res <- parse_expr(quote(a <- parameter()), NULL, NULL) + expect_false(res$rhs$args$differentiate) + + expect_error( + parse_expr(quote(a <- parameter(differentiate = x)), NULL, NULL), + "'differentiate' must be a scalar logical, but was 'x'") + expect_error( + parse_expr(quote(a <- parameter(differentiate = NA)), NULL, NULL), + "'differentiate' must be a scalar logical, but was 'NA'") + expect_error( + parse_expr(quote(a <- parameter(differentiate = NULL)), NULL, NULL), + "'differentiate' must be a scalar logical, but was 'NULL'") +}) + + +test_that("validate constant argument", { + res <- parse_expr(quote(a <- parameter(constant = TRUE)), NULL, NULL) + expect_true(res$rhs$args$constant) + res <- parse_expr(quote(a <- parameter(constant = FALSE)), NULL, NULL) + expect_false(res$rhs$args$constant) + res <- parse_expr(quote(a <- parameter()), NULL, NULL) + expect_equal(res$rhs$args$constant, NA) + + expect_error( + parse_expr(quote(a <- parameter(constant = x)), NULL, NULL), + "'constant' must be a scalar logical if given, but was 'x'") + expect_error( + parse_expr(quote(a <- parameter(constant = NA)), NULL, NULL), + "'constant' must be a scalar logical if given, but was 'NA'") +}) + + +test_that("sensible error if parameters are incorrectly specified", { + expect_error( + parse_expr(quote(a <- parameter(other = TRUE)), NULL, NULL), + "Invalid call to 'parameter()'", + fixed = TRUE) +}) + + +test_that("parse data assignment", { + res <- parse_expr(quote(d <- data()), NULL, NULL) + expect_equal(res$special, "data") + expect_equal(res$lhs$name, "d") + expect_equal(res$rhs, list(type = "data")) +}) + + +test_that("data calls must be very simple", { + expect_error( + parse_expr(quote(d <- data(integer = TRUE)), NULL, NULL), + "Calls to 'data()' must have no arguments", + fixed = TRUE) + expect_error( + parse_expr(quote(deriv(d) <- data()), NULL, NULL), + "Calls to 'data()' must be assigned to a symbol", + fixed = TRUE) +}) + + +test_that("print not yet supported", { + expect_error( + parse_expr(quote(print(x)), NULL, NULL), + "'print()' is not implemented yet", + fixed = TRUE) +}) + + +test_that("delays not yet supported", { + expect_error( + parse_expr(quote(a <- delay(b, 1)), NULL, NULL), + "'delay()' is not implemented yet", + fixed = TRUE) +}) + + +test_that("delays not yet supported", { + expect_error( + parse_expr(quote(a <- interpolate(b, "constant")), NULL, NULL), + "'interpolate()' is not implemented yet", + fixed = TRUE) +}) + + +test_that("arrays not yet supported", { + expect_error( + parse_expr(quote(a[] <- 1), NULL, NULL), + "Arrays are not supported yet", + fixed = TRUE) +}) + + +test_that("Reject unclassifiable expressions", { + expect_error( + parse_expr(quote(a), NULL, NULL), + "Unclassifiable expression") +}) diff --git a/tests/testthat/test-util.R b/tests/testthat/test-util.R index 11588be..d14e37c 100644 --- a/tests/testthat/test-util.R +++ b/tests/testthat/test-util.R @@ -10,3 +10,25 @@ test_that("match_value", { expect_error(match_value("foo", letters), "must be one of") expect_silent(match_value("a", letters)) }) + + +test_that("can match a simple call", { + fn <- function(foo, bar = 1) NULL + expect_equal( + match_call(quote(f(a)), fn), + list(success = TRUE, value = quote(f(foo = a, bar = 1)))) + expect_equal( + match_call(quote(f(a, b)), fn), + list(success = TRUE, value = quote(f(foo = a, bar = b)))) + expect_equal( + match_call(quote(f(bar = 2, foo = 1)), fn), + list(success = TRUE, value = quote(f(foo = 1, bar = 2)))) + + res <- match_call(quote(f(baz = 1)), fn) + expect_false(res$success) + + ## Partial matching still enabled + expect_equal( + suppressWarnings(match_call(quote(f(fo = 2)), fn)), + list(success = TRUE, value = quote(f(foo = 2, bar = 1)))) +}) From 0bd8cc23fe46e1bb78a0dc72978893dcde9fe6a1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Jul 2024 16:20:44 +0100 Subject: [PATCH 2/8] Use helper from mcstate --- DESCRIPTION | 3 +++ R/constants.R | 7 ------ R/parse_expr.R | 27 +++++------------------- tests/testthat/test-parse-expr-compare.R | 7 +++--- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6ad17f0..dda7aac 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,8 +14,11 @@ URL: https://github.com/mrc-ide/odin2 BugReports: https://github.com/mrc-ide/odin2/issues Imports: cli, + mcstate, rlang Suggests: testthat (>= 3.0.0), withr Config/testthat/edition: 3 +Remotes: + mrc-ide/mcstate diff --git a/R/constants.R b/R/constants.R index b157b20..af24611 100644 --- a/R/constants.R +++ b/R/constants.R @@ -1,12 +1,5 @@ SPECIAL_LHS <- c( "initial", "deriv", "update", "output", "dim", "config", "compare") -COMPARE <- list( - Normal = function(mean, sd) {}, - Poisson = function(lambda) {}) - -STOCHASTIC <- list( - Binomial = function(size, prob) {}) - FUNCTIONS <- list( exp = function(x) {}) diff --git a/R/parse_expr.R b/R/parse_expr.R index 5fa35ce..6d53f9b 100644 --- a/R/parse_expr.R +++ b/R/parse_expr.R @@ -212,34 +212,17 @@ parse_expr_compare_lhs <- function(lhs, src, call) { } -## TODO: See mcstate2 with `match_call_candidate()` for doing this -## properly with choices; we may want to leverage some of the code -## there to keep the same semantics, especially once we start -## differentiating. parse_expr_compare_rhs <- function(rhs, src, call) { - if (!rlang::is_call(rhs, names(COMPARE))) { - ## TODO: Add DYM support here, including incorrect cases, and - ## dnorm() etc. - odin_parse_error( - "Expected the rhs of '~' to be a call to a distribution function", - src, call) - } - nm <- as.character(rhs[[1]]) - ## TODO: we really need this to come from mcstate earlier rather - ## than later; add a small utility there which does all the work for - ## us, as it's important that we match this well. The same thing is - ## going to happen with the stochastic functions, and again, we're - ## using the actual support from mcstate again so this should be - ## very consistent. - result <- match_call(rhs, COMPARE[[nm]]) + result <- mcstate2::mcstate_dsl_parse_distribution(rhs, "The rhs of '~'") if (!result$success) { odin_parse_error( - "Invalid call to '{nm}()': {conditionMessage(result$error)}", + result$error, src, call) } - depends <- find_dependencies(rhs) + depends <- find_dependencies(rhs) list(type = "compare", - expr = result$value, + distribution = result$value$cpp$density, + args = result$value$args, depends = depends) } diff --git a/tests/testthat/test-parse-expr-compare.R b/tests/testthat/test-parse-expr-compare.R index 8e84f01..ba58de9 100644 --- a/tests/testthat/test-parse-expr-compare.R +++ b/tests/testthat/test-parse-expr-compare.R @@ -2,7 +2,8 @@ test_that("Can parse compare expression", { res <- parse_expr(quote(compare(x) ~ Normal(0, 1)), NULL, NULL) expect_equal(res$special, "compare") expect_equal(res$rhs$type, "compare") - expect_equal(res$rhs$expr, quote(Normal(x, mean = 0, sd = 1))) + expect_equal(res$rhs$distribution, "normal") + expect_equal(res$rhs$args, list(0, 1)) expect_equal(res$rhs$depends, list(functions = "Normal", variables = "x")) }) @@ -47,10 +48,10 @@ test_that("compare() calls must wrap symbols", { test_that("parse compare call rhs as distributions", { expect_error( parse_expr(quote(compare(x) ~ 1), NULL, NULL), - "Expected the rhs of '~' to be a call to a distribution function") + "The rhs of '~' is not a function call") expect_error( parse_expr(quote(compare(x) ~ Foo(0, 1)), NULL, NULL), - "Expected the rhs of '~' to be a call to a distribution function") + "Unknown distribution 'Foo'") expect_error( parse_expr(quote(compare(x) ~ Normal(mu = 1)), NULL, NULL), "Invalid call to 'Normal()'", From a2a5c3791b91eb34d5b8048c5bc7509d3d9a7655 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Jul 2024 16:21:09 +0100 Subject: [PATCH 3/8] Set branch pointer for now --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index dda7aac..f60b921 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,4 +21,4 @@ Suggests: withr Config/testthat/edition: 3 Remotes: - mrc-ide/mcstate + mrc-ide/mcstate@mrc-5522 From afbafe8b43c4cd3189418c2852fb481fe776cc39 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Jul 2024 16:34:07 +0100 Subject: [PATCH 4/8] Add test of error throwing in context --- tests/testthat/test-parse.R | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/testthat/test-parse.R b/tests/testthat/test-parse.R index 3f4f9ab..7e1e0e4 100644 --- a/tests/testthat/test-parse.R +++ b/tests/testthat/test-parse.R @@ -5,3 +5,17 @@ test_that("can parse trivial system", { }) expect_null(res) }) + + +test_that("throw error with context", { + path <- withr::local_tempfile() + writeLines(c("initial(x) <- a", + "update(x) <- x + b", + "b <- parameter(invalid = TRUE)", + "a <- 5"), + path) + err <- expect_error( + odin_parse(path), + "Invalid call to 'parameter()'", + fixed = TRUE) +}) From 0c9910ecb6f1bd102199b09f0e65181dbcd631b9 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 12 Jul 2024 18:58:12 +0100 Subject: [PATCH 5/8] Throw errors with context --- DESCRIPTION | 6 +++--- NAMESPACE | 3 +++ R/parse.R | 25 +++++++++++++++++++++++-- R/parse_expr.R | 8 ++++---- tests/testthat/test-parse.R | 31 ++++++++++++++++++++++++++++++- 5 files changed, 63 insertions(+), 10 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index f60b921..984f4d5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -9,16 +9,16 @@ Description: Temporary package for rewriting odin. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.1.1 +RoxygenNote: 7.3.2 URL: https://github.com/mrc-ide/odin2 BugReports: https://github.com/mrc-ide/odin2/issues Imports: cli, - mcstate, + mcstate2, rlang Suggests: testthat (>= 3.0.0), withr Config/testthat/edition: 3 Remotes: - mrc-ide/mcstate@mrc-5522 + mrc-ide/mcstate2@mrc-5522 diff --git a/NAMESPACE b/NAMESPACE index e651b94..e39bb5a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1 +1,4 @@ # Generated by roxygen2: do not edit by hand + +S3method(cnd_footer,odin_parse_error) +importFrom(rlang,cnd_footer) diff --git a/R/parse.R b/R/parse.R index 4cf6366..bc102f2 100644 --- a/R/parse.R +++ b/R/parse.R @@ -6,12 +6,33 @@ odin_parse <- function(expr, input_type = NULL) { } -odin_parse_error <- function(msg, src, call, parent = NULL, +odin_parse_error <- function(msg, src, call, ..., .envir = parent.frame()) { + if (!is.null(names(src))) { + src <- list(src) + } cli::cli_abort(msg, class = "odin_parse_error", src = src, call = call, - parent = parent, + ..., .envir = .envir) } + + +##' @importFrom rlang cnd_footer +##' @export +cnd_footer.odin_parse_error <- function(cnd, ...) { + ## TODO: later, we might want to point at specific bits of the error + ## and say "here, this is where you are wrong" but that's not done + ## yet... + src <- cnd$src + if (is.null(src[[1]]$str)) { + context <- unlist(lapply(cnd$src, function(x) deparse1(x$value))) + } else { + line <- unlist(lapply(src, function(x) seq(x$start, x$end))) + src <- unlist(lapply(src, "[[", "str")) + context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src) + } + c(">" = "Context:", context) +} diff --git a/R/parse_expr.R b/R/parse_expr.R index 6d53f9b..9d8ff30 100644 --- a/R/parse_expr.R +++ b/R/parse_expr.R @@ -127,10 +127,10 @@ parse_expr_assignment_rhs_parameter <- function(rhs, src, call) { } result <- match_call(rhs, template) if (!result$success) { - ## I don't think this is quite correct really, and I'm not sure - ## the generated error is hugely informative for the user. - odin_parse_error("Invalid call to 'parameter()'", - src, call, parent = result$error) + odin_parse_error( + c("Invalid call to 'parameter()'", + x = conditionMessage(result$error)), + src, call) } args <- as.list(result$value)[-1] if (is.language(args$default)) { diff --git a/tests/testthat/test-parse.R b/tests/testthat/test-parse.R index 7e1e0e4..a7c8fe2 100644 --- a/tests/testthat/test-parse.R +++ b/tests/testthat/test-parse.R @@ -11,11 +11,40 @@ test_that("throw error with context", { path <- withr::local_tempfile() writeLines(c("initial(x) <- a", "update(x) <- x + b", - "b <- parameter(invalid = TRUE)", + "b<-parameter(invalid=TRUE)", "a <- 5"), path) err <- expect_error( odin_parse(path), "Invalid call to 'parameter()'", + fixed = TRUE, + class = "odin_parse_error") + expect_equal( + err$src, + list(list(value = quote(b <- parameter(invalid = TRUE)), + start = 3, + end = 3, + str = c("b<-parameter(invalid=TRUE)")))) + expect_match( + conditionMessage(err), + "Context:\n3| b<-parameter(invalid=TRUE)", + fixed = TRUE) +}) + + +test_that("throw error with context where source code unavailable", { + err <- expect_error( + odin_parse({ + initial(x) <- a + update(x) <- x + b + b <- parameter(invalid = TRUE) + a <- 5 + }), + "Invalid call to 'parameter()'", + fixed = TRUE, + class = "odin_parse_error") + expect_match( + conditionMessage(err), + "Context:\nb <- parameter(invalid = TRUE)", fixed = TRUE) }) From 2e066cb82fab380520625cae3eb80d4e591d0f67 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 15 Jul 2024 12:19:27 +0100 Subject: [PATCH 6/8] Reorganise source --- R/parse.R | 32 -------------------------------- R/parse_error.R | 30 ++++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 32 deletions(-) create mode 100644 R/parse_error.R diff --git a/R/parse.R b/R/parse.R index bc102f2..b020c36 100644 --- a/R/parse.R +++ b/R/parse.R @@ -4,35 +4,3 @@ odin_parse <- function(expr, input_type = NULL) { exprs <- lapply(dat$exprs, function(x) parse_expr(x$value, x, call = call)) NULL } - - -odin_parse_error <- function(msg, src, call, ..., - .envir = parent.frame()) { - if (!is.null(names(src))) { - src <- list(src) - } - cli::cli_abort(msg, - class = "odin_parse_error", - src = src, - call = call, - ..., - .envir = .envir) -} - - -##' @importFrom rlang cnd_footer -##' @export -cnd_footer.odin_parse_error <- function(cnd, ...) { - ## TODO: later, we might want to point at specific bits of the error - ## and say "here, this is where you are wrong" but that's not done - ## yet... - src <- cnd$src - if (is.null(src[[1]]$str)) { - context <- unlist(lapply(cnd$src, function(x) deparse1(x$value))) - } else { - line <- unlist(lapply(src, function(x) seq(x$start, x$end))) - src <- unlist(lapply(src, "[[", "str")) - context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src) - } - c(">" = "Context:", context) -} diff --git a/R/parse_error.R b/R/parse_error.R new file mode 100644 index 0000000..9dfc018 --- /dev/null +++ b/R/parse_error.R @@ -0,0 +1,30 @@ +odin_parse_error <- function(msg, src, call, ..., + .envir = parent.frame()) { + if (!is.null(names(src))) { + src <- list(src) + } + cli::cli_abort(msg, + class = "odin_parse_error", + src = src, + call = call, + ..., + .envir = .envir) +} + + +##' @importFrom rlang cnd_footer +##' @export +cnd_footer.odin_parse_error <- function(cnd, ...) { + ## TODO: later, we might want to point at specific bits of the error + ## and say "here, this is where you are wrong" but that's not done + ## yet... + src <- cnd$src + if (is.null(src[[1]]$str)) { + context <- unlist(lapply(cnd$src, function(x) deparse1(x$value))) + } else { + line <- unlist(lapply(src, function(x) seq(x$start, x$end))) + src <- unlist(lapply(src, "[[", "str")) + context <- sprintf("%s| %s", gsub(" ", "\u00a0", format(line)), src) + } + c(">" = "Context:", context) +} From b3a264cb14aeb37237e05c72e50290eff880a361 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 15 Jul 2024 15:24:14 +0100 Subject: [PATCH 7/8] Drop branch pointer --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 984f4d5..9ac1e1b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,4 +21,4 @@ Suggests: withr Config/testthat/edition: 3 Remotes: - mrc-ide/mcstate2@mrc-5522 + mrc-ide/mcstate2 From 88483116852f9c9e6ed86005caa7d0269b1e918f Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 16 Jul 2024 07:40:38 +0100 Subject: [PATCH 8/8] Update tests/testthat/test-parse-expr.R Co-authored-by: Wes Hinsley --- tests/testthat/test-parse-expr.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-parse-expr.R b/tests/testthat/test-parse-expr.R index 0b41877..d9d1247 100644 --- a/tests/testthat/test-parse-expr.R +++ b/tests/testthat/test-parse-expr.R @@ -37,7 +37,7 @@ test_that("allow calls on lhs", { }) -test_that("requre that special calls are (currently) simple", { +test_that("require that special calls are (currently) simple", { expect_error( parse_expr(quote(initial(x, TRUE) <- 1), NULL, NULL), "Invalid special function call")