From d1ab5446c73d49202ce13f0d0ed3a873259ab581 Mon Sep 17 00:00:00 2001 From: Lasse Blaauwbroek Date: Wed, 21 Aug 2024 14:55:54 +0200 Subject: [PATCH] Inline and incremental supplementary environment construction --- src/neural_learner.ml | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/neural_learner.ml b/src/neural_learner.ml index 6c712f7..a2b3796 100644 --- a/src/neural_learner.ml +++ b/src/neural_learner.ml @@ -710,12 +710,13 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc type model = { tactics : (glob_tactic_expr * int) TacticMap.t - ; proofs : (((proof_state * tactic_result) list * tactic option) list * data_status * KerName.t) list } + ; proofs : env_extra } - let last_model = Summary.ref ~name:"neural-learner-lastmodel" { tactics = TacticMap.empty; proofs = [] } + let last_model = Summary.ref ~name:"neural-learner-lastmodel" { tactics = TacticMap.empty + ; proofs = Id.Map.empty, Cmap.empty } let empty () = - { tactics = TacticMap.empty; proofs = [] } + { tactics = TacticMap.empty; proofs = Id.Map.empty, Cmap.empty } let add_tactic_info env map tac = let tac = Tactic_normalize.tactic_normalize @@ Tactic_normalize.tactic_strict tac in @@ -727,25 +728,28 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc TacticMap.add (Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map - let learn { tactics; proofs } (kn, path, status) outcomes tac = + let learn { tactics; proofs = var_proofs, const_proofs } (kn, path, status) outcomes tac = let tactics = match tac with | None -> tactics | Some tac -> let tac = tactic_repr tac in let tactics = add_tactic_info (Global.env ()) tactics tac in tactics in - let proofs = - let proof_states = List.map (fun x -> - x.before, x.result) outcomes in - match proofs with - | (ls, pstatus, pkn)::data when KerName.equal kn pkn -> - ((proof_states, tac)::ls, pstatus, pkn)::data - | _ -> ([proof_states, tac], status, kn)::proofs in - let db = { tactics; proofs } in + let constant = Constant.make1 kn in + + let proof_step = List.map (fun x -> + x.before, x.result) outcomes in + let proof_step = List.map (fun (before, result) -> mk_outcome before result) proof_step, + Option.map tactic_repr tac in + + let proof = Option.default [] @@ Cmap.find_opt constant const_proofs in + let proof = proof_step :: proof in + let const_proofs = Cmap.add constant proof const_proofs in + let db = { tactics; proofs = var_proofs, const_proofs } in last_model := db; db - let env_extra proofs = + let _env_extra proofs = let globrefs = Environ.Globals.view (Global.env ()).env_globals in let section_vars = Id.Set.of_list @@ List.map Context.Named.Declaration.get_id @@ Environ.named_context @@ Global.env () in @@ -782,9 +786,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc ; request_prediction; request_text_prediction; _ } = get_communicator () in let env = Global.env () in if not @@ textmode_option () then - let env_extra = env_extra proofs in let state, stack_size = - sync_context_stack ~keep_cache:false tactics env_extra env in + sync_context_stack ~keep_cache:false tactics proofs env in let find_global_argument = find_global_argument state in fun f -> if f = [] then IStream.empty else @@ -809,8 +812,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc let module Response = Api.Reader.PredictionProtocol.Response in let env = Global.env () in let { tactics; proofs } = !last_model in - let env_extra = env_extra (!last_model).proofs in - let state, stack_size = sync_context_stack ~keep_cache:false tactics env_extra env in + let proofs = (!last_model).proofs in + let state, stack_size = sync_context_stack ~keep_cache:false tactics proofs env in let request = Request.init_root () in Request.check_alignment_set request; let unaligned_tacs, unaligned_defs = check_alignment () in @@ -851,8 +854,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc let { sync_context_stack; _ } = get_communicator () in (* We don't send the list of tactics, hence the empty list. Tactics are only sent right before prediction requests are made. *) - let env_extra = env_extra (!last_model).proofs in - let _, stack_size = sync_context_stack TacticMap.empty env_extra (Global.env ()) in + let proofs = (!last_model).proofs in + let _, stack_size = sync_context_stack TacticMap.empty proofs (Global.env ()) in if debug_option () then Feedback.msg_notice Pp.(str "Cache stack size: " ++ int stack_size)