nvan15's picture
Batch upload part 2
6bb0065 verified
#
import sys
#print('sys.path: ___ ', sys.path)
#print(f"Current Python Executable: {sys.executable}")
### dynamo warning
import warnings
# Ignore FutureWarning: prims_common.check, Online Softmax
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
import torch
torch.backends.cuda.matmul.fp32_precision = 'tf32'
# import wandb
import os
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"]="1"
os.environ["MKL_NUM_THREADS"]="1"
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"PyTorch built with CUDA version: {torch.version.cuda}")
import yaml
#from peft import LoraConfig, get_peft_model_state_dict
from torch.utils.data import DataLoader
import time
from datetime import datetime
import math
from typing import List, Tuple
# import prodigyopt
###
import copy
from dataclasses import field, dataclass, asdict
from typing import Sequence, Literal, Dict
import transformers
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from transformers import Trainer
from transformers.modeling_utils import *
from transformers.trainer import _is_peft_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.data.data_collator import DataCollator
from transformers.training_args import TrainingArguments
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from torch.utils.data import Dataset, IterableDataset
from datasets import load_dataset
##
#from ..pipeline.flux_omini import transformer_forward, encode_images
# from ...omini.rotation import RotationTuner, RotationConfig
from rpeft.rotation import RotationTuner, RotationConfig
from rpeft import get_peft_model, PeftModel
from .config import MainConfig, convert_to_trainer_args
import pyrallis
from omegaconf import OmegaConf
import torch.optim as optim
import wandb
from torch.nn.utils.rnn import pad_sequence
IGNORE_INDEX = -100
PROMPT = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
def get_rank():
try:
rank = int(os.environ.get("LOCAL_RANK"))
except:
rank = 0
return rank
def get_config():
config_path = os.environ.get("OMINI_CONFIG")
assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def init_wandb(wandb_config, run_name):
import wandb
try:
assert os.environ.get("WANDB_API_KEY") is not None
wandb.init(
project=wandb_config["project"],
name=run_name,
config={},
)
except Exception as e:
print("Failed to initialize WanDB:", e)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
# @dataclass
# class DataCollatorForSupervisedDataset():
# """Collate examples for supervised fine-tuning."""
# tokenizer: transformers.PreTrainedTokenizer
# max_length: int = field(default=512)
# mode: str = field(default="fixed") # dynamic -> dynamo
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# if self.mode == 'dynamic':
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
# input_ids = [torch.tensor(x) for x in input_ids]
# input_ids = torch.nn.utils.rnn.pad_sequence(
# input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
# )
# labels = [torch.tensor(x) for x in labels]
# labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
# return dict(
# input_ids=input_ids,
# labels=labels,
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
# )
# elif self.mode == 'fixed':
# input_ids = [torch.tensor(x["input_ids"][:self.max_length]) for x in instances]
# input_ids = torch.stack([
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=self.tokenizer.pad_token_id)
# for x in input_ids
# ])
# # Labels
# labels = [torch.tensor(x["labels"][:self.max_length]) for x in instances]
# labels = torch.stack([
# torch.nn.functional.pad(x, (0, self.max_length - x.size(0)), value=IGNORE_INDEX)
# for x in labels
# ])
# return dict(
# input_ids=input_ids,
# labels=labels,
# attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
# )
# else:
# raise NotImplementedError
# @dataclass
# class DataCollatorForSupervisedDataset(object):
# tokenizer: transformers.PreTrainedTokenizer
# max_length: int = field(default=512)
# mode: str = field(default="fixed") # "dynamic" or "fixed"
# def _pad_to_length(self, tensors: Sequence[torch.Tensor], pad_value: int, target_len: int):
# """Pad a list of 1D tensors to target_len (int) and stack -> (B, target_len)."""
# batch_size = len(tensors)
# out = torch.full((batch_size, target_len), pad_value, dtype=tensors[0].dtype)
# for i, t in enumerate(tensors):
# L = min(t.size(0), target_len)
# out[i, :L] = t[:L]
# return out
# def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# # Collect raw sequences (lists or tensors)
# input_seqs = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
# label_seqs = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
# if self.mode == "dynamic":
# # pad to the max length present in this batch (<= self.max_length)
# batch_max_len = min(max([s.size(0) for s in input_seqs]), self.max_length)
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=batch_max_len)
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=batch_max_len)
# elif self.mode == "fixed":
# # always pad/truncate to self.max_length
# input_ids = self._pad_to_length(input_seqs, pad_value=self.tokenizer.pad_token_id, target_len=self.max_length)
# labels = self._pad_to_length(label_seqs, pad_value=IGNORE_INDEX, target_len=self.max_length)
# else:
# raise NotImplementedError(f"Unknown mode: {self.mode}")
# attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
# return {
# "input_ids": input_ids,
# "labels": labels,
# "attention_mask": attention_mask
# }
@dataclass
class DataCollatorForSupervisedDataset():
tokenizer: transformers.PreTrainedTokenizer
max_length: int = field(default=512)
mode: str = field(default="fixed") # "dynamic" or "fixed"
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# Extract inputs and labels
# Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
# 1. Determine padding logic
if self.mode == "dynamic":
# Dynamic padding: pad to the longest sequence in the batch
# But cap it at self.max_length to prevent OOM
batch_max_len = max([len(x) for x in input_ids_list])
target_len = min(batch_max_len, self.max_length)
else:
# Fixed padding: always pad to max_length
target_len = self.max_length
# 2. Helper to pad and truncate
def pad_and_truncate(tensors, padding_value):
# First, pad everything using PyTorch's optimized utility (batch_first=True)
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
# Handle truncation/extending to exact target_len
curr_len = padded.shape[1]
if curr_len > target_len:
# Truncate if too long (rare if filtered beforehand)
return padded[:, :target_len]
elif curr_len < target_len:
# Pad more if shorter than target_len (happens in fixed mode)
diff = target_len - curr_len
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
return torch.cat([padded, padding], dim=1)
else:
return padded
# 3. Apply padding
# Critical: tokenizer.pad_token_id must NOT be None here
if self.tokenizer.pad_token_id is None:
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
labels = pad_and_truncate(labels_list, IGNORE_INDEX)
# 4. Create Attention Mask explicitly
# .ne() creates Bools, .long() casts to 0s and 1s for compatibility
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask
}
def train_tokenize_function(examples, tokenizer, query, response):
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
data_dict = preprocess(sources, targets, tokenizer)
return data_dict
### Trainer
def default_worker_init_fn(worker_id):
# mỗi worker chỉ 1 thread cho BLAS
try:
import numpy as _np
except Exception:
_np = None
torch.set_num_threads(1)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
# Optional: bind CPU affinity per worker to avoid contention (NUMA-aware)
try:
cpu_count = os.cpu_count() or 1
# chia đều CPU cho workers
num_workers = getattr(torch.utils.data, "_num_workers", None)
# fallback: if not available, compute from environment variable or pass externally
# We'll do a simple round-robin assignment using worker_id
# assign a small mask of cores to this worker (e.g., chunk size 4)
chunk = max(1, cpu_count // max(1, min(64, cpu_count)))
start = (worker_id * chunk) % cpu_count
end = start + chunk
mask = set(range(start, min(end, cpu_count)))
try:
os.sched_setaffinity(0, mask)
except Exception:
pass
except Exception:
pass
def set_seed(seed: int):
# random.seed(seed)
# np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
transformers.set_seed(seed)
@pyrallis.wrap()
def main(mainCfg: MainConfig):
#mainCfg = get_config()
#print(mainCfg)
print('='*120)
# print(OmegaConf.to_yaml(mainCfg))
# print('-'*40)
#
# print((training_args))
set_seed(mainCfg.seed)
training_args = convert_to_trainer_args(mainCfg)
# wandb
ENTITY = "nvan-13-korea-university"
PROJECT = os.environ.get("WANDB_PROJECT")
api = wandb.Api()
try:
runs_list = api.runs(f"{ENTITY}/{PROJECT}")
next_run_num = len(runs_list) + 1
except Exception as e:
next_run_num = 1
training_args.run_name = f'[{next_run_num}]lr={mainCfg.trainer_args.learning_rate:.1e},b={mainCfg.trainer_args.per_device_train_batch_size},'\
f'n={mainCfg.rotation_adapter_config.num_rotations},r={mainCfg.rotation_adapter_config.r},'\
f'init={mainCfg.run_text}'
# training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
# print('-'*40)
# print(training_args.to_json_string())
# exit()
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
device_map="auto", low_cpu_mem_usage=True,
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
attn_implementation="sdpa",
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE", DEVICE)
# for name, param in model.named_parameters():
# if 'q_proj' in name and 'layers.5' in name:
# print(f"Name: {name} | {param.shape} ")
# print(f"Name (pretrained): {name} | {param.shape} | {param.data[0:5,0:5]}")
# print('model', model)
# exit()
total_params_now = sum(p.numel() for p in model.parameters())
print(f'#params of the pretrained model, {total_params_now:,}')
# print(model)
if mainCfg.model.adapter_path is not None:
print('___ Loading from: ', mainCfg.model.adapter_path)
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
elif mainCfg.rotation_adapter_config.r is not None:
rotation_adapter_config = asdict(mainCfg.rotation_adapter_config)
# rotation_adapter_config[peft_type]
for adapter_name in mainCfg.data.adapter_names:
rotation_config = RotationConfig(**rotation_adapter_config)
model = get_peft_model(model, rotation_config, adapter_name=adapter_name)
# model.set_adapter(adapter_name)
else:
print("Full Parameter Fine-Tuning")
model = model.to(DEVICE)
# print('model', model)
model.print_trainable_parameters()
# print("Program starts")
# time.sleep(300)
# exit()
# for name, param in model.named_parameters():
# if 'q_proj' in name and 'rotation' in name and 'layers.5' in name:
# print(f"Name: {name} | {param.shape} ")
# print(f"Name (pretrained): {name} | {param.shape} ")
# X = param.data
# print('model', type(model), X.shape)
# visualize_value_distribution(X)
# exit()
rotation_layers = filter(
lambda p: p.requires_grad, model.parameters()
)
tokenizer = AutoTokenizer.from_pretrained(
mainCfg.model.model_name,
model_max_length=mainCfg.model.model_max_seq_length,
padding_side="right",
use_fast=True,
)
if tokenizer.pad_token is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token = tokenizer.unk_token
print("Set PAD token to UNK token.")
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
print("Set PAD token to EOS token.")
if model is not None:
model.config.pad_token_id = tokenizer.pad_token_id
if model.config.pad_token_id != tokenizer.pad_token_id:
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
# local MetaMathQA-40K
raw_datasets = load_dataset("json", data_files=mainCfg.data.path, split=mainCfg.data.dataset_split)
#raw_train_datasets = load_dataset("MetaMathQA-40K", split=mainCfg.data.dataset_split)
# print('raw', type(raw_train_datasets), len(raw_train_datasets))
# split a single set
split_ratio = mainCfg.data.split_ratio
split_data = raw_datasets.train_test_split(test_size=split_ratio, seed=42)
raw_train_datasets = split_data['train']
raw_valid_datasets = split_data['test']
train_dataset = raw_train_datasets.map(
train_tokenize_function,
batched=True,
batch_size=30000,
num_proc=32,
remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
"response": mainCfg.data.dataset_field[1]}
)
valid_dataset = raw_valid_datasets.map(
train_tokenize_function,
batched=True,
batch_size=30000,
num_proc=32,
remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
fn_kwargs={"tokenizer": tokenizer, "query": mainCfg.data.dataset_field[0],
"response": mainCfg.data.dataset_field[1]}
)
print('- dataset size: ', len(valid_dataset), len(train_dataset))
# print('dataset', type(train_dataset))
# print('process', len(train_dataset))
# print(f"Sample features: {train_dataset.column_names}, {train_dataset.num_rows}")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
#mode=mainCfg.model.data_collator_mode,
)
data_module = dict(train_dataset=train_dataset, data_collator=data_collator, eval_dataset=valid_dataset)
optimizer = optim.AdamW(
rotation_layers,
lr=mainCfg.trainer_args.learning_rate, #
eps=1e-8
)
# print('model x', model)
start_time = datetime.now()
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
trainer = MyTrainer(model=model, processing_class=tokenizer,
lamda=mainCfg.model.lambda_reg,
optimizers=(optimizer, None),
args=training_args, **data_module)
model.config.use_cache = False
# now = time.time()
# for i in range(20):
# next(iter(trainer.get_train_dataloader()))
# print('time', time.time()-now)
# now = time.time()
# dl = trainer.get_train_dataloader()
# t0 = time.time()
# for i, batch in enumerate(dl):
# if i==20: break
# print("time / 20 batches =", time.time() - t0)
# exit()
# model2 = model.merge_and_unload()
# results2 = trainer2.evaluate()
# print('results2: ', results2)
# exit()
start_time = datetime.now()
trainer.train()
end_time = datetime.now()
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
# Save Model (Includes Adapter weights & Config)
# trainer.save_model(os.path.join(training_args.output_dir, 'ft'))
# Save Tokenizer
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
# Save Training State (Metrics & Logs)
trainer.save_state()
# save peft_config. Or model.base_model.peft_config['default']
model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
# the easiest way
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
return
class MyTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
#run_name: Optional[str] = None,
#report_to: Optional[Union[str, list[str]]] = None,
# project
lamda: float = 1e-4
):
super().__init__(model=model, args=args, data_collator=data_collator,
train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class,
model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks,
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
#run_name=run_name, report_to=report_to
)
self.lamda = lamda
# def compute_loss(self, model, inputs, return_outputs=False,
# num_items_in_batch: Optional[torch.Tensor] = None,):
# """
# How the loss is computed by Trainer. By default, all models return the loss in the first element.
# Subclass and override for custom behavior.
# """
# if self.label_smoother is not None and "labels" in inputs:
# labels = inputs.pop("labels")
# else:
# labels = None
# if self.model_accepts_loss_kwargs:
# kwargs = {}
# if num_items_in_batch is not None:
# kwargs["num_items_in_batch"] = num_items_in_batch
# inputs = {**inputs, **kwargs}
# outputs = model(**inputs)
# # Save past state if it exists
# # TODO: this needs to be fixed and made cleaner later.
# if self.args.past_index >= 0:
# self._past = outputs[self.args.past_index]
# if labels is not None:
# unwrapped_model = unwrap_model(model)
# if _is_peft_model(unwrapped_model):
# model_name = unwrapped_model.base_model.model._get_name()
# else:
# model_name = unwrapped_model._get_name()
# if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
# loss = self.label_smoother(outputs, labels, shift_labels=True)
# else:
# loss = self.label_smoother(outputs, labels)
# else:
# if isinstance(outputs, dict) and "loss" not in outputs:
# raise ValueError(
# "The model did not return a loss from the inputs, only the following keys: "
# f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
# )
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
# # ------------------------------------------------------------------------------
# # for name, param in model.named_parameters():
# # if 'oft_r' in name:
# # device = param.device
# # householder_U_norm = param / param.norm(dim=0)
# # orth_loss = torch.norm(
# # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
# # print(self.lamda)
# # loss = loss + self.lamda * orth_loss.to(loss.device)
# # ------------------------------------------------------------------------------
# return (loss, outputs) if return_outputs else loss
def get_train_dataloader(self):
# get dataset & sampler from super
train_dataset = self.train_dataset
sampler = self._get_train_sampler()
# compute effective batch size per step (HF has some routines; we use per_device_train_batch_size)
batch_size = self.args.train_batch_size if hasattr(self.args, "train_batch_size") else self.args.per_device_train_batch_size
# recommended num_workers: start moderate (16), you can tune upward
num_workers = getattr(self.args, "dataloader_num_workers", 16)
pin_memory = getattr(self.args, "dataloader_pin_memory", True)
prefetch_factor = getattr(self.args, "dataloader_prefetch_factor", 2)
persistent_workers = getattr(self.args, "dataloader_persistent_workers", True)
return DataLoader(
train_dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last if hasattr(self.args, "dataloader_drop_last") else False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
worker_init_fn=default_worker_init_fn,
)
if __name__ == "__main__":
main()