# 🤗 XLSR Wav2Vec2 Fine-Tuning Week This document contains all the discussions, ideas, links, resources and questions from the slack channel #xlsr-fine-tuning-week. It is slightly unorganized. The organized version will be added to the Fine-Tuning Week Doc after today. ## Examples and Instructions from 🤗 * Fine-Tuning Week Doc : https://github.com/huggingface/transformers/blob/master/examples/research_projects/wav2vec2/FINE_TUNE_XLSR_WAV2VEC2.md * Fine-Tuning Wav2Vec2 For Turkish : https://discuss.huggingface.co/t/turkish-asr-fine-tuning-wav2vec2/4556 * Fine-Tuning Wav2Vec2 on Timit: https://huggingface.co/blog/fine-tune-wav2vec2-english * Combining local datasets with official dataset: https://discuss.huggingface.co/t/how-to-combine-local-data-files-with-an-official-dataset/4685 * Training with TPUs using PyTorch: https://huggingface.co/blog/pytorch-xla * Video on Fine-tuning XLSR-Wav2Vec2 for low-resource ASR: https://youtu.be/UynYn2C3tI0 * Sharing your model via Google Colab: https://huggingface.co/transformers/model_sharing.html#workflow-in-a-colab-notebook * An example of a model card: https://huggingface.co/patrickvonplaten/wav2vec2-large-xlsr-turkish-demo/blob/main/README.md * W&B integration documentation with HF: https://docs.wandb.ai/integrations/huggingface * W&B Colab: https://colab.research.google.com/drive/1oqMZAtRYkurKqePGxFKpnU9n6L8Qk9hM?usp=sharing * Video on using OVH for fine-tuning XLSR Wav2Vec2: https://www.youtube.com/watch?v=2hlkWAESMk8 * Video 2 on using OVH for fine-tuning XLSR Wav2Vec2: https://www.youtube.com/watch?v=z3fk5h9Ce0g&t=2s * Enable and disable caching in datasets: https://huggingface.co/docs/datasets/processing.html?highlight=cache#enable-or-disable-caching * One important note -> the wer that you see during training as the validation wer is often significantly worse than the actual "Test WER" because during training the data is processed in a way that is ideal for the model to learn better (appending the " " at the end of every sample), but not the correct preprocessing when running the actual eval. So it's important to always run the eval script (that can be copy-pasted from the template) once and use this final WER as the result. ## Issues * Errors * CUDA Out of Memory/Insufficient GPU memory error: * Reduce batch size :) * Try filtering examples greater than a particular sequence: See this PR : https://github.com/huggingface/transformers/pull/10581 * Reduce evaluation steps. * trainer.train() gives CUDNN_STATUS_NOT_INITIALIZED * Fix torch/cuda version. Ensure that they are compatible. * The following warning appears during training: > UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate * It’s just a warning which can be ignored. The training happens correctly. * `trainer.train()` leads to “blank must be in label range” * Try using a delimiter which is not present in the dataset text. * Caching bug in HuggingFace datasets: In some cases, the cached version is used for a dataset even though the dataset has changed during `.map()`: * Reset the cache: https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cache#datasets.Dataset.cleanup_cache_files * OR use `load_from_cache_file=False` * OVH Issues * The job fails during the pre-processing: * A possible reason for this is that your local/ephemeral storage is exceeding the limit. This happens due to HuggingFace datasets’s caching. When you are using the `finetune.sh` script, specify the cache dir to be the volume which you have mounted. If you are using the notebook, then you have to specify the ‘cache_file_name’ everywhere you use `map`. Also, don’t forget to mount both the object storage containers via the CLI. * How to fix is explained here : https://www.youtube.com/watch?v=z3fk5h9Ce0g * Misc: * Unable to find huggingface-cli: * install huggingface-cli 2. search using cli command find -> huggingface-cli 3. after finding , add it to global path (smt. like export PATH=...:$PATH * Unable to run `git lfs install`: https://stackoverflow.com/questions/48734119/git-lfs-is-not-a-git-command-unclear * Did anyone get the WER metric bigger than 100%? :confused: I tried to run the evaluation script on my checkpoint model (which shown ~65% WER when training) but I got ~114%. To make sure that I didn’t use a broken script, I ran it on a random model from the huggingface hub and I got a correct result. Am I missing anything or this is some sort of bug? * a) Make sure that all files are exactly the same as during training (make sure the output files of processor.save_pretrained("./") and the pytorch_model.bin & config.json that is found in each checkpoint-... directory is used for eval * Make sure that the data preprocessing is exactly the same ## Questions 1. Training loss appears to be `nan` 1. If you encounter nan it means some of your audio files have a problem. Mostly it becomes much more problematic when you have a large audio with a very small text associated to it or a very small audio with a very large text associated to it. Wav2vec2 has an inbuilt loss scale that can handle nan loss. The loss scale starts from 128 and goes down by a factor of 2 every time there is a overflow detected to a minimum of 0.06. 2. Normally the model is not able to recover from nan training loss, so I'd be surprised if the eval score keeps being good. `learning_rate` and the `dropout values` are probably the most important hyper-parameter to tune if training loss goes nan 3. See this post. 4. For everyone struggling with the nan loss: try setting ctc_zero_infinity=True in the model config to avoid inf NLL for misaligned or poorly annotated examples. 2. `trainer.train()` takes very long time to cache the dataset 1. Try using 1 `dataloader_num_workers` 3. The audio file is not ‘.wav’ but ‘.mp3’: 1. usually `torchaudio.load(...)` works well with .mp3 4. I cleaned the common voice dataset by editing tsv files by hand. How do I change the code to use those new labels? Deleting caches doesn't work. 1. So you should add a cache_dir="./common_voice/" assuming that in the dir common_voice you have “cv-corpus-6.1-2020-12-11/{your_lang_id}" 5. Hi. Just to make sure I will do the right thing when I want to upload my trained checkpoint: The preprocessor_config.json file under my google drive with name wav2vec2-large-xlsr-{my favorite name} should be included and NOT the same filename which I see under the wav2vec2-large-xlsr-{my favorite name}/checkpoint-XXXX? 1. Yes. The preprocessor_config.json saved using processor.save_pretrained() should be used. 6. The notebook crashes on OVH during the eval for large datasets: 1. Try using the chunked WER from here : https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/6 2. Keep the eval_accumulation_steps to a good value. Without it, default is loading all data into GPU. 7. what might be the reason for the trainer to have much less training steps then what I calculate by (steps_per_epoch*num_train_epochs) < trainer's total_steps, steps_per_epoch=trainset_size/batch 1. Edit solved: number of gpus! batch (was per gpu!) 8. did anyone solved the problem of OOM when using multi gpus? 1. `export CUDA_VISIBLE_DEVICES=1,2,3` 9. The load_dataset takes a lot of time for large datasets even after downloading is complete: 1. This can be due to the extraction taking place: Check using `du -h cache_dir`. 2. Alternatively, download and extract using `aria2c`: https://aria2.github.io/ ## Training Ideas * Persistent Storage for fast re-loading of dataset: https://gist.github.com/flozi00/c8ce65dca6717797202a4734099dfd26 * Based on the validation loss, the model is overfitting, but WER increases: On small datasets it's completely normal and sometimes even preferred if the model overfits to the data. The important metric to monitor is the WER on the validation set here and if this goes down nicely as it does in your script, then this is a very good sign! The reason why the validation loss tends to go down and up again is often the following on small datasets: In the beginning, the model is not very confident about its predictions and gives every output class an equal probability of say 1/vocab_size ~ 0.025. The model will have many **wrong** predictions on the validation set, but those wrong predictions don't drag the validation loss down very much, e.g. log(0.025) ~ -3.7 . Now the more you train the model the more confident the model becomes in its predictions and gives higher probabilities to just a few classes and almost 0 probability to the other classes. What happens often now is that the model will have much fewer wrong predictions, but when the model is wrong, it will be very confident about it's wrong prediction and thus drag down the loss very heavily, e.g. the model gives the correct answer only 1e-10 probability so that the loss for the classification is then much higher ~ -23 . This means a confident wrong prediction can be as bad as ~10 not-so-confident incorrect predictions. Now this shows up more often on small datasets because there a few confidently incorrectly classified samples can really drag down the overall score. However, this is not necessarly bad as the model achieves a better WER (it gets more correct answers) and this is what we care about essentially. * Using save_model to save the last step along with the checkpoint, because sometimes final model is better because it has been trained more. * How can we calculate phoneme error rate (PER) . In case we wanted to compare our results to the paper table 1 * That's only possible if you preprocess your text into a sequence of phonemes which define your vocabulary -> this is super language-specific, so I didn't try to make a "one-fits-it-all" script. The closest thing to PER given our current setup would be if you calculate the character error rate (CER) * In general one can say CER < WER < SER and phoneme error rate is often similar to CER * I tried xlsr finetuning for Japanese, but it didn’t give me good results because the common voice contains only 30h Japanese audio. I tried transfer learning of the facebook model without finetuning. I got pretty good results to use my pretrain model. So I recommend pretrain for your language. Here is the progress of my pretraining (960h Japanese audio). And this is the result of finetuning only Japanese 10h (wer is 0.12). The facebook paper shows a lot of time and high resources spent on learning, but I found that I could get good results in about 40 hours with 8 cores of tesla v100 for transfer learning. This training uses a spot instance of sagemaker, so the price is 1/4 of the regular price. * In the paper, for finetuning they have mentioned “The learning rate schedule has three phases: warm up for the first 10% of updates, keep constant for 40% and then linearly decay for the remainder”. Will I have to create my own LR scheduler for this, or it present in hf transformers? * Sharing Trainer code for the same learning rate schedule as the paper, 10% warmup, 40% flat, 50% linear decay * https://huggingface.slack.com/archives/C01QZ90Q83Z/p1616535314247300?thread_ts=1616498495.144000&cid=C01QZ90Q83Z ```python def get_flat_linear_schedule_with_warmup(optimizer:Optimizer, num_warmup_steps:int, num_training_steps:int, last_epoch:int =-1): def lr_lambda(current_step): constant_steps = int(num_training_steps * 0.4) warmup_steps = int(num_training_steps * 0.1) if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) elif current_step < warmup_steps+constant_steps: return 1 else: return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - (warmup_steps+constant_steps))) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_flat_cheduler( name: Union[str, SchedulerType] = None, optimizer: Optimizer = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, ): return get_flat_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) ``` * And creating Trainer wrapper: ```python class FlatTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def create_flat_scheduler(self, num_training_steps: int): self.lr_scheduler = get_flat_cheduler(optimizer = self.optimizer, num_training_steps=num_training_steps) def create_optimizer_and_scheduler(self, num_training_steps): self.create_optimizer() self.create_flat_scheduler(num_training_steps) ``` * Hi, one doubt for Wav2Vec experts, since the model is trained on transforming audio input to a suitable vector space, shouldn’t we be using the normalisers, noise and augmentations used by the original training process ? Since at fine-tuning we are freezing the vectoriser, adding any new normalisation/transformation will lead to OOV / incorrect sampling of audio isn’t it ? Please correct me if I am wrong.. * In pretraining the network is made to learn phonemes ignoring all the noise and backgrounds to make it robust to any type of noise and should be speaker independent as well. * Even in the finetuning step if you introduce noise it would be able to differentiate the phonemes from it but the training process will be a lot slower and sometimes may stuck in a local minima if excessive noise is added. Regarding the other augmentations like pitch, pace you are doing them to create new voices to basically increase your speaker diversity. * Use n-gram language model for inference to improve performance. * Improve training start time using the Group Length Trainer mentioned in : https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/6 * Augmenting data with real-life noisy datasets will not probably improve performance on Common Voice WER: Librispeech, MLS (Multi-Lingual Librispeech) & Common Voice are all read speech corpora which have a very different distribution to "real" dialogue data. It's much harder for speech models to be as generally applicable as NLP models precisly because the input format of speech can have drastic changes in the input distribution compared to just text input. * The GPU might switch between 100% and 0% utilization by default, use larger `dataloader_num_workers` to make the usage more consistent. * Try training multilingual model for same family of languages. ## Augmentation Ideas * Another tip, but no warranty. From the time working with Mozilla deepspeech we got pretty good results with audio data augmentation. Many datasets, especially the low ressource languages from commonvoice are unbalanced at men / woman. Taking the men's audio and making it's faster and / or pitching it up and taking the women's data and pitching down and / or making slower makes the dataset much bigger and more stable in production or with little noise audio. * Example: 10 hours raw audio * Pitching up / down = 20 hours audio * Making all faster = 40 hours audio * Making all slower = 80 hours audio * For german dataset I am working with around 80 hours raw data (10% of full dataset) while finding best parameters and reach an WER up to 0.23. Not sure about it's performance with wav2vec, if anyone is interested I can make some code and share it. * You can start with augmentations on pace, pitch, volume. You can also add white/pink noise to make the model more robust. But there are instances where I have observed that if the training data (text data) is repeated across sentences i.e. across various speakers the model tends to overfit. * Here is an example for augmentation in general Depending on the languages and datasets you would need change the parameters or use other augmentation methods The p parameter means the probability that the method is used, in this case 80% for each one that it is used, so it's variating. It should be easy to duplicate this code to make multiple runs for every single filter like described before and making the datasets 8 times bigger instead of doing it one time and making it only double size. ```python from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift import soundfile as sf import librosa augment = Compose([ AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8), PitchShift(min_semitones=-1, max_semitones=1, p=0.8), Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8) ]) def augmented_speech_file_to_array_fn(batch): try: speech_array, sampling_rate = sf.read(batch["path"] + "augmented.wav") except: speech_array, sampling_rate = librosa.load(batch["path"], sr = 16000, res_type='zero_order_hold') speech_array = augment(samples=speech_array, sample_rate=sampling_rate) sf.write(batch["path"] + "augmented.wav", speech_array, sampling_rate, subtype='PCM_24') batch["speech"] = speech_array batch["sampling_rate"] = sampling_rate batch["target_text"] = batch["text"] return batch common_voice_train_augmented = common_voice_train_augmented.map(augmented_speech_file_to_array_fn, remove_columnremove_columns=common_voice_train_augmented.column_names, num_proc=1) print("merging") common_voice_train = concatenate_datasets([common_voice_train, common_voice_train_augmented]) ``` * What insight I can give is that if your data is balanced in terms of gender, number of sampes, speaker then you can augment using pace, pitch, volume and noise! * Can someone who usually work on audio data weight in on `loudness normalization`( at least I think that is what it is called by my googling). I'm curious if the differing audio levels (some too loud some too quiet) is something we should normalize or normalization done by torchaudio takes care of that too? * https://github.com/RubensZimbres/Repo-2018/blob/master/Google-Cloud-Speech-API/ffmpeg_converter - Ffmpeg for handling the conversion. * One way to calculate loudness ```python def get_loudness_stats(sa, sr): # Return mean and max loudness given a speeach array and sample rate # Credit: https://stackoverflow.com/questions/64913424/how-to-compute-loudness-from-audio-signal # Compute the spectrogram (magnitude) n_fft = 2048 hop_length = 1024 spec_mag = abs(librosa.stft(sa, n_fft=n_fft, hop_length=hop_length)) # Convert the spectrogram into dB spec_db = librosa.amplitude_to_db(spec_mag) # Compute A-weighting values freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) a_weights = librosa.A_weighting(freqs) a_weights = np.expand_dims(a_weights, axis=1) # Apply the A-weghting to the spectrogram in dB spec_dba = spec_db + a_weights # Compute the "loudness" value loudness = librosa.feature.rms(S=librosa.db_to_amplitude(spec_dba)) return np.mean(loudness[0]), np.max(loudness[0]) ``` * Used the max loudness to do some EDA here, planning on doing some normalisation like you suggested when I learn a little more!: https://wandb.ai/wandb/xlsr/artifacts/audio-file/common-voice-tr-train/f368150e87dc099e9559/files/train_samples.table.json * Normalised loudness using pyloudnorm to start, the quieter samples now sound good, no change to the loud samples, but the Irish dataset didn’t have any crazy loud ones to begin with. ```python import pyloudnorm as pyln def get_loudness_normalised(sa, sr): # peak normalize audio to -1 dB peak_normalized_audio = pyln.normalize.peak(sa, -1.0) # measure the loudness first meter = pyln.Meter(rate) # create BS.1770 meter loudness = meter.integrated_loudness(sa) # loudness normalize audio to -12 dB LUFS loudness_normalized_audio = pyln.normalize.loudness(sa, loudness, -12.0) return loudness_normalized_audio ``` ________________ ## Other Ideas * Explore unfreezing the feature extractor for better performance. * Using GPT-2 with Wav2Vec2 for improving WER: https://huggingface.co/voidful/wav2vec2-large-xlsr-53-tw-gpt, we feed the greedy decoded sentence to an GPT model, that means gpt will try to predict the what it should be based on left context in every token. For example, ctc may ge6nerate 9 a clock, then feed this sentence to gpt, it probably predicts o instead of a when left context is 9. A simple way to extract this information is to multiply their probability after softmax which means that we want our result to have high probability in both gpt and ctc models. * On using LMs with ASR: Use beam search with partial scores from ASR and LM: https://github.com/espnet/espnet/blob/master/espnet/nets/beam_search.py * Phonemes in ASR: https://assets.amazon.science/64/94/639ae0c44890837b0f1fbf11ef77/using-phoneme-representations-to-build-predictive-models-robust-to-asr-errors.pdf, https://arxiv.org/pdf/1904.02210.pdf Mixed results, but it was mostly useful when combined with an adversarial objective. It does also really depend on the quality of your grapheme-to-phoneme conversion and the amount of training data you have for the language. If you have lots of data then you can get away with non-transparent orthographies like English. I’d say the phoneme objectives would be better in limited data scenarios when the orthography isn’t transparent. * my experience from phonemes is that they give you worse results compared to graphemes (at least for english) so unless you want a phonemes mapping in your output you are better off with normal graphemes (at least when you are not doing some kind of combination of them). Moreover, using phonemes introduces other issues that are not necessarily straightforward to solve… Speaking for English again. For example mapping correctly from graphemes to phonemes audio that has accented speech. (A speaker from US, UK, Ireland, Australia, South Africa etc can use different sounds for the same word)… Words that can have multiple pronunciations, translating back from phonemes to graphemes the generated text so that someone can read it easier since there is not a single mapping that gets you from one side to the other. Also for example CMU is limited in size so for oov words you will have to use either some rule based methods and make sure that it short of align with CMU or train a model from the CMU dictionary. CMU also doesn’t use schwa which is more or less the most common sound in english and assumes that pronunciation is literal (if that is the right word) so no assimilation I think etc. However, for languages that have more straightforward mapping to sounds than English it can probably work quite well ## Notebook Examples * Using GPT-2 with Wav2Vec2 for improving WER: https://colab.research.google.com/drive/1nBRLf4Pwiply_y5rXWoaIB8LxX41tfEI?usp=sharing * Using audiomentations to augment the data: https://colab.research.google.com/drive/1pNdyUaG75leBri3DMcvHm2VFN91Yo7Oq?usp=sharing ## Libraries * Audiomentations: https://github.com/iver56/audiomentations, For augmentation of audio data. * NLP Libraries for Indian Languages: * https://www.analyticsvidhya.com/blog/2020/01/3-important-nlp-libraries-indian-languages-python/ * http://anoopkunchukuttan.github.io/indic_nlp_library/ * https://pypi.org/project/indic-transliteration/ * https://pypi.org/project/ftfy/ * Transformations, e.g. converting 2 to two: https://pypi.org/project/inflect/ * AudioSegment : https://github.com/jiaaro/pydub * Force Aligner: https://montreal-forced-aligner.readthedocs.io/en/latest/ * Sox with trim can be used for splitting. * from pydub.silence import split_on_silence from PyDub ## Datasets * LibriSpeech/OpenSLR (especially for Indic Languages): http://openslr.org/resources.php * ArabicSpeech: https://arabicspeech.org/resources/ * BembaSpeech: https://arxiv.org/abs/2102.04889 * Open Speech Copora: https://github.com/coqui-ai/open-speech-corpora * Hindi, Marathi, Odia: http://www.openslr.org/103/ (For Interspeech 2021 participants only) * Microsoft Research Open Data: https://msropendata.com/datasets/7230b4b1-912d-400e-be58-f84e0512985e ## Random Helpers * Regex for Indic Languages: * `reg['Hindi'] = re.compile(r"[\u0900-\u097F]")` * `reg['Bengali'] = re.compile(r"[\u0980-\u09FF]")` * `reg['Tamil'] = re.compile(r"[\u0B80-\u0BFF]")` * `reg['Telugu'] = re.compile(r"[\u0C00-\u0C7F]")` * `reg['Kannada'] = re.compile(r"[\u0C80-\u0CFF]")` * `reg['Malayalam'] = re.compile(r"[\u0D00-\u0D7F]")` * `reg['Gujarati'] = re.compile(r"[\u0A80-\u0AFF]")` * `reg['Punjabi'] = re.compile(r"[\u0A00-\u0A7F]")` * `reg['Oriya'] = re.compile(r"[\u0B01-\u0B70]")` * Using W&B: To log to one of these W&B project just add the project name to your W&B logging, see the colab here for an example of setting up and logging to W&B, you need to set a Project environment variable for the Trainer’s WandbCallback to log to your project: * In a notebook you can set the Project environment variable like so * `%env WANDB_PROJECT = xlsr-french` * Or in a script you can do: * `os.environ['WANDB_PROJECT'] = xlsr-french` * For cleaning persian: ```python persian_alpha_codepoints = '\u0621-\u0628\u062A-\u063A\u0641-\u0642\u0644-\u0648\u064E-\u0651\u0655\u067E\u0686\u0698\u06A9\u06AF\u06BE\u06CC' persian_num_codepoints = '\u06F0-\u06F9' arabic_numbers_codepoints = '\u0660-\u0669' space_codepoints ='\u0020\u2000-\u200F\u2028-\u202F' additional_arabic_characters_codepoints = '\u0629\u0643\u0649-\u064B\u064D\u06D5' def clean_row(line): line = re.sub(re.compile(r'\([^)]*\)'), '', line) # split with tab and remove nim fasele line = line.replace('\u200c', ' ') # Just remain persian alphabete and numbers line = re.sub(r"[^" + persian_alpha_codepoints + persian_num_codepoints + additional_arabic_characters_codepoints+ arabic_numbers_codepoints + space_codepoints + '1234567890\n' + "]", "", line) line = re.sub(r"[" + space_codepoints+ "]", " ", line) # Remove or Substitude some characters. ـ # این نتوین ً و ء حفظ میشه line = re.sub(r"[" + 'ّ' + 'ٌ' + 'ـ' + 'َ' + 'ِ' + 'ٕ' + 'ٍ' + 'ُ' + 'ْ' + "]", '', line) # Be careful this editor VSC shows the character in wrong order line = re.sub('ؤ', 'و', line) line = re.sub('ة', 'ه', line) line = re.sub('ك', 'ک', line) line = re.sub('ى', 'ی', line) line = re.sub('ي', 'ی', line) line = re.sub('ە', 'ه', line) line = re.sub('ئ', 'ی', line) line = re.sub('أ', 'ا', line) line = re.sub('إ', 'ا', line) # remove multiple spaces with just one space line = re.sub(' +', ' ', line) # remove multiple strings from first and last of lines line = line.strip() return line ``` * Boris Dayma, helper to decide what letters to keep: https://huggingface.slack.com/archives/C01QZ90Q83Z/p1616786995219700 ```python import operator import csv count_char = {} for d in [common_voice_test, common_voice_train]: for x in d: for c in x['sentence'].lower(): if c in count_char.keys(): count_char[c] += 1 else: count_char[c] = 1 sorted_char = sorted(count_char.items(), key=operator.itemgetter(1), reverse=True) with open('letters.csv', 'w', newline='') as csvfile: writer = csv.writer(csvfile) for x in sorted_char: writer.writerow(x) ``` Then you can change the regex to only keep the top characters, e.g. ```python characters_to_keep = ‘[^abcdefg]’ batch["sentence"] = re.sub(characters_to_keep, '', batch["sentence"]).lower() + " " ``` * Hyperparameter Settings from Boris Dayma Based on my parameter search on turkish dataset for 10 epochs, I would try following strategies (on any dataset). * Option 1: I can only do one run * batch size: 32 (for example batch size of 16 with 2 gradients accumulation) * activation_dropout: 0.055 * attention_dropout: 0.094 * feat_proj_dropout: 0.04 * hidden_dropout: 0.047 * layerdrop: 0.041 * learning_rate: 2.34 e-4 * mask_time_prob: 0.082 * Option 2: I can do a few runs * Bayesian search with following range of parameters (adjust language and batch size based on your setup). * If your dataset is very small, you could start a new random search (use sweep.yaml from the repo) * Option 3: I can handle a much larger batch size (with no accumulation) * filter my sweep search with larger batch sizes to decide your parameters. * Using on-the-fly augmentation with Wav2Vec2 ```pyhton import torch import torch.nn as nn from torch_audiomentations import Compose, Gain, PolarityInversion class AUG(nn.Module): def __init__(self): super().__init__() # Initialize augmentation callable self.aug = Compose(transforms=[Gain(min_gain_in_db=-15.0, max_gain_in_db=5.0, p=0.5,), PolarityInversion(p=0.5)]) def forward(self, x): return self.aug(x, sample_rate=16000) ... aug = AUG() model.wav2vec2.feature_extractor.conv_layers.insert(0, aug) ``` ## Forums * Spanish: https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586/3?u=pcuenq * German: https://discuss.huggingface.co/t/german-asr-fine-tuning-wav2vec2/4558 * Persian: https://discuss.huggingface.co/t/persian-asr-fine-tuning-wav2vec2/4559 * Swedish: https://discuss.huggingface.co/t/swedish-asr-fine-tuning-wav2vec2/4560 * Indonesian: https://discuss.huggingface.co/t/indonesian-asr-fine-tuning-wav2vec2/4564 * Arabic: https://discuss.huggingface.co/t/arabic-asr-fine-tuning-wav2vec2/4608 * Hindi: https://discuss.huggingface.co/t/hindi-asr-fine-tuning-wav2vec2/4582 * Spanish: https://discuss.huggingface.co/t/spanish-asr-fine-tuning-wav2vec2/4586 * Portuguese: https://discuss.huggingface.co/t/portuguese-asr-fine-tuning-xlsr-wav2vec2/4653 * Turkish: https://discuss.huggingface.co/t/turkish-asr-fine-tuning-wav2vec2/4556 * MultiGPU Training: https://discuss.huggingface.co/t/wav2vec-fine-tuning-with-multigpu/4894?u=gorodecki * Bemba: https://discuss.huggingface.co/t/bemba-asr-fine-tuning-wav2vec2/4774?u=claytone * Marathi: https://discuss.huggingface.co/t/marathi-asr-fine-tuning-wav2vec2/4943?u=gchhablani ## Other Links * Contributors List: https://docs.google.com/spreadsheets/d/1GSzqhu2ysIhPzMtfWZb9UOiQI3fkaGoF8GSZbrITr98/edit?usp=sharing * OVH LIst: https://docs.google.com/spreadsheets/d/1Lk4MU6aVMnsan5Z1nudqwukUb1nvTFMFgFfcl3GsKso/edit?usp=sharing * OVH Discord Link: https://discord.gg/HaNEhBax * 16GB Colab Lnk: https://colab.research.google.com/drive/1D6krVG0PPJR2Je9g5eN_2h6JP73_NUXz * Speech Crawler from YouTube: https://github.com/Prem-kumar27/Fast-KTSpeechCrawler * YouTube explanation of Wav2Vec2: https://www.youtube.com/watch?v=aUSXvoWfy3w * PapersWithCode Link: https://paperswithcode.com/dataset/common-voice * W&B Link for language-specific projects : https://docs.google.com/spreadsheets/d/1ZlpldJ4DBqRk5IR7jNU1gJY_eciKADlwUgkgEAfD9PU/edit?usp=sharing * W&B data visualization: https://vimeo.com/526950141 * Add existing logs to the W&B projects: https://docs.wandb.ai/app/features/sidebar#move-runs-between-projects * Video from the authors of the paper: https://youtu.be/u5Bldiey4zc * XLSR Sweeps using W&B: https://github.com/borisdayma/xlsr-sweeps * Link to CER metric: https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/blob/main/cer.py * Compressing Wav2Vec2: https://medium.com/georgian-impact-blog/compressing-wav2vec-2-0-f41166e82dc2 * Common Voice Explorer: https://github.com/cceyda/common-voice-explorer, https://28a3b43b-3997-4649-adda-93f52445562a.job.gra.training.ai.cloud.ovh.net/ * Elda for datasets: http://elda.org/ * Repo for optimal training on large datasets : https://github.com/maxidl/wav2vec2 * Library for training: https://github.com/morganmcg1/xlsr_finetune * Checks for corrupted audio, cleaner functions, etc. * Gist for checkpoint averaging: https://gist.github.com/ArtVanderlay/f08fc36cf79e53fbe03804697735c8f4#file-average_checkpoints-py