diff --git a/pytact/fake_python_server.py b/pytact/fake_python_server.py index 9f768df..0e91526 100644 --- a/pytact/fake_python_server.py +++ b/pytact/fake_python_server.py @@ -32,19 +32,25 @@ async def text_prediction_loop(context : GlobalContextMessage): else: raise Exception("Capnp protocol error") -async def graph_prediction_loop(context : GlobalContextMessage, level): +async def graph_prediction_loop(context : GlobalContextMessage, prev_tactics, level): print(f"level {level}") for cluster in context.definitions.clustered_definitions(full = False): print('cluster:') for d in cluster: print(f' {d.name}') - for t in context.tactics: - print(t) + + tactics = prev_tactics.copy() + for d in context.definitions.definitions(full = False): + if p := d.proof: + for ps in p: + if t := ps.tactic: + tactics.add((t.ident, len(ps.outcomes[0].tactic_arguments))) + print(context.log_annotation) prediction_requests = context.prediction_requests cool_definitions = [ d.node for d in context.definitions.definitions() if d.name == "Coq.Init.Logic.I" ] - zeroArgs = [t.ident for t in context.tactics if t.parameters == 0] - oneArg = [t.ident for t in context.tactics if t.parameters == 1] + zeroArgs = [ident for (ident, parameters) in tactics if parameters == 0] + oneArg = [ident for (ident, parameters) in tactics if parameters == 1] async for msg in prediction_requests: # Redirect any exceptions to Coq. Additionally, deal with CancellationError # thrown when a request from Coq is cancelled @@ -61,11 +67,11 @@ async def graph_prediction_loop(context : GlobalContextMessage, level): await prediction_requests.asend(TacticPredictionsGraph(preds)) elif isinstance(msg, CheckAlignmentMessage): unknown_definitions = list(context.definitions.definitions()) - unknown_tactics = [t.ident for t in context.tactics] + unknown_tactics = [ident for (ident, _) in tactics] alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics) await prediction_requests.asend(alignment) elif isinstance(msg, GlobalContextMessage): - await graph_prediction_loop(msg, level + 1) + await graph_prediction_loop(msg, tactics, level + 1) else: raise Exception(f"Capnp protocol error {msg}") @@ -80,7 +86,7 @@ async def run_session(args, record_file, capnp_stream): await text_prediction_loop(messages_generator) elif args.mode == 'graph': print('Python server running in graph mode') - await graph_prediction_loop(messages_generator, 0) + await graph_prediction_loop(messages_generator, set(), 0) else: raise Exception("The 'mode' argument needs to be either 'text' or 'graph'") diff --git a/src/neural_learner.ml b/src/neural_learner.ml index 486c6ec..db4056c 100644 --- a/src/neural_learner.ml +++ b/src/neural_learner.ml @@ -43,6 +43,12 @@ let find_tactic tacs id = | None -> raise NoSuchTactic | Some x -> x +let add_tactic_info env map tac = + let { base_tactic; args; _ } = analyze_tactic tac in + let params = List.length args in + TacticMap.add + (Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map + let find_local_argument context_map = let context_map_inv = Names.Id.Map.fold (fun id (_, node) m -> Int.Map.add node (Tactic_one_variable.TVar id) m) @@ -193,18 +199,19 @@ type context_state = ; id : int ; constants : Environ.constant_key Cmap_env.t ; inductives : Environ.mind_key Mindmap_env.t - ; section : Constr.named_context } + ; section : Constr.named_context + ; tactics : (glob_tactic_expr * int) TacticMap.t } type context_stack = { stack : context_state list ; stack_size : int } -let update_context_stack id tacs env_extra env { stack_size; stack } = - let state, old_constants, old_inducives, old_section = match stack with +let update_context_stack ?(force=false) id env_extra env { stack_size; stack } = + let state, old_constants, old_inducives, old_section, tactics = match stack with | [] -> let (empty_state, ()), _ = CICGraphMonad.run_empty (CICGraphMonad.return ()) (G.HashMap.create 0) G.builder_nil 0 in - empty_state, Cmap_env.empty, Mindmap_env.empty, [] - | { state; constants; inductives; section; _ }::_ -> state, constants, inductives, section in + empty_state, Cmap_env.empty, Mindmap_env.empty, [], TacticMap.empty + | { state; constants; inductives; section; tactics; _ }::_ -> state, constants, inductives, section, tactics in let globals = Environ.Globals.view Environ.(env.env_globals) in let section = Environ.named_context env in @@ -236,8 +243,23 @@ let update_context_stack id tacs env_extra env { stack_size; stack } = ++ pr_vertical_list Id.print (Id.Set.elements new_section) ); - if Cset.is_empty new_constants && Mindset.is_empty new_inductives && Id.Set.is_empty new_section && - TacticMap.is_empty tacs then state, { stack_size; stack } else + if (not force) && Cset.is_empty new_constants && Mindset.is_empty new_inductives && Id.Set.is_empty new_section + then state, tactics, { stack_size; stack } else + + let update_tactics fold find tmap set map = + fold (fun c tmap -> + match find c map with + | None -> tmap + | Some ls -> + List.fold_left (fun tmap (_, t) -> + match t with + | None -> tmap + | Some t -> add_tactic_info env tmap t + ) tmap ls + ) set tmap in + let env_vars, env_const = env_extra in + let tactics = update_tactics Cset.fold Cmap.find_opt tactics new_constants env_const in + let tactics = update_tactics Id.Set.fold Id.Map.find_opt tactics new_section env_vars in let { def_count; node_count; edge_count; defs; nodes; edges }, state = let open Monad_util.WithMonadNotations(CICGraphMonad) in @@ -264,12 +286,6 @@ let update_context_stack id tacs env_extra env { stack_size; stack } = GlobalContextAddition.log_annotation_set init @@ log_annotation (); ignore(GlobalContextAddition.data_version_set_reader init Api.Reader.current_version); GlobalContextAddition.stack_size_set_int_exn init stack_size; - let tac_arr = GlobalContextAddition.tactics_init init @@ TacticMap.cardinal tacs in - List.iteri (fun i (hash, (_tac, params)) -> - let arri = Capnp.Array.get tac_arr i in - Api.Builder.AbstractTactic.ident_set arri hash; - Api.Builder.AbstractTactic.parameters_set_exn arri params) - (TacticMap.bindings tacs); W.write_graph ~node_hash ~node_label ~node_lower:(fun n -> fst @@ G.lower n) ~node_dep_index:(fun (stack_id, _) -> stack_size - stack_id) ~node_local_index @@ -284,13 +300,14 @@ let update_context_stack id tacs env_extra env { stack_size; stack } = let state = { state with previous = None ; external_previous = Option.cata (fun p -> [p]) state.external_previous state.previous } in - state, { stack_size = stack_size + 1 - ; stack = { request = builder - ; state; id - ; constants = globals.constants - ; inductives = globals.inductives - ; section } - ::stack } + state, tactics, { stack_size = stack_size + 1 + ; stack = { request = builder + ; state; id + ; constants = globals.constants + ; inductives = globals.inductives + ; section + ; tactics } + ::stack } let context_stack = Summary.ref ~name:"neural-learner-graph-cache" { stack = []; stack_size = 0 } @@ -300,13 +317,13 @@ let sync_context_stack add_global_context = let id = ref 0 in let remote_state = ref [] in let remote_stack_size = ref 0 in - fun ?(keep_cache=true) tacs env_extra env -> + fun ?(keep_cache=true) ?(force=false) env_extra env -> if debug_option () then Feedback.msg_notice Pp.( str "old remote stack : " ++ prlist_with_sep (fun () -> str "-") int !remote_state ++ fnl () ++ str "old local stack : " ++ prlist_with_sep (fun () -> str "-") (fun { id; _ } -> int id) !context_stack.stack); - let state, ({ stack_size; stack } as cache) = update_context_stack !id tacs env_extra env !context_stack in + let state, tactics, ({ stack_size; stack } as cache) = update_context_stack ~force !id env_extra env !context_stack in if keep_cache then context_stack := cache; if debug_option () then @@ -333,7 +350,7 @@ let sync_context_stack add_global_context = remote_stack_size := stack_size; if debug_option () then Feedback.msg_notice Pp.(str "new remote stack : " ++ prlist_with_sep (fun () -> str "-") int !remote_state); - state, stack_size + state, tactics, stack_size type capnp_connection = { rc : Unix.file_descr Capnp_unix.IO.ReadContext.t @@ -383,8 +400,8 @@ type connection = type communicator = { add_global_context : (Api.Builder.GlobalContextAddition.t -> unit) -> unit - ; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> env_extra -> Environ.env -> - CICGraphMonad.state * int + ; sync_context_stack : ?keep_cache:bool -> ?force:bool -> env_extra -> Environ.env -> + CICGraphMonad.state * (glob_tactic_expr * int) TacticMap.t * int ; request_prediction : (Api.Builder.PredictionRequest.t -> unit) -> (Graph_api.ro, Api.Reader.Prediction.t, Api.Reader.array_t) Capnp.Array.t ; request_text_prediction : (Api.Builder.PredictionRequest.t -> unit) -> @@ -675,29 +692,14 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc preds type model = - { tactics : (glob_tactic_expr * int) TacticMap.t - ; proofs : env_extra } + { proofs : env_extra } - let last_model = Summary.ref ~name:"neural-learner-lastmodel" { tactics = TacticMap.empty - ; proofs = Id.Map.empty, Cmap.empty } + let last_model = Summary.ref ~name:"neural-learner-lastmodel" { proofs = Id.Map.empty, Cmap.empty } let empty () = - { tactics = TacticMap.empty; proofs = Id.Map.empty, Cmap.empty } - - let add_tactic_info env map tac = - let { base_tactic; args; _ } = analyze_tactic tac in - let params = List.length args in - if params >= 256 then map else - TacticMap.add - (Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map - - let learn { tactics; proofs = var_proofs, const_proofs } (kn, path, status) outcomes tac = - let tactics = match tac with - | None -> tactics - | Some tac -> - let tac = tactic_repr tac in - let tactics = add_tactic_info (Global.env ()) tactics tac in - tactics in + { proofs = Id.Map.empty, Cmap.empty } + + let learn { proofs = var_proofs, const_proofs } (kn, path, status) outcomes tac = (* TODO: Filtering out bad proof states: Occasionally, proof states refer to section variables that have been filtered out by Coq during section @@ -733,14 +735,6 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc (TS.proof_state_hypotheses ps) ~init:status in status in - (* TODO: Ridiculous tactic filtering: *) - let tac = match tac with - | None -> None - | Some tac -> - let { base_tactic; args; _ } = analyze_tactic @@ tactic_repr tac in - let params = List.length args in - if params >= 256 then None else Some tac in - (* TODO: Drop-in shadowing replacement for mk_outcome. For now, we don't need the proof term and after states. We butcher them to make the payload smaller and faster to compute. *) let mk_outcome before result = @@ -758,21 +752,21 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc (* TODO: This is not entirely correct: We always attach a proof to a constant, never to a section variable. No good solution for now. *) let const_proofs = Cmap.update constant update const_proofs in - let db = { tactics; proofs = var_proofs, const_proofs } in + let db = { proofs = var_proofs, const_proofs } in last_model := db; db - let predict { tactics; proofs } = + let predict { proofs } = let { add_global_context; sync_context_stack ; request_prediction; request_text_prediction; _ } = get_communicator () in let env = Global.env () in if not @@ textmode_option () then - let state, stack_size = - sync_context_stack ~keep_cache:false tactics proofs env in + let state, tacs, stack_size = + sync_context_stack ~keep_cache:false ~force:true proofs env in let find_global_argument = find_global_argument state in fun f -> if f = [] then IStream.empty else - let preds = predict request_prediction find_global_argument stack_size state tactics env + let preds = predict request_prediction find_global_argument stack_size state tacs env (List.hd f).state in let preds = List.map (fun (t, c) -> { confidence = c; focus = 0; tactic = tactic_make t }) preds in IStream.of_list preds @@ -792,9 +786,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc let module Request = Api.Builder.PredictionProtocol.Request in let module Response = Api.Reader.PredictionProtocol.Response in let env = Global.env () in - let { tactics; proofs } = !last_model in let proofs = (!last_model).proofs in - let state, stack_size = sync_context_stack ~keep_cache:false tactics proofs env in + let state, tactics, stack_size = sync_context_stack ~keep_cache:false proofs env in let request = Request.init_root () in Request.check_alignment_set request; let unaligned_tacs, unaligned_defs = check_alignment () in @@ -836,7 +829,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc (* We don't send the list of tactics, hence the empty list. Tactics are only sent right before prediction requests are made. *) let proofs = (!last_model).proofs in - let _, stack_size = sync_context_stack TacticMap.empty proofs (Global.env ()) in + let _, _, stack_size = sync_context_stack proofs (Global.env ()) in if debug_option () then Feedback.msg_notice Pp.(str "Cache stack size: " ++ int stack_size)