diff --git a/tests/tyck/wip.chester b/tests/tyck/wip.chester index a052406f..78c2b00a 100644 --- a/tests/tyck/wip.chester +++ b/tests/tyck/wip.chester @@ -1,3 +1,3 @@ record A(a: Integer); let aT = A; -//def getA(x: A): Integer = x.a; \ No newline at end of file +def getA(x: A): Integer = x.a; \ No newline at end of file diff --git a/tyck/src/main/scala/chester/tyck/Elaborater.scala b/tyck/src/main/scala/chester/tyck/Elaborater.scala index 97de24b0..fde0cbcb 100644 --- a/tyck/src/main/scala/chester/tyck/Elaborater.scala +++ b/tyck/src/main/scala/chester/tyck/Elaborater.scala @@ -178,33 +178,13 @@ trait ProvideElaborater extends ProvideCtx with Elaborater with ElaboraterFuncti ck.reporter.apply(problem) ErrorTerm(problem, convertMeta(expr.meta)) } else { - // Elaborate the record expression - val recordTy = newType - val recordTerm = elab(recordExpr, recordTy, effects) - - // Get the field name fieldExpr match { case Identifier(fieldName, _) => - // Read the record type - readMetaVar(toTerm(recordTy)) match { - case RecordCallTerm(recordDef, _, _) => - // Find the field in the record definition - recordDef.fields.find(_.name == fieldName) match { - case Some(fieldTerm) => - // Unify the field type with the expected type - unify(ty, fieldTerm.ty, expr) - // Create field access term - FieldAccessTerm(recordTerm, fieldName, fieldTerm.ty, convertMeta(meta)) - case None => - val problem = FieldNotFound(fieldName, recordDef.name, expr) - ck.reporter.apply(problem) - ErrorTerm(problem, convertMeta(expr.meta)) - } - case other => - val problem = NotARecordType(other, expr) - ck.reporter.apply(problem) - ErrorTerm(problem, convertMeta(expr.meta)) - } + val recordTy = newType + val recordTerm = elab(recordExpr, recordTy, effects) + val resultTerm = FieldAccessTerm(recordTerm, fieldName, toTerm(ty), convertMeta(meta)) + state.addPropagator(RecordFieldPropagator(recordTy, fieldName, ty, expr)) + resultTerm case _ => val problem = InvalidFieldName(fieldExpr) ck.reporter.apply(problem) diff --git a/tyck/src/main/scala/chester/tyck/ElaboraterCommon.scala b/tyck/src/main/scala/chester/tyck/ElaboraterCommon.scala index e97ac230..73fe1955 100644 --- a/tyck/src/main/scala/chester/tyck/ElaboraterCommon.scala +++ b/tyck/src/main/scala/chester/tyck/ElaboraterCommon.scala @@ -492,6 +492,47 @@ trait ElaboraterCommon extends ProvideCtx with ElaboraterBase with CommonPropaga toTerm(argTerm) } + case class RecordFieldPropagator( + recordTy: CellId[Term], + fieldName: Name, + expectedTy: CellId[Term], + cause: Expr + )(using localCtx: Context) extends Propagator[Tyck] { + override val readingCells: Set[CIdOf[Cell[?]]] = Set(recordTy) + override val writingCells: Set[CIdOf[Cell[?]]] = Set(expectedTy) + override val zonkingCells: Set[CIdOf[Cell[?]]] = Set(recordTy, expectedTy) + + override def run(using state: StateAbility[Tyck], more: Tyck): Boolean = { + state.readStable(recordTy) match { + case Some(Meta(id)) => + state.addPropagator(RecordFieldPropagator(id, fieldName, expectedTy, cause)) + true + case Some(RecordCallTerm(recordDef, _, _)) => + recordDef.fields.find(_.name == fieldName) match { + case Some(fieldTerm) => + unify(expectedTy, fieldTerm.ty, cause) + true + case None => + val problem = FieldNotFound(fieldName, recordDef.name, cause) + more.reporter.apply(problem) + true + } + case Some(other) => + val problem = NotARecordType(other, cause) + more.reporter.apply(problem) + true + case None => false + } + } + + override def naiveZonk(needed: Vector[CellIdAny])(using state: StateAbility[Tyck], more: Tyck): ZonkResult = { + state.readStable(recordTy) match { + case None => ZonkResult.Require(Vector(recordTy)) + case _ => ZonkResult.Done + } + } + } + } trait ElaboraterBase extends CommonPropagator[Tyck] {