diff --git a/quint/src/ir/IRTransformer.ts b/quint/src/ir/IRTransformer.ts index 3f3eb6963..5cf1389c1 100644 --- a/quint/src/ir/IRTransformer.ts +++ b/quint/src/ir/IRTransformer.ts @@ -483,6 +483,9 @@ function transformExpression(transformer: IRTransformer, expr: ir.QuintEx): ir.Q newExpr = transformer.enterLambda(newExpr) } + newExpr.params = newExpr.params.map(p => + p.typeAnnotation ? { ...p, typeAnnotation: transformType(transformer, p.typeAnnotation) } : p + ) newExpr.expr = transformExpression(transformer, newExpr.expr) if (transformer.exitLambda) { diff --git a/quint/src/ir/IRVisitor.ts b/quint/src/ir/IRVisitor.ts index 282f048ec..879eb112b 100644 --- a/quint/src/ir/IRVisitor.ts +++ b/quint/src/ir/IRVisitor.ts @@ -456,7 +456,11 @@ export function walkExpression(visitor: IRVisitor, expr: ir.QuintEx): void { if (visitor.enterLambda) { visitor.enterLambda(expr) } - + expr.params.forEach(p => { + if (p.typeAnnotation) { + walkType(visitor, p.typeAnnotation) + } + }) walkExpression(visitor, expr.expr) if (visitor.exitLambda) { diff --git a/quint/test/ir/IRTransformer.test.ts b/quint/test/ir/IRTransformer.test.ts index c01ce0d51..70b210cbc 100644 --- a/quint/test/ir/IRTransformer.test.ts +++ b/quint/test/ir/IRTransformer.test.ts @@ -41,6 +41,26 @@ describe('enterExpr', () => { assert.deepEqual(moduleToString(result), moduleToString(expectedModule)) }) + + it('transforms paramater type annotations', () => { + class TestTransformer implements IRTransformer { + exitType(_: QuintType): QuintType { + return { kind: 'var', name: 'trans' } + } + } + + const transformer = new TestTransformer() + + const m = buildModuleWithDecls(['def foo(x: int, b: int, c: str): int = 42']) + + const transformedDecl = transformModule(transformer, m).declarations[0] + assert(transformedDecl.kind === 'def') + assert(transformedDecl.expr.kind === 'lambda') + transformedDecl.expr.params.forEach(p => { + assert(p.typeAnnotation) + assert.deepEqual(p.typeAnnotation, { kind: 'var', name: 'trans' }) + }) + }) }) describe('enterDecl', () => { diff --git a/quint/test/ir/IRVisitor.test.ts b/quint/test/ir/IRVisitor.test.ts index c3c41d5a3..6c9da5333 100644 --- a/quint/test/ir/IRVisitor.test.ts +++ b/quint/test/ir/IRVisitor.test.ts @@ -902,5 +902,25 @@ describe('walkModule', () => { assert.deepEqual(visitor.entered.map(typeToString), expectedTypes) assert.deepEqual(visitor.exited.map(typeToString), expectedTypes) }) + + it('finds paramater type annotations', () => { + class TestVisitor implements IRVisitor { + typesVisited: QuintType[] = [] + exitType(t: QuintType) { + this.typesVisited.push(t) + } + } + + const visitor = new TestVisitor() + + const m = buildModuleWithDecls(['def foo(x: int, b: str): bool = true']) + walkModule(visitor, m) + const actualTypes = visitor.typesVisited.map(typeToString) + // `int` and `str` should each show up TWICE: + // - once from of the lambda type annotations + // - once from of the parameter type annotation + const expectedTypes = ['int', 'str', 'bool', '(int, str) => bool', 'int', 'str'] + assert.deepEqual(actualTypes, expectedTypes) + }) }) })