fixed
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
model/*
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
70
mini_lm.py
70
mini_lm.py
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
@@ -32,16 +31,7 @@ def _normalize_label(label: str) -> str:
|
||||
raise ValueError(f"Unknown label: {label}")
|
||||
|
||||
|
||||
def _text_to_tokens(text: str) -> tuple[list[str], list[tuple[int, int]]]:
|
||||
tokens = []
|
||||
spans = []
|
||||
for match in re.finditer(r"\S+", text):
|
||||
tokens.append(match.group(0))
|
||||
spans.append(match.span())
|
||||
return tokens, spans
|
||||
|
||||
|
||||
def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
|
||||
def _extract_spans(task: dict) -> list[dict]:
|
||||
annotations = task.get("annotations") or []
|
||||
if not annotations:
|
||||
return []
|
||||
@@ -54,7 +44,13 @@ def _extract_spans(task: dict) -> list[tuple[int, int, str]]:
|
||||
labels = value.get("labels") or []
|
||||
if start is None or end is None or not labels:
|
||||
continue
|
||||
spans.append((int(start), int(end), _normalize_label(labels[0])))
|
||||
spans.append(
|
||||
{
|
||||
"start": int(start),
|
||||
"end": int(end),
|
||||
"label": _normalize_label(labels[0]),
|
||||
}
|
||||
)
|
||||
return spans
|
||||
|
||||
|
||||
@@ -67,36 +63,42 @@ def parse_annotated_dataset(path: str) -> list[dict]:
|
||||
for task in data:
|
||||
text = task["data"]["text"]
|
||||
spans = _extract_spans(task)
|
||||
tokens, token_spans = _text_to_tokens(text)
|
||||
labels = []
|
||||
for token_start, token_end in token_spans:
|
||||
assigned = "RAW_PHRASE"
|
||||
for span_start, span_end, span_label in spans:
|
||||
if token_start < span_end and token_end > span_start:
|
||||
assigned = span_label
|
||||
break
|
||||
labels.append(assigned)
|
||||
examples.append({"tokens": tokens, "labels": labels})
|
||||
examples.append({"text": text, "spans": spans})
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def tokenize_and_align(example, tokenizer, label2id):
|
||||
tokenized = tokenizer(example["tokens"], is_split_into_words=True, truncation=True)
|
||||
word_ids = tokenized.word_ids()
|
||||
labels = []
|
||||
prev = None
|
||||
def _label_for_offset(start: int, end: int, spans: list[dict]) -> str:
|
||||
best_label = "RAW_PHRASE"
|
||||
best_overlap = 0
|
||||
for span in spans:
|
||||
span_start = span["start"]
|
||||
span_end = span["end"]
|
||||
span_label = span["label"]
|
||||
overlap = min(end, span_end) - max(start, span_start)
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_label = span_label
|
||||
return best_label
|
||||
|
||||
for word_id in word_ids:
|
||||
if word_id is None:
|
||||
|
||||
def tokenize_and_align(example, tokenizer, label2id):
|
||||
tokenized = tokenizer(
|
||||
example["text"],
|
||||
truncation=True,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
labels = []
|
||||
|
||||
for start, end in tokenized["offset_mapping"]:
|
||||
if start == end:
|
||||
labels.append(-100)
|
||||
elif word_id != prev:
|
||||
labels.append(label2id[example["labels"][word_id]])
|
||||
else:
|
||||
labels.append(-100)
|
||||
prev = word_id
|
||||
label_name = _label_for_offset(start, end, example["spans"])
|
||||
labels.append(label2id[label_name])
|
||||
|
||||
tokenized["labels"] = labels
|
||||
tokenized.pop("offset_mapping")
|
||||
return tokenized
|
||||
|
||||
|
||||
@@ -144,7 +146,7 @@ def _load_model(model_dir: str = MODEL_DIR):
|
||||
|
||||
def predict(text: str, model_dir: str = MODEL_DIR, top_k: int = 3):
|
||||
tokenizer, model = _load_model(model_dir)
|
||||
inputs = tokenizer(text.split(), return_tensors="pt", is_split_into_words=True)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
|
||||
logits = outputs.logits[0]
|
||||
|
||||
Reference in New Issue
Block a user