-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_wrapper_interface.py
56 lines (46 loc) · 2.16 KB
/
model_wrapper_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import List
import abc
class ModelWrapperInterface(metaclass=abc.ABCMeta):
@classmethod
def __subclasshook__(cls, subclass):
return (hasattr(subclass, 'predict') and
callable(subclass.predict) and
hasattr(subclass, 'clean_text') and
callable(subclass.clean_text) or
hasattr(subclass, 'texts_to_sequences') and
callable(subclass.texts_to_sequences) or
hasattr(subclass, 'sequences_to_texts') and
callable(subclass.sequences_to_texts) or
hasattr(subclass, 'texts_to_tokens') and
callable(subclass.texts_to_tokens) or
hasattr(subclass, 'extract_embedding') and
callable(subclass.extract_embedding) or
NotImplemented)
@abc.abstractmethod
def get_label_list(self) -> List[int]:
"""Returns the list of labels. """
raise NotImplementedError
@abc.abstractmethod
def predict(self, input_texts: List[str]) -> list:
""" Output predictions given a list of input texts. """
raise NotImplementedError
@abc.abstractmethod
def clean_text(self, input_texts: str) -> str:
""" Clean function for the input texts. """
pass
@abc.abstractmethod
def texts_to_sequences(self, input_texts: List[str]) -> List[List[int]]:
""" Converts a list of texts into a list of sequences' ids. """
raise NotImplementedError
@abc.abstractmethod
def sequences_to_texts(self, sequences: List[List[int]]) -> List[List[str]]:
""" Converts a list of sequences' ids into a list of string. """
raise NotImplementedError
@abc.abstractmethod
def texts_to_tokens(self, input_texts: List[List[str]]) -> List[List[str]]:
""" Converts a list of texts into a list of lists of tokens. """
raise NotImplementedError
@abc.abstractmethod
def extract_embedding(self, input_texts, batch_size, layers, layers_aggregation_function):
""" Extract embedding for each input tokens or word. It should be adapted to the specific model under analysis. """
raise NotImplementedError