From a28b779f3c6de31b06c0c220eab7a3b62bc876df Mon Sep 17 00:00:00 2001 From: Kamil Jankowski Date: Wed, 14 Feb 2024 12:44:08 +0100 Subject: [PATCH] [`snforge-test-collector`] Add `TestDetails` to `TestCaseRaw` (#1121) https://github.com/foundry-rs/starknet-foundry/issues/1394 --- Cargo.lock | 1 + Cargo.toml | 1 + .../scarb-snforge-test-collector/Cargo.toml | 1 + .../src/compilation/test_collector.rs | 54 +++++++--- .../test_collector/function_finder.rs | 26 +++++ .../tests/test.rs | 98 +++++++++---------- 6 files changed, 115 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7dc6127b2..494edc64e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4645,6 +4645,7 @@ dependencies = [ "cairo-lang-semantic", "cairo-lang-sierra", "cairo-lang-sierra-generator", + "cairo-lang-sierra-type-size", "cairo-lang-starknet", "cairo-lang-syntax", "cairo-lang-test-plugin", diff --git a/Cargo.toml b/Cargo.toml index b06b5c806..fd5f68d9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ cairo-lang-semantic = { git = "https://github.com/starkware-libs/cairo", rev = " cairo-lang-sierra = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } cairo-lang-sierra-generator = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } cairo-lang-sierra-to-casm = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } +cairo-lang-sierra-type-size = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } cairo-lang-starknet = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } cairo-lang-starknet-classes = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } cairo-lang-syntax = { git = "https://github.com/starkware-libs/cairo", rev = "c20ac70ca515de15692f2f514325abb84cc79152" } diff --git a/extensions/scarb-snforge-test-collector/Cargo.toml b/extensions/scarb-snforge-test-collector/Cargo.toml index 824657322..fb52c6e26 100644 --- a/extensions/scarb-snforge-test-collector/Cargo.toml +++ b/extensions/scarb-snforge-test-collector/Cargo.toml @@ -20,6 +20,7 @@ cairo-lang-project.workspace = true cairo-lang-semantic.workspace = true cairo-lang-sierra.workspace = true cairo-lang-sierra-generator.workspace = true +cairo-lang-sierra-type-size.workspace = true cairo-lang-starknet.workspace = true cairo-lang-syntax.workspace = true cairo-lang-test-plugin.workspace = true diff --git a/extensions/scarb-snforge-test-collector/src/compilation/test_collector.rs b/extensions/scarb-snforge-test-collector/src/compilation/test_collector.rs index 39d77b2c3..3fc449152 100644 --- a/extensions/scarb-snforge-test-collector/src/compilation/test_collector.rs +++ b/extensions/scarb-snforge-test-collector/src/compilation/test_collector.rs @@ -17,6 +17,7 @@ use cairo_lang_semantic::items::functions::GenericFunctionId; use cairo_lang_semantic::{ConcreteFunction, FunctionLongId}; use cairo_lang_sierra::extensions::enm::EnumType; use cairo_lang_sierra::extensions::NamedType; +use cairo_lang_sierra::ids::GenericTypeId; use cairo_lang_sierra::program::{GenericArg, Program}; use cairo_lang_sierra_generator::db::SierraGenGroup; use cairo_lang_sierra_generator::replace_ids::replace_sierra_ids_in_program; @@ -74,6 +75,14 @@ pub struct TestCaseRaw { pub expected_result: ExpectedTestResult, pub fork_config: Option, pub fuzzer_config: Option, + pub test_details: TestDetails, +} + +#[derive(Debug, PartialEq, Clone, Serialize)] +pub struct TestDetails { + pub entry_point_offset: usize, + pub parameter_types: Vec<(GenericTypeId, i16)>, + pub return_types: Vec<(GenericTypeId, i16)>, } pub fn collect_tests( @@ -138,6 +147,9 @@ pub fn collect_tests( .context("Compilation failed without any diagnostics") .context("Failed to get sierra program")?; + let sierra_program = replace_sierra_ids_in_program(db, &sierra_program.program); + let function_finder = FunctionFinder::new(sierra_program.clone())?; + let collected_tests = all_tests .into_iter() .map(|(func_id, test)| { @@ -157,23 +169,39 @@ pub fn collect_tests( }) .collect_vec() .into_iter() - .map(|(test_name, config)| TestCaseRaw { - name: test_name, - available_gas: config.available_gas, - ignored: config.ignored, - expected_result: config.expected_result, - fork_config: config.fork_config, - fuzzer_config: config.fuzzer_config, + .map(|(test_name, config)| { + let test_details = build_test_details(&function_finder, &test_name).unwrap(); + TestCaseRaw { + name: test_name, + available_gas: config.available_gas, + ignored: config.ignored, + expected_result: config.expected_result, + fork_config: config.fork_config, + fuzzer_config: config.fuzzer_config, + test_details, + } }) .collect(); - let sierra_program = replace_sierra_ids_in_program(db, &sierra_program.program); - - validate_tests(sierra_program.clone(), &collected_tests)?; + validate_tests(&function_finder, &collected_tests)?; Ok((sierra_program, collected_tests)) } +fn build_test_details(function_finder: &FunctionFinder, test_name: &str) -> Result { + let func = function_finder.find_function(test_name)?; + + let parameter_types = + function_finder.generic_id_and_size_from_concrete(&func.signature.param_types); + let return_types = function_finder.generic_id_and_size_from_concrete(&func.signature.ret_types); + + Ok(TestDetails { + entry_point_offset: func.entry_point.0, + parameter_types, + return_types, + }) +} + fn build_diagnostics_reporter(compilation_unit: &CompilationUnit) -> DiagnosticsReporter<'static> { if compilation_unit.allow_warnings() { DiagnosticsReporter::stderr().allow_warnings() @@ -208,13 +236,9 @@ fn insert_lib_entrypoint_content_into_db( } fn validate_tests( - sierra_program: Program, + function_finder: &FunctionFinder, collected_tests: &Vec, ) -> Result<(), anyhow::Error> { - let function_finder = match FunctionFinder::new(sierra_program) { - Ok(casm_generator) => casm_generator, - Err(e) => panic!("{}", e), - }; for test in collected_tests { let func = function_finder.find_function(&test.name)?; let signature = &func.signature; diff --git a/extensions/scarb-snforge-test-collector/src/compilation/test_collector/function_finder.rs b/extensions/scarb-snforge-test-collector/src/compilation/test_collector/function_finder.rs index b4452195b..852e2e3c5 100644 --- a/extensions/scarb-snforge-test-collector/src/compilation/test_collector/function_finder.rs +++ b/extensions/scarb-snforge-test-collector/src/compilation/test_collector/function_finder.rs @@ -2,8 +2,10 @@ use cairo_lang_sierra::extensions::core::{CoreLibfunc, CoreType}; use cairo_lang_sierra::extensions::ConcreteType; +use cairo_lang_sierra::ids::{ConcreteTypeId, GenericTypeId}; use cairo_lang_sierra::program::Function; use cairo_lang_sierra::program_registry::{ProgramRegistry, ProgramRegistryError}; +use cairo_lang_sierra_type_size::{get_type_size_map, TypeSizeMap}; use thiserror::Error; #[derive(Debug, Error)] @@ -12,6 +14,8 @@ pub enum FinderError { MissingFunction { suffix: String }, #[error(transparent)] ProgramRegistryError(#[from] Box), + #[error("Unable to create TypeSizeMap.")] + TypeSizeMapError, } pub struct FunctionFinder { @@ -19,6 +23,8 @@ pub struct FunctionFinder { sierra_program: cairo_lang_sierra::program::Program, /// Program registry for the Sierra program. sierra_program_registry: ProgramRegistry, + // Mapping for the sizes of all types for sierra_program + type_size_map: TypeSizeMap, } #[allow(clippy::result_large_err)] @@ -26,9 +32,13 @@ impl FunctionFinder { pub fn new(sierra_program: cairo_lang_sierra::program::Program) -> Result { let sierra_program_registry = ProgramRegistry::::new(&sierra_program)?; + let type_size_map = get_type_size_map(&sierra_program, &sierra_program_registry) + .ok_or(FinderError::TypeSizeMapError)?; + Ok(Self { sierra_program, sierra_program_registry, + type_size_map, }) } @@ -57,4 +67,20 @@ impl FunctionFinder { ) -> &cairo_lang_sierra::extensions::types::TypeInfo { self.sierra_program_registry.get_type(ty).unwrap().info() } + + /// Converts array of `ConcreteTypeId`s into corresponding `GenericTypeId`s and their sizes + pub fn generic_id_and_size_from_concrete( + &self, + types: &[ConcreteTypeId], + ) -> Vec<(GenericTypeId, i16)> { + types + .iter() + .map(|pt| { + let info = self.get_info(pt); + let generic_id = &info.long_id.generic_id; + let size = self.type_size_map[pt]; + (generic_id.clone(), size) + }) + .collect() + } } diff --git a/extensions/scarb-snforge-test-collector/tests/test.rs b/extensions/scarb-snforge-test-collector/tests/test.rs index 3326d2513..d45aac40f 100644 --- a/extensions/scarb-snforge-test-collector/tests/test.rs +++ b/extensions/scarb-snforge-test-collector/tests/test.rs @@ -48,11 +48,20 @@ fn forge_test_locations() { assert_eq!(&json[1]["test_cases"][0]["name"], "tests::tests::test"); assert_eq!(&json[1]["tests_location"], "Tests"); - assert_eq!(&json[0]["test_cases"][0]["available_gas"], &Value::Null); - assert_eq!(&json[0]["test_cases"][0]["expected_result"], "Success"); - assert_eq!(&json[0]["test_cases"][0]["fork_config"], &Value::Null); - assert_eq!(&json[0]["test_cases"][0]["fuzzer_config"], &Value::Null); - assert_eq!(&json[0]["test_cases"][0]["ignored"], false); + let case_0 = &json[0]["test_cases"][0]; + + assert_eq!(&case_0["available_gas"], &Value::Null); + assert_eq!(&case_0["expected_result"], "Success"); + assert_eq!(&case_0["fork_config"], &Value::Null); + assert_eq!(&case_0["fuzzer_config"], &Value::Null); + assert_eq!(&case_0["ignored"], false); + assert_eq!(&case_0["test_details"]["entry_point_offset"], 0); + assert_eq!( + &case_0["test_details"]["parameter_types"], + &Value::Array(vec![]) + ); + assert_eq!(&case_0["test_details"]["return_types"][0][0], "Enum"); + assert_eq!(&case_0["test_details"]["return_types"][0][1], 3); } #[test] @@ -88,7 +97,9 @@ const WITH_MANY_ATTRIBUTES_TEST: &str = indoc! {r#" #[available_gas(100)] #[test] fn test(a: felt252) { - assert(true == true, 'it works!') + let (x, y) = (1_u8, 2_u8); + let z = x + y; + assert(x < z, 'it works!') } } "# @@ -114,37 +125,32 @@ fn forge_test_with_attributes() { .read_to_string(); let json: Value = serde_json::from_str(&snforge_sierra).unwrap(); + let case_0 = &json[0]["test_cases"][0]; + assert_eq!(&case_0["available_gas"], &Value::Number(Number::from(100))); + assert_eq!(&case_0["expected_result"]["Panics"], "Any"); + assert_eq!(&case_0["fork_config"]["Params"]["block_id_type"], "Number"); + assert_eq!(&case_0["fork_config"]["Params"]["block_id_value"], "123"); assert_eq!( - &json[0]["test_cases"][0]["available_gas"], - &Value::Number(Number::from(100)) - ); - assert_eq!( - &json[0]["test_cases"][0]["expected_result"]["Panics"], - "Any" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_type"], - "Number" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_value"], - "123" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["url"], + case_0["fork_config"]["Params"]["url"], "http://your.rpc.url" ); + assert_eq!(&case_0["fuzzer_config"]["fuzzer_runs"], 22); + assert_eq!(&case_0["fuzzer_config"]["fuzzer_seed"], 38); + assert_eq!(&case_0["ignored"], true); + assert_eq!(&case_0["name"], "forge_test::tests::test"); + assert_eq!(&case_0["test_details"]["entry_point_offset"], 0); assert_eq!( - &json[0]["test_cases"][0]["fuzzer_config"]["fuzzer_runs"], - 22 - ); - assert_eq!( - &json[0]["test_cases"][0]["fuzzer_config"]["fuzzer_seed"], - 38 + &case_0["test_details"]["parameter_types"][0][0], + "RangeCheck" ); - assert_eq!(&json[0]["test_cases"][0]["ignored"], true); - assert_eq!(&json[0]["test_cases"][0]["name"], "forge_test::tests::test"); + assert_eq!(&case_0["test_details"]["parameter_types"][0][1], 1); + assert_eq!(&case_0["test_details"]["parameter_types"][1][0], "felt252"); + assert_eq!(&case_0["test_details"]["parameter_types"][1][1], 1); + assert_eq!(&case_0["test_details"]["return_types"][0][0], "RangeCheck"); + assert_eq!(&case_0["test_details"]["return_types"][0][1], 1); + assert_eq!(&case_0["test_details"]["return_types"][1][0], "Enum"); + assert_eq!(&case_0["test_details"]["return_types"][1][1], 3); } const FORK_TAG_TEST: &str = indoc! {r#" @@ -179,21 +185,16 @@ fn forge_test_with_fork_tag_attribute() { .read_to_string(); let json: Value = serde_json::from_str(&snforge_sierra).unwrap(); + let case_0 = &json[0]["test_cases"][0]; + assert_eq!(&case_0["fork_config"]["Params"]["block_id_type"], "Tag"); + assert_eq!(&case_0["fork_config"]["Params"]["block_id_value"], "Latest"); assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_type"], - "Tag" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_value"], - "Latest" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["url"], + case_0["fork_config"]["Params"]["url"], "http://your.rpc.url" ); - assert_eq!(&json[0]["test_cases"][0]["name"], "forge_test::tests::test"); + assert_eq!(&case_0["name"], "forge_test::tests::test"); } const FORK_HASH_TEST: &str = indoc! {r#" @@ -228,21 +229,16 @@ fn forge_test_with_fork_hash_attribute() { .read_to_string(); let json: Value = serde_json::from_str(&snforge_sierra).unwrap(); + let case_0 = &json[0]["test_cases"][0]; + assert_eq!(&case_0["fork_config"]["Params"]["block_id_type"], "Hash"); + assert_eq!(&case_0["fork_config"]["Params"]["block_id_value"], "123"); assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_type"], - "Hash" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["block_id_value"], - "123" - ); - assert_eq!( - &json[0]["test_cases"][0]["fork_config"]["Params"]["url"], + case_0["fork_config"]["Params"]["url"], "http://your.rpc.url" ); - assert_eq!(&json[0]["test_cases"][0]["name"], "forge_test::tests::test"); + assert_eq!(&case_0["name"], "forge_test::tests::test"); } const SHOULD_PANIC_TEST: &str = indoc! {r#"