Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Initial model training #25

389 changes: 389 additions & 0 deletions pytact/notebooks/explore_dgl_lree_lstm.ipynb

Large diffs are not rendered by default.

283 changes: 283 additions & 0 deletions pytact/notebooks/model_training_v1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4b3e05b4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"kj/filesystem-disk-unix.c++:1703: warning: PWD environment variable doesn't match current directory; pwd = /root\n"
]
}
],
"source": [
"# Load the dataset into PyTactician's visualizer.\n",
"from pytact import data_reader, graph_visualize_browse\n",
"import pathlib\n",
"from typing import Optional, List, DefaultDict\n",
"from pytact.data_reader import Node\n",
"from pytact.graph_api_capnp_cython import EdgeClassification\n",
"from pytact.graph_api_capnp_cython import Graph_Node_Label_Which\n",
"from collections import defaultdict\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"import random\n",
"import numpy as np\n",
"from sklearn.metrics import classification_report"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3150d1b2",
"metadata": {},
"outputs": [],
"source": [
"class BasicCSRNN(nn.Module):\n",
" def __init__(self, embedding_size, hidden_size, nodes_number, edges_number): #Nodes_number is an input tokens size\n",
" super(BasicCSRNN, self).__init__()\n",
" self.embedding = nn.Embedding(nodes_number, embedding_size)\n",
" self.Wx = torch.randn(embedding_size, hidden_size) # n_inputs X n_neurons\n",
" self.We = torch.randn(edges_number, hidden_size, hidden_size) # n_edges X 1 X n_neurons\n",
" self.hidden_size = hidden_size\n",
" self.b = torch.zeros(1, hidden_size) # 1 X n_neurons\n",
"\n",
" def forward(self, node):\n",
" return self.node_forward(node)\n",
"\n",
" def node_forward(self, node):\n",
" emb = self.embedding(torch.tensor(node.label.which.value))\n",
" emb = emb.view(1, -1)\n",
" x = torch.mm(emb, self.Wx)\n",
" if node.children and not node.label.which.name == 'REL':\n",
" hidden = torch.mean(torch.stack([x + torch.mm(self.node_forward(child), self.We[edge_type.value]) for edge_type, child in list(node.children)]), dim=0) \n",
" else:\n",
" # Ensure that the zero tensor is of the correct shape [batch size, hidden size]\n",
" hidden = torch.zeros(1, self.hidden_size, dtype=torch.float, device=x.device)\n",
" return torch.tanh(x + hidden + self.b)\n",
" \n",
" \n",
"class RNNLabelDecode(nn.Module):\n",
" def __init__(self, hidden_size, output_size, edges_number):\n",
" super(RNNLabelDecode, self).__init__()\n",
" self.hidden_size = hidden_size\n",
" self.We = nn.Parameter(torch.randn(edges_number, hidden_size, hidden_size))\n",
" self.Wdc = nn.Linear(hidden_size, output_size, bias=True)\n",
" self.be = torch.zeros(1, hidden_size) # 1 X n_neurons\n",
" \n",
" # Keep track of edges if needed\n",
" self.decoded_edges = []\n",
" def forward(self, embedding, node, max_depth):\n",
" self.decoded_edges = []\n",
" self.node_decode_forward(embedding, node, depth=1, max_depth=max_depth)\n",
" return self.decoded_edges\n",
" \n",
" def node_decode_forward(self, embedding, node, depth, max_depth):\n",
" # Decode label \n",
" logits = self.Wdc(embedding)\n",
" probabilities = F.softmax(logits)\n",
" self.decoded_edges.append(probabilities)\n",
" if node.children and not node.label.which.name == 'REL' and depth < max_depth:\n",
" for edge_type, child in node.children: \n",
" new_embedding = torch.mm(embedding, self.We[edge_type.value]) + self.be #Calculate new hidden state\n",
" self.node_decode_forward(new_embedding, child, depth=depth+1, max_depth=max_depth) # Decode child\n",
" \n",
" \n",
"class DecoderRNNClasifier(nn.Module):\n",
" def __init__(self, embedding_size, hidden_size, nodes_number, edges_number):\n",
" super(DecoderRNNClasifier, self).__init__() \n",
" self.dec = RNNLabelDecode(hidden_size, nodes_number, edges_number)\n",
" self.enc = BasicCSRNN(embedding_size, hidden_size, nodes_number, edges_number) \n",
" \n",
" def forward(self, node, max_depth): \n",
" emb = self.enc(node)\n",
" dec = self.dec(emb, node, max_depth)\n",
" \n",
" return dec\n",
"\n",
"\n",
"class LabelGetter: \n",
" def __init__(self): \n",
" self.labels = []\n",
" def get_labels(self, graph, max_depth):\n",
" self.labels = []\n",
" self.get_labels_helper(graph, 1, max_depth)\n",
" return self.labels\n",
" def get_labels_helper(self, graph, depth, max_depth):\n",
" self.labels.append(graph.label.which.value)\n",
" if graph.children and not graph.label.which.name == 'REL' and depth < max_depth: \n",
" for _, child in list(graph.children):\n",
" self.get_labels_helper(child, depth+1, max_depth)\n",
" \n",
"def get_file_size(reader, dataset_pointer): \n",
" pdl = dataset_pointer.lowlevel\n",
" size = len(pdl.graph.nodes)\n",
" return size"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa4eda44",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "82e5c6ad",
"metadata": {},
"outputs": [],
"source": [
"# Constants and configurations\n",
"DATASET_PATH = '../../../../v15-stdlib-coq8.11/dataset'\n",
"FILE_PATH = \"coq-tactician-stdlib.8.11.dev/theories/Init/Logic.bin\"\n",
"DATASET_PATH = pathlib.Path(DATASET_PATH)\n",
"FILE_PATH = pathlib.Path(FILE_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1f62ece0",
"metadata": {},
"outputs": [],
"source": [
"# Randomness\n",
"RANDOM_SEED = 42\n",
"random.seed(RANDOM_SEED)\n",
"torch.manual_seed(RANDOM_SEED)\n",
"\n",
"# Model Parameters \n",
"NODES_NUMBER = 30\n",
"EMBEDDING_SIZE = 8\n",
"HIDDEN_SIZE = 16\n",
"EDGES_NUMBER = 50\n",
"\n",
"# Model Introduction\n",
"model = DecoderRNNClasifier(EMBEDDING_SIZE, HIDDEN_SIZE, NODES_NUMBER, EDGES_NUMBER)\n",
"lg = LabelGetter() #graph node_labels extractor\n",
"\n",
"# Model Training Details\n",
"LEARNING_RATE = 0.001\n",
"BATCH_SIZE = 20\n",
"MAX_DECODING_DEPTH = 3\n",
"EPOCHS = 3\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "478f003e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_150483/3908839096.py:43: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" probabilities = F.softmax(logits)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Max decoding depth: 1, Epoch 1/3, Training Loss: 3.052128731485556, TrainingAccuracy: 67.51%, Test Accuracy: 89.13%\n",
"Max decoding depth: 1, Epoch 2/3, Training Loss: 2.6578493127871217, TrainingAccuracy: 91.94%, Test Accuracy: 92.03%\n",
"Max decoding depth: 1, Epoch 3/3, Training Loss: 2.562999010732011, TrainingAccuracy: 93.79%, Test Accuracy: 94.66%\n",
"Max decoding depth: 2, Epoch 1/3, Training Loss: 1.1934545814126256, TrainingAccuracy: 48.01%, Test Accuracy: 60.79%\n",
"Max decoding depth: 2, Epoch 2/3, Training Loss: 1.1457983314641975, TrainingAccuracy: 66.84%, Test Accuracy: 71.08%\n",
"Max decoding depth: 2, Epoch 3/3, Training Loss: 1.132676905871813, TrainingAccuracy: 76.18%, Test Accuracy: 77.36%\n",
"Max decoding depth: 3, Epoch 1/3, Training Loss: 0.8206769919992821, TrainingAccuracy: 47.33%, Test Accuracy: 48.19%\n",
"Max decoding depth: 3, Epoch 2/3, Training Loss: 0.8182403788558541, TrainingAccuracy: 49.05%, Test Accuracy: 51.46%\n",
"Max decoding depth: 3, Epoch 3/3, Training Loss: 0.8166042187206709, TrainingAccuracy: 51.80%, Test Accuracy: 52.30%\n"
]
}
],
"source": [
"with data_reader.data_reader(DATASET_PATH) as reader:\n",
" dataset_pointer = reader[FILE_PATH] \n",
" grpahs_number = get_file_size(reader, dataset_pointer)\n",
" shuffled_indexes = list(range(grpahs_number)) # change indexes to random_shuffle\n",
" random.shuffle(shuffled_indexes)\n",
" train_indexes = shuffled_indexes[:grpahs_number*7//10]\n",
" test_indexes = shuffled_indexes[grpahs_number*7//10:]\n",
" for max_depth in range(1, MAX_DECODING_DEPTH+1):\n",
" for epoch in range(EPOCHS):\n",
" # Training Loop\n",
" correct = 0\n",
" total = 0\n",
" total_loss = 0\n",
" for i in train_indexes:\n",
" graph = dataset_pointer.node_by_id(i)\n",
" labels = lg.get_labels(graph, max_depth)\n",
" optimizer.zero_grad()\n",
" output_whole = model(graph, max_depth=max_depth)\n",
" loss = criterion(torch.stack(output_whole).squeeze(1), torch.tensor(labels))/len(labels)\n",
" loss.backward()\n",
" if (i + 1) % BATCH_SIZE == 0:\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" total_loss += loss.item()\n",
" predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)\n",
" correct += (predictions == torch.tensor(labels)).sum().item()\n",
" total += len(labels)\n",
" trainig_accuracy = correct / total if total > 0 else 0\n",
" \n",
" # Testing Loop\n",
" correct = 0\n",
" total = 0\n",
" for i in test_indexes:\n",
" graph = dataset_pointer.node_by_id(i)\n",
" labels = lg.get_labels(graph, max_depth)\n",
" with torch.no_grad():\n",
" output_whole = model(graph, max_depth=max_depth)\n",
" predictions = torch.argmax(torch.stack(output_whole).squeeze(1), dim=1)\n",
" correct += (predictions == torch.tensor(labels)).sum().item()\n",
" total += len(labels)\n",
" \n",
" accuracy = correct / total if total > 0 else 0\n",
" print(f'Max decoding depth: {max_depth}, Epoch {epoch+1}/{EPOCHS}, Training Loss: {total_loss / len(train_indexes)}, TrainingAccuracy: {trainig_accuracy* 100:.2f}%, Test Accuracy: {accuracy * 100:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1de41da5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tactician",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading