Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mismayil committed May 5, 2023
1 parent 7dc2504 commit 186fa43
Show file tree
Hide file tree
Showing 33 changed files with 1,593 additions and 293 deletions.
53 changes: 29 additions & 24 deletions app/server/inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import spacy

from kogito.models.bart.comet import COMETBART
from kogito.models.gpt2.comet import COMETGPT2
from kogito.models.gpt2.zeroshot import GPT2Zeroshot

# from kogito.models.gpt2.comet import COMETGPT2
# from kogito.models.gpt2.zeroshot import GPT2Zeroshot
from kogito.inference import CommonsenseInference
from kogito.core.relation import KnowledgeRelation
from kogito.core.processors.relation import SWEMRelationMatcher, DistilBERTRelationMatcher, BERTRelationMatcher
from kogito.core.processors.relation import (
# SWEMRelationMatcher,
DistilBERTRelationMatcher,
# BERTRelationMatcher,
)
from kogito.linkers.deberta import DebertaLinker

MODEL_MAP = {
Expand All @@ -14,20 +19,21 @@
# "gpt2": GPT2Zeroshot("gpt2-xl")
}

LINKER_MAP = {
"deberta": DebertaLinker()
}
LINKER_MAP = {"deberta": DebertaLinker()}

PROCESSOR_MAP = {
# "swem_relation_matcher": SWEMRelationMatcher("swem_relation_matcher"),
"distilbert_relation_matcher": DistilBERTRelationMatcher("distilbert_relation_matcher"),
"distilbert_relation_matcher": DistilBERTRelationMatcher(
"distilbert_relation_matcher"
),
# "bert_relation_matcher": BERTRelationMatcher("bert_relation_matcher"),
}

nlp = spacy.load("en_core_web_sm")

print("Ready for inference.")


def infer(data):
text = data.get("text")
model = MODEL_MAP.get(data.get("model"))
Expand All @@ -50,7 +56,7 @@ def infer(data):

for proc in set(csi_rel_procs).difference(set(rel_procs)):
csi.remove_processor(proc)

for proc in set(head_procs).difference(set(csi_head_procs)):
csi.add_processor(PROCESSOR_MAP[proc])

Expand All @@ -63,21 +69,20 @@ def infer(data):

linker = LINKER_MAP["deberta"]

output_graph = csi.infer(text=text,
model=model,
heads=heads,
relations=relations,
extract_heads=extract_heads,
match_relations=match_relations,
dry_run=dry_run,
context=context,
threshold=threshold,
linker=linker)

result = {
"text": [],
"graph": []
}
output_graph = csi.infer(
text=text,
model=model,
heads=heads,
relations=relations,
extract_heads=extract_heads,
match_relations=match_relations,
dry_run=dry_run,
context=context,
threshold=threshold,
linker=linker,
)

result = {"text": [], "graph": []}

if output_graph:
result["graph"] = [kg.to_json() for kg in output_graph]
Expand All @@ -86,4 +91,4 @@ def infer(data):
doc = nlp(text)
result["text"] = [token.lemma_.lower() for token in doc]

return result
return result
10 changes: 7 additions & 3 deletions app/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
app = Flask(__name__)
CORS(app)


@app.route("/")
def heartbeat():
return "Running"


@app.route("/inference", methods=["POST"])
def inference():
try:
Expand All @@ -20,9 +22,11 @@ def inference():
traceback.print_exc(e)
return str(e), 500


def main():
port = int(os.environ.get('PORT', 8080))
app.run(debug=os.environ.get('FLASK_DEBUG', False), host='0.0.0.0', port=port)
port = int(os.environ.get("PORT", 8080))
app.run(debug=os.environ.get("FLASK_DEBUG", False), host="0.0.0.0", port=port)


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion examples/docker/Dockerfile → docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ENV KOGITO_DIR=${HOME}/kogito
SHELL ["/bin/bash", "-cu"]

# Install dependencies
RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux
RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux git

# Set up SSH server
RUN mkdir /var/run/sshd
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
12 changes: 8 additions & 4 deletions examples/docker/train.py → docker/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
if __name__ == "__main__":
data_dir = os.environ.get("KOGITO_DATA_DIR")
model = COMETGPT2("gpt2-xl")
train_graph = KnowledgeGraph.from_csv(f"{data_dir}/atomic2020_data-feb2021/train.tsv", header=None, sep="\t")
val_graph = KnowledgeGraph.from_csv(f"{data_dir}/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t")
train_graph = KnowledgeGraph.from_csv(
f"{data_dir}/atomic2020_data-feb2021/train.tsv", header=None, sep="\t"
)
val_graph = KnowledgeGraph.from_csv(
f"{data_dir}/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t"
)
model.train(
train_graph=train_graph,
val_graph=val_graph,
batch_size=16,
output_dir="/scratch/mete/models/comet-gpt2",
log_wandb=True,
lr=5e-5,
epochs=1
)
epochs=1,
)
23 changes: 12 additions & 11 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('..'))

sys.path.insert(0, os.path.abspath(".."))


# -- Project information -----------------------------------------------------

project = 'kogito'
copyright = '2022, Mete Ismayil'
author = 'Mete Ismayil'
project = "kogito"
copyright = "2022, Mete Ismayil"
author = "Mete Ismayil"


# -- General configuration ---------------------------------------------------
Expand All @@ -31,32 +32,32 @@
"insegel",
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.napoleon"
"sphinx.ext.napoleon",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'insegel'
html_theme = "insegel"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]

html_css_files = [
'css/custom.css',
"css/custom.css",
]

autodoc_member_order = 'bysource'
autodoc_member_order = "bysource"
Binary file added eacl2023/kogito-poster-eacl2023.pdf
Binary file not shown.
Binary file added eacl2023/kogito-presentation-eacl2023.pdf
Binary file not shown.
175 changes: 175 additions & 0 deletions examples/eacl.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from kogito.models.bart.comet import COMETBART\n",
"from kogito.inference import CommonsenseInference\n",
"\n",
"# Load pre-trained knowledge model from HuggingFace\n",
"model = COMETBART.from_pretrained()\n",
"\n",
"# Initialize inference module\n",
"csi = CommonsenseInference()\n",
"\n",
"# Run inference\n",
"text = \"Student gets a library card\"\n",
"context = \"library\"\n",
"kgraph = csi.infer(text, model, context=context)\n",
"\n",
"# Save output knowledge graph to JSON file\n",
"kgraph.to_jsonl(\"kgraph.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from kogito.core.head import KnowledgeHead\n",
"from kogito.core.model import KnowledgeModel\n",
"from kogito.core.knowledge import Knowledge, KnowledgeGraph\n",
"from kogito.core.linker import KnowledgeLinker\n",
"from kogito.core.relation import DESIRES\n",
"from kogito.linkers.deberta import DeBERTaLinker\n",
"\n",
"# Knowledge representation\n",
"knowledge = Knowledge(head=KnowledgeHead(\"student\"),\n",
" relation=DESIRES,\n",
" tails=[\"get good grades\"])\n",
"input_graph = KnowledgeGraph([knowledge, ...])\n",
"input_graph = KnowledgeGraph.from_jsonl(\"kgraph.jsonl\")\n",
"\n",
"# Knowledge models and linkers\n",
"model: KnowledgeModel = COMETBART.from_pretrained()\n",
"linker: KnowledgeLinker = DeBERTaLinker()\n",
"\n",
"# Train, evaluate, predict and save model\n",
"model = model.train(input_graph, ...)\n",
"output_graph = model.generate(input_graph, ...)\n",
"metrics = model.evaluate(input_graph, ...)\n",
"model.save_pretrained(...)\n",
"\n",
"# Run inference\n",
"output_graph = csi.infer(text, model=model, linker=linker)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Tuple\n",
"from kogito.core.head import KnowledgeHead\n",
"from kogito.core.relation import KnowledgeRelation\n",
"\n",
"from kogito.core.processors.head import KnowledgeHeadExtractor\n",
"from kogito.core.processors.relation import KnowledgeRelationMatcher\n",
"\n",
"HeadRelationMatch = Tuple[KnowledgeHead, KnowledgeRelation]\n",
"\n",
"class CustomHeadExtractor(KnowledgeHeadExtractor):\n",
" def extract(self, text, doc) -> List[KnowledgeHead]:\n",
" \"\"\"your custom head extraction logic\"\"\"\n",
"\n",
"class CustomRelationMatcher(KnowledgeRelationMatcher):\n",
" def match(self, heads, relations) -> List[HeadRelationMatch]:\n",
" \"\"\"your custom relation matching logic\"\"\"\n",
"\n",
"class CustomKnowledgeModel(KnowledgeModel):\n",
" def train(self, train_graph, *args, **kwargs) -> KnowledgeModel:\n",
" \"\"\"your custom training logic\"\"\"\n",
"\n",
" def generate(self, input_graph, *args, **kwargs) -> KnowledgeGraph:\n",
" \"\"\"your custom inference logic\"\"\"\n",
"\n",
"class CustomKnowledgeLinker(KnowledgeLinker):\n",
" def link(self, input_graph, context) -> List[List[float]]:\n",
" \"\"\"your custom linking logic\"\"\"\n",
"\n",
"csi = CommonsenseInference()\n",
"custom_extractor = CustomHeadExtractor(\"custom_extractor\")\n",
"custom_matcher = CustomRelationMatcher(\"custom_matcher\")\n",
"custom_model = CustomKnowledgeModel()\n",
"custom_linker = CustomKnowledgeLinker()\n",
"csi.add_processor(custom_extractor)\n",
"csi.add_processor(custom_matcher)\n",
"csi.infer(text, model=custom_model, linker=custom_linker)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from kogito.inference import CommonsenseInference\n",
"from kogito.models.gpt3.zeroshot import GPT3Zeroshot\n",
"from kogito.core.relation import KnowledgeRelation, register_relation\n",
"\n",
"# Verbalizer for new relation\n",
"def x_wishes_verbalizer(head, **kwargs):\n",
" index = kwargs.get(\"index\")\n",
" index_txt = f\"{index}\" if index is not None else \"\"\n",
" return f\"Situation {index_txt}: {head}\\nWishes: As a result, PersonX wishes\"\n",
"\n",
"X_WISHES = KnowledgeRelation(\"xWishes\",\n",
" verbalizer=x_wishes_verbalizer,\n",
" prompt=\"How does this situation affect each character's wishes?\")\n",
"register_relation(X_WISHES)\n",
"\n",
"# Define sample graph for new relation\n",
"sample_graph = KnowledgeGraph.from_csv(\"sample_graph.tsv\", sep=\"\\t\")\n",
"# Sample graph should contain example knowledge inferences for the new relation\n",
"# Example sample graph:\n",
"# PersonX is at a party\t xWishes\t to drink beer and dance\n",
"# PersonX bleeds a lot\t xWishes\t to see a doctor\n",
"# PersonX works as a cashier\txWishes\t to be a store manager\n",
"# PersonX gets dirty\t xWishes\t to clean up\n",
"\n",
"# GPT-3 is few-shot prompted with the sample graph\n",
"model = GPT3Zeroshot(api_key=\"\", model_name=\"text-davinci-003\")\n",
"\n",
"csi = CommonsenseInference()\n",
"kgraph = csi.infer(text, model, sample_graph=sample_graph)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from kogito.linkers.deberta import DebertaLinker\n",
"from kogito.inference import CommonsenseInference\n",
"\n",
"csi = CommonsenseInference()\n",
"linker = DebertaLinker()\n",
"context = ...\n",
"input_graph = ...\n",
"\n",
"# Link input graph to context\n",
"relevance_probs = linker.link(input_graph, context)\n",
"\n",
"# Link and filter input graph based on relevancy to the context\n",
"filtered_graph = linker.filter(input_graph, context, threshold=0.6)\n",
"\n",
"# Generate inferences and filter based on relevancy to the context\n",
"kgraph = csi.infer(text, model, context=context, linker=linker, threshold=0.6)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 186fa43

Please sign in to comment.