-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,593 additions
and
293 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from kogito.models.bart.comet import COMETBART\n", | ||
"from kogito.inference import CommonsenseInference\n", | ||
"\n", | ||
"# Load pre-trained knowledge model from HuggingFace\n", | ||
"model = COMETBART.from_pretrained()\n", | ||
"\n", | ||
"# Initialize inference module\n", | ||
"csi = CommonsenseInference()\n", | ||
"\n", | ||
"# Run inference\n", | ||
"text = \"Student gets a library card\"\n", | ||
"context = \"library\"\n", | ||
"kgraph = csi.infer(text, model, context=context)\n", | ||
"\n", | ||
"# Save output knowledge graph to JSON file\n", | ||
"kgraph.to_jsonl(\"kgraph.jsonl\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from kogito.core.head import KnowledgeHead\n", | ||
"from kogito.core.model import KnowledgeModel\n", | ||
"from kogito.core.knowledge import Knowledge, KnowledgeGraph\n", | ||
"from kogito.core.linker import KnowledgeLinker\n", | ||
"from kogito.core.relation import DESIRES\n", | ||
"from kogito.linkers.deberta import DeBERTaLinker\n", | ||
"\n", | ||
"# Knowledge representation\n", | ||
"knowledge = Knowledge(head=KnowledgeHead(\"student\"),\n", | ||
" relation=DESIRES,\n", | ||
" tails=[\"get good grades\"])\n", | ||
"input_graph = KnowledgeGraph([knowledge, ...])\n", | ||
"input_graph = KnowledgeGraph.from_jsonl(\"kgraph.jsonl\")\n", | ||
"\n", | ||
"# Knowledge models and linkers\n", | ||
"model: KnowledgeModel = COMETBART.from_pretrained()\n", | ||
"linker: KnowledgeLinker = DeBERTaLinker()\n", | ||
"\n", | ||
"# Train, evaluate, predict and save model\n", | ||
"model = model.train(input_graph, ...)\n", | ||
"output_graph = model.generate(input_graph, ...)\n", | ||
"metrics = model.evaluate(input_graph, ...)\n", | ||
"model.save_pretrained(...)\n", | ||
"\n", | ||
"# Run inference\n", | ||
"output_graph = csi.infer(text, model=model, linker=linker)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import List, Tuple\n", | ||
"from kogito.core.head import KnowledgeHead\n", | ||
"from kogito.core.relation import KnowledgeRelation\n", | ||
"\n", | ||
"from kogito.core.processors.head import KnowledgeHeadExtractor\n", | ||
"from kogito.core.processors.relation import KnowledgeRelationMatcher\n", | ||
"\n", | ||
"HeadRelationMatch = Tuple[KnowledgeHead, KnowledgeRelation]\n", | ||
"\n", | ||
"class CustomHeadExtractor(KnowledgeHeadExtractor):\n", | ||
" def extract(self, text, doc) -> List[KnowledgeHead]:\n", | ||
" \"\"\"your custom head extraction logic\"\"\"\n", | ||
"\n", | ||
"class CustomRelationMatcher(KnowledgeRelationMatcher):\n", | ||
" def match(self, heads, relations) -> List[HeadRelationMatch]:\n", | ||
" \"\"\"your custom relation matching logic\"\"\"\n", | ||
"\n", | ||
"class CustomKnowledgeModel(KnowledgeModel):\n", | ||
" def train(self, train_graph, *args, **kwargs) -> KnowledgeModel:\n", | ||
" \"\"\"your custom training logic\"\"\"\n", | ||
"\n", | ||
" def generate(self, input_graph, *args, **kwargs) -> KnowledgeGraph:\n", | ||
" \"\"\"your custom inference logic\"\"\"\n", | ||
"\n", | ||
"class CustomKnowledgeLinker(KnowledgeLinker):\n", | ||
" def link(self, input_graph, context) -> List[List[float]]:\n", | ||
" \"\"\"your custom linking logic\"\"\"\n", | ||
"\n", | ||
"csi = CommonsenseInference()\n", | ||
"custom_extractor = CustomHeadExtractor(\"custom_extractor\")\n", | ||
"custom_matcher = CustomRelationMatcher(\"custom_matcher\")\n", | ||
"custom_model = CustomKnowledgeModel()\n", | ||
"custom_linker = CustomKnowledgeLinker()\n", | ||
"csi.add_processor(custom_extractor)\n", | ||
"csi.add_processor(custom_matcher)\n", | ||
"csi.infer(text, model=custom_model, linker=custom_linker)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from kogito.inference import CommonsenseInference\n", | ||
"from kogito.models.gpt3.zeroshot import GPT3Zeroshot\n", | ||
"from kogito.core.relation import KnowledgeRelation, register_relation\n", | ||
"\n", | ||
"# Verbalizer for new relation\n", | ||
"def x_wishes_verbalizer(head, **kwargs):\n", | ||
" index = kwargs.get(\"index\")\n", | ||
" index_txt = f\"{index}\" if index is not None else \"\"\n", | ||
" return f\"Situation {index_txt}: {head}\\nWishes: As a result, PersonX wishes\"\n", | ||
"\n", | ||
"X_WISHES = KnowledgeRelation(\"xWishes\",\n", | ||
" verbalizer=x_wishes_verbalizer,\n", | ||
" prompt=\"How does this situation affect each character's wishes?\")\n", | ||
"register_relation(X_WISHES)\n", | ||
"\n", | ||
"# Define sample graph for new relation\n", | ||
"sample_graph = KnowledgeGraph.from_csv(\"sample_graph.tsv\", sep=\"\\t\")\n", | ||
"# Sample graph should contain example knowledge inferences for the new relation\n", | ||
"# Example sample graph:\n", | ||
"# PersonX is at a party\t xWishes\t to drink beer and dance\n", | ||
"# PersonX bleeds a lot\t xWishes\t to see a doctor\n", | ||
"# PersonX works as a cashier\txWishes\t to be a store manager\n", | ||
"# PersonX gets dirty\t xWishes\t to clean up\n", | ||
"\n", | ||
"# GPT-3 is few-shot prompted with the sample graph\n", | ||
"model = GPT3Zeroshot(api_key=\"\", model_name=\"text-davinci-003\")\n", | ||
"\n", | ||
"csi = CommonsenseInference()\n", | ||
"kgraph = csi.infer(text, model, sample_graph=sample_graph)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from kogito.linkers.deberta import DebertaLinker\n", | ||
"from kogito.inference import CommonsenseInference\n", | ||
"\n", | ||
"csi = CommonsenseInference()\n", | ||
"linker = DebertaLinker()\n", | ||
"context = ...\n", | ||
"input_graph = ...\n", | ||
"\n", | ||
"# Link input graph to context\n", | ||
"relevance_probs = linker.link(input_graph, context)\n", | ||
"\n", | ||
"# Link and filter input graph based on relevancy to the context\n", | ||
"filtered_graph = linker.filter(input_graph, context, threshold=0.6)\n", | ||
"\n", | ||
"# Generate inferences and filter based on relevancy to the context\n", | ||
"kgraph = csi.infer(text, model, context=context, linker=linker, threshold=0.6)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.