Files
clint-dataset/mini_lm.py
2026-04-08 18:58:15 +05:30

168 lines
4.9 KiB
Python

import json
from functools import partial
from pathlib import Path
import torch
from datasets import Dataset
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
)
from enums import TokenLabel
MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased"
MODEL_DIR = "./model"
LABEL_LIST = [label.name for label in TokenLabel]
LABEL2ID = {label: i for i, label in enumerate(LABEL_LIST)}
ID2LABEL = {i: label for i, label in enumerate(LABEL_LIST)}
VALUE_TO_NAME = {label.value: label.name for label in TokenLabel}
def _normalize_label(label: str) -> str:
if label in LABEL2ID:
return label
if label in VALUE_TO_NAME:
return VALUE_TO_NAME[label]
raise ValueError(f"Unknown label: {label}")
def _extract_spans(task: dict) -> list[dict]:
annotations = task.get("annotations") or []
if not annotations:
return []
results = annotations[0].get("result", []) or []
spans = []
for result in results:
value = result.get("value", {})
start = value.get("start")
end = value.get("end")
labels = value.get("labels") or []
if start is None or end is None or not labels:
continue
spans.append(
{
"start": int(start),
"end": int(end),
"label": _normalize_label(labels[0]),
}
)
return spans
def parse_annotated_dataset(path: str) -> list[dict]:
data = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(data, list):
raise ValueError("Annotated dataset must be a JSON array.")
examples = []
for task in data:
text = task["data"]["text"]
spans = _extract_spans(task)
examples.append({"text": text, "spans": spans})
return examples
def _label_for_offset(start: int, end: int, spans: list[dict]) -> str:
best_label = "RAW_PHRASE"
best_overlap = 0
for span in spans:
span_start = span["start"]
span_end = span["end"]
span_label = span["label"]
overlap = min(end, span_end) - max(start, span_start)
if overlap > best_overlap:
best_overlap = overlap
best_label = span_label
return best_label
def tokenize_and_align(example, tokenizer, label2id):
tokenized = tokenizer(
example["text"],
truncation=True,
return_offsets_mapping=True,
)
labels = []
for start, end in tokenized["offset_mapping"]:
if start == end:
labels.append(-100)
else:
label_name = _label_for_offset(start, end, example["spans"])
labels.append(label2id[label_name])
tokenized["labels"] = labels
tokenized.pop("offset_mapping")
return tokenized
def train(dataset_path: str = "./datasets/annotated/ffmpeg_gemini_v1.json", model_dir: str = MODEL_DIR):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(
MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID2LABEL, label2id=LABEL2ID
)
examples = parse_annotated_dataset(dataset_path)
dataset = Dataset.from_list(examples)
dataset = dataset.map(partial(tokenize_and_align, tokenizer=tokenizer, label2id=LABEL2ID))
dataset = dataset.train_test_split(test_size=0.1)
training_args = TrainingArguments(
output_dir=model_dir,
learning_rate=3e-5,
per_device_train_batch_size=8,
num_train_epochs=5,
weight_decay=0.01,
logging_steps=10,
save_strategy="epoch",
)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
data_collator=data_collator,
)
trainer.train()
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)
def _load_model(model_dir: str = MODEL_DIR):
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForTokenClassification.from_pretrained(model_dir)
return tokenizer, model
def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3):
tokenizer, model = _load_model(model_dir)
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits[0]
probs = torch.softmax(logits, dim=-1)
k = min(max(1, top_k), probs.shape[-1])
top_probs, top_ids = torch.topk(probs, k, dim=-1)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
for token, prob_row, id_row in zip(tokens, top_probs, top_ids):
preds = [
f"{ID2LABEL.get(int(label_id), 'RAW_PHRASE')}:{float(score):.4f}"
for label_id, score in zip(id_row, prob_row)
]
print(token, ", ".join(preds))
if __name__ == "__main__":
train()