Skip to content

Commit

Permalink
token labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Damien Sileo committed May 4, 2023
1 parent e5908d9 commit de7dfb9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, tasks, args, warm_start=None):
if task.task_type=='MultipleChoice':
key=task.task_type
else:
labels = getattr(task.dataset["train"].features[task.y],"names",None)
labels = task.get_labels()#getattr(task.dataset["train"].features[task.y],"names",None)
key= tuple([normalize_label(x) for x in labels]) if labels else None
#key = key if task.num_labels!=2 or key else "binary"

Expand Down
10 changes: 8 additions & 2 deletions src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,21 @@ def __post_init__(self):
target = self.dataset[self.main_split].features[self.y]
if not self.num_labels:
self.num_labels = 1 if "float" in target.dtype else target.feature.num_classes
self.label_names = [f"{i}" for i in range(self.num_labels)]

try:
self.label_names=target.feature.names
except:
self.label_names = [f"{i}" for i in range(self.num_labels)]
def get_labels(self):
return self.label_names
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
self.tokenizer.add_prefix_space = True
self.data_collator = DataCollatorForTokenClassification(
tokenizer=self.tokenizer
)



def preprocess_function(self, examples):
if examples[self.tokens] and type(examples[self.tokens][0])==str:
unsqueeze, examples= True, wrap_examples(examples)
Expand Down
6 changes: 4 additions & 2 deletions src/tasknet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,17 @@ def search_module(m,name, mode='attr', lowercase=True):
raise ValueError('mode must be "attr" or "class"')


def load_pipeline(model_name, task_name, adapt_task_embedding=True):
def load_pipeline(model_name, task_name, adapt_task_embedding=True,multilingual=False):
if multilingual or 'mdeberta' in model_name:
multilingual=True

from transformers import AutoModelForSequenceClassification, TextClassificationPipeline, AutoTokenizer
from .models import Adapter
try:
import tasksource
except:
raise ImportError('Requires tasksource.\n pip install tasksource')
task = tasksource.load_task(task_name)
task = tasksource.load_task(task_name,multilingual=multilingual)

model = AutoModelForSequenceClassification.from_pretrained(model_name,ignore_mismatched_sizes=True)
adapter = Adapter.from_pretrained(model_name.replace('-nli','')+'-adapters')
Expand Down

0 comments on commit de7dfb9

Please sign in to comment.