refactor: engine.py
This commit is contained in:
32
llm_runtime.py
Normal file
32
llm_runtime.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
DEFAULT_MODEL_PATH = "/home/sortedcord/.cache/huggingface/hub/models--ggml-org--gemma-4-E4B-it-GGUF/snapshots/6b352c53e1d2e4bb974d9f8cafcf85887c224219/gemma-4-e4b-it-Q4_K_M.gguf"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
|
||||
llm = ChatLlamaCpp(
|
||||
temperature=0.2,
|
||||
model_path=DEFAULT_MODEL_PATH,
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=8,
|
||||
max_tokens=512,
|
||||
n_threads=multiprocessing.cpu_count() - 1,
|
||||
repeat_penalty=1.5,
|
||||
)
|
||||
|
||||
|
||||
def _format_prompt(messages):
|
||||
formatted = []
|
||||
for message in messages:
|
||||
formatted.append(f"{message.__class__.__name__}:\n{message.content}")
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
|
||||
def _normalize_llm_output(text: str) -> str:
|
||||
return text.replace("\r", "").replace("\n", "").strip()
|
||||
Reference in New Issue
Block a user