This is the comparision between merge model with/without embed_tokens
We will have 4 type:
1. stealth-v1 (trinity embed layers, with `embed_tokens`)
2. stealth-v1-sft (trinity embed layers, with `embed_tokens`, lora on `embed layers`)
3. stealth-v1-no-embed (trinity embed layers, no `embed_tokens` merge)
4. stealth-mistral (mistral embed layer, no `embed_tokens` merge)
**NOTE**
- **Tokenizer Source Compatibility**: The `tokenizer_source` is compatible with all merge methods, but linear merging is applied to `lm_head/embed_tokens`.
- **Embedding Slerp Parameter**: For two-model merges, `embed_slerp` can be set to `true` for spherical linear interpolation (SLERP) instead of linear merging.
- **Fallback Behavior Without Tokenizer Source**: If `tokenizer_source` is not set, mergekit reverts to legacy behavior:
- The tokenizer from the base model (or the first model in a merge without a specified base model) is copied to the output directory.
- The parameter matrices for `lm_head/embed_tokens` are truncated to the smallest size in the merge, typically corresponding to the base model's tokenizer.
# Stealth v1 - WITH `embed_tokens`
## Merge
```yaml
slices:
- sources:
- model: jan-hq/trinity-v1.2
layer_range: [0, 32]
- model: WizardLM/WizardMath-7B-V1.1
layer_range: [0, 32]
merge_method: slerp
base_model: jan-hq/trinity-v1.2
parameters:
t:
- filter: lm_head
value: [0.55]
- filter: embed_tokens
value: [0.7]
- filter: self_attn
value: [0.65, 0.35]
- filter: mlp
value: [0.35, 0.65]
- filter: layernorm
value: [0.4, 0.6]
- filter: modelnorm
value: [0.6]
- value: 0.5 # fallback for rest of tensors
dtype: bfloat16
```
## SFT without focus embed layers
```yaml
# Model arguments
model_name_or_path: jan-hq/stealth-v1
torch_dtype: auto
use_flash_attention_2: true
# LoRA arguments
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
# Data training arguments
dataset_mixer:
jan-hq/nectar_sft_binarized_subset: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: stealth-v1-adapter
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/stealth-v1-adapter
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
save_total_limit: null
seed: 42
neftune_noise_alpha: 5
```
## Stealth v1 WITH embed layers focus
```yaml
# Model arguments
model_name_or_path: jan-hq/stealth-v1
torch_dtype: auto
use_flash_attention_2: true
# LoRA arguments
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
lora_modules_to_save:
- embed_tokens
# Data training arguments
dataset_mixer:
jan-hq/nectar_sft_binarized_subset: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: stealth-v1-adapter
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/stealth-v1-adapter
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
save_total_limit: null
seed: 42
neftune_noise_alpha: 5
```
# Stealth no embed v1 - WITHOUT `embed_tokens`
## Merge
```yaml
slices:
- sources:
- model: jan-hq/trinity-v1.2
layer_range: [0, 32]
- model: WizardLM/WizardMath-7B-V1.1
layer_range: [0, 32]
merge_method: slerp
base_model: jan-hq/trinity-v1.2
parameters:
t:
- filter: lm_head
value: [0.55]
- filter: self_attn
value: [0.65, 0.35]
- filter: mlp
value: [0.35, 0.65]
- filter: layernorm
value: [0.4, 0.6]
- filter: modelnorm
value: [0.6]
- value: 0.5 # fallback for rest of tensors
dtype: bfloat16
```
## SFT
```yaml
# Model arguments
# Model arguments
model_name_or_path: jan-hq/stealth-v1-no_embed-v1
torch_dtype: auto
use_flash_attention_2: true
# LoRA arguments
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
# Data training arguments
dataset_mixer:
jan-hq/nectar_sft_binarized_subset: 1.0
dataset_splits:
- train
- test
preprocessing_num_workers: 12
# SFT trainer config
bf16: true
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: stealth-v1-no_embed-adapter
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/stealth-v1-no_embed-adapter
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
save_strategy: "no"
save_total_limit: null
seed: 42
neftune_noise_alpha: 5
```
# Stealth v1 mistral embed - WITHOUT `embed_tokens`
## Merge
```
slices:
- sources:
- model: jan-hq/trinity-v1.2
layer_range: [0, 32]
- model: WizardLM/WizardMath-7B-V1.1
layer_range: [0, 32]
merge_method: slerp
base_model: mistralai/Mistral-7B-Instruct-v0.2
parameters:
t:
- filter: lm_head
value: [0.55]
- filter: self_attn
value: [0.65, 0.35]
- filter: mlp
value: [0.35, 0.65]
- filter: layernorm
value: [0.4, 0.6]
- filter: modelnorm
value: [0.6]
- value: 0.5 # fallback for rest of tensors
dtype: bfloat16slices:
```
# Result
1. GSM8k
|Model|Tasks|Version| Filter |n-shot| Metric |Value| |Stderr|
|----|-----|-------|----------|-----:|-----------|----:|---|-----:|
|stealth-v1|gsm8k|Yaml |get-answer| 5|exact_match|0.768|± |0.0116|
|stealth-v1-no_embed|gsm8k|Yaml |get-answer| 5|exact_match|0.7718|± |0.0116|
|stealth-v1-mistral-embed|gsm8k|Yaml |get-answer| 5|exact_match|0.2745|± |0.0123|
|stealth-v1-math-embed|gsm8k|Yaml |get-answer| 5|exact_match|0.765|± |0.0117|
|stealth-v1-math-sft|gsm8k|Yaml |get-answer| 5|exact_match|0.7665|± |0.0117|
|stealth-v1.1|gsm8k|Yaml |get-answer| 5|exact_match|0.7665|± |0.0117|
Peak GPU: 23.7GB for 8 batch
AVG time:
- 4090: 11 mins
- 3090: 40 mins
2. TruthfulQA
Peak GPU: 23.7GB for 8 batch
3. Timetaken for SFT
r = 16, alpha = 32, 700 samples full layers
|GPU|Time|
|---|---|
|4090|16 mins|
|3090|110 mins|
- With r=256, alpha =512, 29000 samples, full layer + deepspeed
|GPU|Time|
|---|---|
|4090|600 mins|