1
2from celery import Celery
3from datetime import datetime
4import torch
5from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
6from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
7from trl import SFTTrainer
8from datasets import load_dataset
9from app.config import settings
10
11celery_app = Celery("training", broker=settings.redis_url)
12
13@celery_app.task(bind=True, max_retries=0)
14def start_training_job(self, job_id: str):
15 from app.models.database import sync_session, TrainingJob, TrainingStatus
16
17 with sync_session() as session:
18 job = session.get(TrainingJob, job_id)
19 job.status = TrainingStatus.RUNNING
20 job.started_at = datetime.utcnow()
21 session.commit()
22 config = job.config
23 dataset_path = session.get(TrainingDataset, job.dataset_id).file_path
24
25 try:
26 output_dir = f"{settings.model_storage_path}/{job_id}"
27 tokenizer = AutoTokenizer.from_pretrained(config["base_model"])
28 tokenizer.pad_token = tokenizer.eos_token
29
30 if config.get("use_qlora"):
31 bnb_config = BitsAndBytesConfig(
32 load_in_4bit=True,
33 bnb_4bit_quant_type="nf4",
34 bnb_4bit_compute_dtype=torch.bfloat16,
35 )
36 model = AutoModelForCausalLM.from_pretrained(
37 config["base_model"],
38 quantization_config=bnb_config,
39 device_map="auto",
40 )
41 model = prepare_model_for_kbit_training(model)
42 else:
43 model = AutoModelForCausalLM.from_pretrained(
44 config["base_model"],
45 torch_dtype=torch.bfloat16,
46 device_map="auto",
47 )
48
49 lora_config = LoraConfig(
50 task_type=TaskType.CAUSAL_LM,
51 r=config["lora_r"],
52 lora_alpha=config["lora_alpha"],
53 lora_dropout=0.05,
54 target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
55 )
56 model = get_peft_model(model, lora_config)
57
58 dataset = load_dataset("json", data_files=dataset_path, split="train")
59
60 training_args = TrainingArguments(
61 output_dir=output_dir,
62 num_train_epochs=config["num_epochs"],
63 per_device_train_batch_size=config["batch_size"],
64 gradient_accumulation_steps=config["gradient_accumulation"],
65 learning_rate=config["learning_rate"],
66 warmup_ratio=0.1,
67 lr_scheduler_type="cosine",
68 bf16=True,
69 gradient_checkpointing=True,
70 save_strategy="epoch",
71 logging_steps=10,
72 report_to="none",
73 )
74
75 trainer = SFTTrainer(
76 model=model,
77 train_dataset=dataset,
78 args=training_args,
79 tokenizer=tokenizer,
80 max_seq_length=config.get("max_seq_length", 2048),
81 )
82
83 trainer.train()
84 trainer.save_model(output_dir)
85
86 final_metrics = {
87 "final_loss": trainer.state.log_history[-1].get("train_loss"),
88 "total_steps": trainer.state.global_step,
89 "epochs_completed": config["num_epochs"],
90 }
91
92 with sync_session() as session:
93 job = session.get(TrainingJob, job_id)
94 job.status = TrainingStatus.COMPLETED
95 job.completed_at = datetime.utcnow()
96 job.output_path = output_dir
97 job.metrics = final_metrics
98 session.commit()
99
100 except Exception as e:
101 with sync_session() as session:
102 job = session.get(TrainingJob, job_id)
103 job.status = TrainingStatus.FAILED
104 job.completed_at = datetime.utcnow()
105 job.error_message = str(e)
106 session.commit()
107 raise
108