This commit is contained in:
2026-04-08 18:58:15 +05:30
parent d227f37cc2
commit fe7b159ff6
3 changed files with 38 additions and 34 deletions

View File

@@ -1,5 +1,4 @@
import json
import re
from functools import partial
from pathlib import Path
@@ -32,16 +31,7 @@ def _normalize_label(label: str) -> str:
raise ValueError(f"Unknown label: {label}")
def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]:
tokens = []
spans = []
for match in re.finditer(r"\S+", text):
tokens.append(match.group(0))
spans.append(match.span())
return tokens, spans
def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
def _extract_spans(task: dict) -> list[dict]:
annotations = task.get("annotations") or []
if not annotations:
return []
@@ -54,7 +44,13 @@ def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
labels = value.get("labels") or []
if start is None or end is None or not labels:
continue
spans.append((int(start), int(end), _normalize_label(labels[0])))
spans.append(
{
"start": int(start),
"end": int(end),
"label": _normalize_label(labels[0]),
}
)
return spans
@@ -67,36 +63,42 @@ def parse_annotated_dataset(path: str) -> list[dict]:
for task in data:
text = task["data"]["text"]
spans = _extract_spans(task)
tokens, token_spans = _text_to_tokens(text)
labels = []
for token_start, token_end in token_spans:
assigned = "RAW_PHRASE"
for span_start, span_end, span_label in spans:
if token_start < span_end and token_end > span_start:
assigned = span_label
break
labels.append(assigned)
examples.append({"tokens": tokens, "labels": labels})
examples.append({"text": text, "spans": spans})
return examples
def tokenize_and_align(example, tokenizer, label2id):
tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
word_ids = tokenized.word_ids()
labels = []
prev = None
def _label_for_offset(start: int, end: int, spans: list[dict]) -> str:
best_label = "RAW_PHRASE"
best_overlap = 0
for span in spans:
span_start = span["start"]
span_end = span["end"]
span_label = span["label"]
overlap = min(end, span_end) - max(start, span_start)
if overlap > best_overlap:
best_overlap = overlap
best_label = span_label
return best_label
for word_id in word_ids:
if word_id is None:
def tokenize_and_align(example, tokenizer, label2id):
tokenized = tokenizer(
example["text"],
truncation=True,
return_offsets_mapping=True,
)
labels = []
for start, end in tokenized["offset_mapping"]:
if start == end:
labels.append(-100)
elif word_id != prev:
labels.append(label2id[example["labels"][word_id]])
else:
labels.append(-100)
prev = word_id
label_name = _label_for_offset(start, end, example["spans"])
labels.append(label2id[label_name])
tokenized["labels"] = labels
tokenized.pop("offset_mapping")
return tokenized
@@ -144,7 +146,7 @@ def _load_model(model_dir: str = MODEL_DIR):
def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3):
tokenizer, model = _load_model(model_dir)
inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True)
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits[0]