166 lines
5.0 KiB
Python
166 lines
5.0 KiB
Python
import json
|
|
import re
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from transformers import (
|
|
AutoModelForTokenClassification,
|
|
AutoTokenizer,
|
|
DataCollatorForTokenClassification,
|
|
Trainer,
|
|
TrainingArguments,
|
|
)
|
|
|
|
from enums import TokenLabel
|
|
|
|
MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased"
|
|
MODEL_DIR = "./model"
|
|
|
|
LABEL_LIST = [label.name for label in TokenLabel]
|
|
LABEL2ID = {label: i for i, label in enumerate(LABEL_LIST)}
|
|
ID2LABEL = {i: label for i, label in enumerate(LABEL_LIST)}
|
|
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
|
|
|
|
|
|
def _normalize_label(label: str) -> str:
|
|
if label in LABEL2ID:
|
|
return label
|
|
if label in VALUE_TO_NAME:
|
|
return VALUE_TO_NAME[label]
|
|
raise ValueError(f"Unknown label: {label}")
|
|
|
|
|
|
def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]:
|
|
tokens = []
|
|
spans = []
|
|
for match in re.finditer(r"\S+", text):
|
|
tokens.append(match.group(0))
|
|
spans.append(match.span())
|
|
return tokens, spans
|
|
|
|
|
|
def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
|
|
annotations = task.get("annotations") or []
|
|
if not annotations:
|
|
return []
|
|
results = annotations[0].get("result", []) or []
|
|
spans = []
|
|
for result in results:
|
|
value = result.get("value", {})
|
|
start = value.get("start")
|
|
end = value.get("end")
|
|
labels = value.get("labels") or []
|
|
if start is None or end is None or not labels:
|
|
continue
|
|
spans.append((int(start), int(end), _normalize_label(labels[0])))
|
|
return spans
|
|
|
|
|
|
def parse_annotated_dataset(path: str) -> list[dict]:
|
|
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
if not isinstance(data, list):
|
|
raise ValueError("Annotated dataset must be a JSON array.")
|
|
|
|
examples = []
|
|
for task in data:
|
|
text = task["data"]["text"]
|
|
spans = _extract_spans(task)
|
|
tokens, token_spans = _text_to_tokens(text)
|
|
labels = []
|
|
for token_start, token_end in token_spans:
|
|
assigned = "RAW_PHRASE"
|
|
for span_start, span_end, span_label in spans:
|
|
if token_start < span_end and token_end > span_start:
|
|
assigned = span_label
|
|
break
|
|
labels.append(assigned)
|
|
examples.append({"tokens": tokens, "labels": labels})
|
|
|
|
return examples
|
|
|
|
|
|
def tokenize_and_align(example, tokenizer, label2id):
|
|
tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
|
|
word_ids = tokenized.word_ids()
|
|
labels = []
|
|
prev = None
|
|
|
|
for word_id in word_ids:
|
|
if word_id is None:
|
|
labels.append(-100)
|
|
elif word_id != prev:
|
|
labels.append(label2id[example["labels"][word_id]])
|
|
else:
|
|
labels.append(-100)
|
|
prev = word_id
|
|
|
|
tokenized["labels"] = labels
|
|
return tokenized
|
|
|
|
|
|
def train(dataset_path: str = "./datasets/annotated/ffmpeg_gemini_v1.json", model_dir: str = MODEL_DIR):
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
model = AutoModelForTokenClassification.from_pretrained(
|
|
MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID2LABEL, label2id=LABEL2ID
|
|
)
|
|
|
|
examples = parse_annotated_dataset(dataset_path)
|
|
dataset = Dataset.from_list(examples)
|
|
dataset = dataset.map(partial(tokenize_and_align, tokenizer=tokenizer, label2id=LABEL2ID))
|
|
dataset = dataset.train_test_split(test_size=0.1)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=model_dir,
|
|
learning_rate=3e-5,
|
|
per_device_train_batch_size=8,
|
|
num_train_epochs=5,
|
|
weight_decay=0.01,
|
|
logging_steps=10,
|
|
save_strategy="epoch",
|
|
)
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=dataset["train"],
|
|
eval_dataset=dataset["test"],
|
|
data_collator=data_collator,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model(model_dir)
|
|
tokenizer.save_pretrained(model_dir)
|
|
|
|
|
|
def _load_model(model_dir: str = MODEL_DIR):
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
model = AutoModelForTokenClassification.from_pretrained(model_dir)
|
|
return tokenizer, model
|
|
|
|
|
|
def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3):
|
|
tokenizer, model = _load_model(model_dir)
|
|
inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True)
|
|
outputs = model(**inputs)
|
|
|
|
logits = outputs.logits[0]
|
|
probs = torch.softmax(logits, dim=-1)
|
|
k = min(max(1, top_k), probs.shape[-1])
|
|
top_probs, top_ids = torch.topk(probs, k, dim=-1)
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
|
|
|
for token, prob_row, id_row in zip(tokens, top_probs, top_ids):
|
|
preds = [
|
|
f"{ID2LABEL.get(int(label_id), 'RAW_PHRASE')}:{float(score):.4f}"
|
|
for label_id, score in zip(id_row, prob_row)
|
|
]
|
|
print(token, ", ".join(preds))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|