Skip to content

🐢 Pretrain Transformers Models in PyTorch using Hugging Face Transformers

Pretrain 67 transformers models on your custom dataset.

Open In Colab   Generic badge Generic badge Generic badge License


Disclaimer: The format of this tutorial notebook is very similar with my other tutorial notebooks. This is done intentionally in order to keep readers familiar with my format.


This notebook is used to pretrain transformers models using Huggingface on your own custom dataset.

What do I mean by pretrain transformers? The definition of pretraining is to train in advance. That is exactly what I mean! Train a transformer model to use it as a pretrained transformers model which can be used to fine-tune it on a specific task!

I also use the term fine-tune where I mean to continue training a pretrained model on a custom dataset. I know it is confusing and I hope I'm not making it worse. At the end of the day you are training a transformer model that was previously trained or not!

With the AutoClasses functionality we can reuse the code on a large number of transformers models!

This notebook is designed to:

  • Use an already pretrained transformers model and fine-tune (continue training) it on your custom dataset.

  • Train a transformer model from scratch on a custom dataset. This requires an already trained (pretrained) tokenizer. This notebook will use by default the pretrained tokenizer if an already trained tokenizer is no provided.

This notebook is heavily inspired from the Hugging Face script used for training language models: transformers/tree/master/examples/language-modeling. I basically adapted that script to work nicely in a notebook with a lot more comments.

Notes from transformers/tree/master/examples/language-modeling: Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet). GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.


What should I know for this notebook?

Since I am using PyTorch to fine-tune our transformers models any knowledge on PyTorch is very useful.

Knowing a little bit about the transformers library helps too.

In this notebook I am using raw text data to pretrain / train / fine-tune transformers models. There is no need for labeled data since we are not doing classification. The Transformers library handles the text files in same way as the original implementation of each model did.


How to use this notebook?

Like with every project, I built this notebook with reusability in mind. This notebook uses a custom dataset from .txt files. Since the dataset does not come in a single .txt file I created a custom function movie_reviews_to_file that reads the dataset and creates the text file. The way I load the .txt files can be easily reused for any other dataset.

The only modifications needed to use your own dataset will be in the paths provided to the train .txt file and evaluation .txt file.

All parameters that need to be changed are under the Parameters Setup section. Each parameter is nicely commented and structured to be as intuitive as possible.


What transformers models work with this notebook?

A lot of people will probably use it for Bert. When there is a need to run a different transformer model architecture, which one would work with this code? Since the name of the notebooks is pretrain_transformers it should work with more than one type of transformers.

I ran this notebook across all the pretrained models found on Hugging Face Transformer. This way you know ahead of time if the model you plan to use works with this code without any modifications.

The list of pretrained transformers models that work with this notebook can be found here. There are 67 models that worked πŸ˜„ and 39 models that failed to work 😒 with this notebook. Remember these are pretrained models and fine-tuned on custom dataset.


Dataset

This notebook will cover pretraining transformers on a custom dataset. I will use the well known movies reviews positive - negative labeled Large Movie Review Dataset.

The description provided on the Stanford website:

This is a dataset for binary sentiment classification containing substantially more data than previous benchmark datasets. We provide a set of 25,000 highly polar movie reviews for training, and 25,000 for testing. There is additional unlabeled data for use as well. Raw text and already processed bag of words formats are provided. See the README file contained in the release for more details.

Why this dataset? I believe is an easy to understand and use dataset for classification. I think sentiment data is always fun to work with.


Coding

Now let's do some coding! We will go through each coding cell in the notebook and describe what it does, what's the code, and when is relevantβ€Š-β€Šshow the output.

I made this format to be easy to follow if you decide to run each code cell in your own python notebook.

When I learn from a tutorial I always try to replicate the results. I believe it's easy to follow along if you have the code next to the explanations.


Downloads

Download the Large Movie Review Dataset and unzip it locally.

# Download the dataset.
!wget -q -nc http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
# Unzip the dataset.
!tar -zxf /content/aclImdb_v1.tar.gz

Installs

  • transformers library needs to be installed to use all the awesome code from Hugging Face. To get the latest version I will install it straight from GitHub.

  • ml_things library used for various machine learning related tasks. I created this library to reduce the amount of code I need to write for each machine learning project.

# Install transformers library.
!pip install -q git+https://github.com/huggingface/transformers.git
# Install helper functions.
!pip install -q git+https://github.com/gmihaila/ml_things.git
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing wheel metadata ... done
 |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2.9MB 6.7MB/s 
 |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 890kB 48.9MB/s 
 |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.1MB 49.0MB/s 
Building wheel for transformers (PEP 517) ... done
Building wheel for sacremoses (setup.py) ... done
 |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 71kB 5.2MB/s 
Building wheel for ml-things (setup.py) ... done
Building wheel for ftfy (setup.py) ... done

Imports

Import all needed libraries for this notebook.

Declare basic parameters used for this notebook:

  • set_seed(123) - Always good to set a fixed seed for reproducibility.

  • device - Look for gpu to use. I will use cpu by default if no gpu found.

import io
import os
import math
import torch
import warnings
from tqdm.notebook import tqdm
from ml_things import plot_dict, fix_text
from transformers import (
                          CONFIG_MAPPING,
                          MODEL_FOR_MASKED_LM_MAPPING,
                          MODEL_FOR_CAUSAL_LM_MAPPING,
                          PreTrainedTokenizer,
                          TrainingArguments,
                          AutoConfig,
                          AutoTokenizer,
                          AutoModelWithLMHead,
                          AutoModelForCausalLM,
                          AutoModelForMaskedLM,
                          LineByLineTextDataset,
                          TextDataset,
                          DataCollatorForLanguageModeling,
                          DataCollatorForWholeWordMask,
                          DataCollatorForPermutationLanguageModeling,
                          PretrainedConfig,
                          Trainer,
                          set_seed,
                          )

# Set seed for reproducibility,
set_seed(123)

# Look for gpu to use. Will use `cpu` by default if no gpu found.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Helper Functions

I like to keep all Classes and functions that will be used in this notebook under this section to help maintain a clean look of the notebook:

movie_reviews_to_file(path_data: str, path_texts_file: str)

As I mentioned before, we will need .txt files to run this notebook. Since the Large Movie Review Dataset comes in multiple files with different labels I created this function to put together all data in a single .txt file. Examples are saved on each line of the file. The path_data points to the path where data files are present and path_texts_file will be the .txt file containing all data.


ModelDataArguments

This class follows similar format as the [transformers]((huggingface/transformers) library. The main difference is the way I combined multiple types of arguments into one and used rules to make sure the arguments used are correctly set. Here are all argument detailed (they are also mentioned in the class documentation):

  • train_data_file: Path to your .txt file dataset. If you have an example on each line of the file make sure to use line_by_line=True. If the data file contains all text data without any special grouping use line_by_line=False to move a block_size window across the text file.

  • eval_data_file: Path to evaluation .txt file. It has the same format as train_data_file.

  • line_by_line: If the train_data_file and eval_data_file contains separate examples on each line set line_by_line=True. If there is no separation between examples and train_data_file and eval_data_file contains continuous text then line_by_line=False and a window of block_size will be moved across the files to acquire examples.

  • mlm: Is a flag that changes loss function depending on model architecture. This variable needs to be set to True when working with masked language models like bert or roberta and set to False otherwise. There are functions that will raise ValueError if this argument is not set accordingly.

  • whole_word_mask: Used as flag to determine if we decide to use whole word masking or not. Whole word masking means that whole words will be masked during training instead of tokens which can be chunks of words.

  • mlm_probability: Used when training masked language models. Needs to have mlm=True. It represents the probability of masking tokens when training model.

  • plm_probability: Flag to define the ratio of length of a span of masked tokens to surrounding context length for permutation language modeling. Used for XLNet.

  • max_span_length: Flag may also be used to limit the length of a span of masked tokens used for permutation language modeling. Used for XLNet.

  • block_size: It refers to the windows size that is moved across the text file. Set to -1 to use maximum allowed length.

  • overwrite_cache: If there are any cached files, overwrite them.

  • model_type: Type of model used: bert, roberta, gpt2. More details here.

  • model_config_name: Config of model used: bert, roberta, gpt2. More details here.

  • tokenizer_name: Tokenizer used to process data for training the model. It usually has same name as model_name_or_path: bert-base-cased, roberta-base, gpt2 etc.

  • model_name_or_path: Path to existing transformers model or name of transformer model to be used: bert-base-cased, roberta-base, gpt2 etc. More details here.

  • model_cache_dir: Path to cache files. It helps to save time when re-running code.


get_model_config(args: ModelDataArguments)

Get model configuration. Using the ModelDataArguments to return the model configuration. Here are all argument detailed:

  • args: Model and data configuration arguments needed to perform pretraining.

  • Returns: Model transformers configuration.

  • Raises: ValueError: If mlm=True and model_type is NOT in ["bert", "roberta", "distilbert", "camembert"]. We need to use a masked language model in order to set mlm=True.


get_tokenizer(args: ModelDataArguments)

Get model tokenizer.Using the ModelDataArguments return the model tokenizer and change block_size form args if needed. Here are all argument detailed:

  • args: Model and data configuration arugments needed to perform pretraining.

  • Returns: Model transformers tokenizer.


get_model(args: ModelDataArguments, model_config)

Get model. Using the ModelDataArguments return the actual model. Here are all argument detailed:

  • args: Model and data configuration arguments needed to perform pretraining.

  • model_config: Model transformers configuration.

  • Returns: PyTorch model.


get_dataset(args: ModelDataArguments, tokenizer: PreTrainedTokenizer, evaluate: bool=False)

Process dataset file into PyTorch Dataset. Using the ModelDataArguments return the actual model. Here are all argument detailed:

  • args: Model and data configuration arguments needed to perform pretraining.

  • tokenizer: Model transformers tokenizer.

  • evaluate: If set to True the test / validation file is being handled. If set to False the train file is being handled.

  • Returns: PyTorch Dataset that contains file's data.


get_collator(args: ModelDataArguments, model_config: PretrainedConfig, tokenizer: PreTrainedTokenizer)

Get appropriate collator function. Collator function will be used to collate a PyTorch Dataset object. Here are all argument detailed:

  • args: Model and data configuration arguments needed to perform pretraining.

  • model_config: Model transformers configuration.

  • tokenizer: Model transformers tokenizer.

  • Returns: Transformers specific data collator.

def movie_reviews_to_file(path_data: str, path_texts_file: str):
  r"""Reading in all data from path and saving it into a single `.txt` file.
  
  In the pretraining process of our transformers model we require a text file.

  This function is designed to work for the Movie Reviews Dataset. 
  You wil have to create your own function to move all examples into a text 
  file if you don't already have a text file with all your unlabeled data.

  Arguments:

      path_data (:obj:`str`):
        Path to the Movie Review Dataset partition. We only have `\train` and 
        `test` partitions.

      path_texts_file (:obj:`str`):
        File path of the generated `.txt` file that contains one example / line.

  """

  # Check if path exists.
  if not os.path.isdir(path_data):
    # Raise error if path is invalid.
    raise ValueError('Invalid `path` variable! Needs to be a directory')
  # Check max sequence length.
  texts = []
  print('Reading `%s` partition...' % (os.path.basename(path_data)))
  # Since the labels are defined by folders with data we loop 
  # through each label.
  for label  in ['neg', 'pos']:
    sentiment_path = os.path.join(path_data, label)

    # Get all files from path.
    files_names = os.listdir(sentiment_path)#[:30] # SAMPLE FOR DEBUGGING.
    # Go through each file and read its content.
    for file_name in tqdm(files_names, desc=label, unit='files'):
      file_path = os.path.join(sentiment_path, file_name)

      # Read content.
      content = io.open(file_path, mode='r', encoding='utf-8').read()
      # Fix any unicode issues.
      content = fix_text(content)
      # Save content.
      texts.append(content)
  # Move list to single string.
  all_texts = '\n'.join(texts)
  # Send all texts string to single file.
  io.open(file=path_texts_file, mode='w', encoding='utf-8').write(all_texts)
  # Print when done.
  print('`.txt` file saved in `%s`\n' % path_texts_file)

  return


class ModelDataArguments(object):
  r"""Define model and data configuration needed to perform pretraining.

  Eve though all arguments are optional there still needs to be a certain 
  number of arguments that require values attributed.
  
  Arguments:

    train_data_file (:obj:`str`, `optional`): 
      Path to your .txt file dataset. If you have an example on each line of 
      the file make sure to use line_by_line=True. If the data file contains 
      all text data without any special grouping use line_by_line=False to move 
      a block_size window across the text file.
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    eval_data_file (:obj:`str`, `optional`): 
      Path to evaluation .txt file. It has the same format as train_data_file.
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    line_by_line (:obj:`bool`, `optional`, defaults to :obj:`False`): 
      If the train_data_file and eval_data_file contains separate examples on 
      each line then line_by_line=True. If there is no separation between 
      examples and train_data_file and eval_data_file contains continuous text 
      then line_by_line=False and a window of block_size will be moved across 
      the files to acquire examples.
      This argument is optional and it has a default value.

    mlm (:obj:`bool`, `optional`, defaults to :obj:`False`): 
      Is a flag that changes loss function depending on model architecture. 
      This variable needs to be set to True when working with masked language 
      models like bert or roberta and set to False otherwise. There are 
      functions that will raise ValueError if this argument is 
      not set accordingly.
      This argument is optional and it has a default value.

    whole_word_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
      Used as flag to determine if we decide to use whole word masking or not. 
      Whole word masking means that whole words will be masked during training 
      instead of tokens which can be chunks of words.
      This argument is optional and it has a default value.

    mlm_probability(:obj:`float`, `optional`, defaults to :obj:`0.15`): 
      Used when training masked language models. Needs to have mlm set to True. 
      It represents the probability of masking tokens when training model.
      This argument is optional and it has a default value.

    plm_probability (:obj:`float`, `optional`, defaults to :obj:`float(1/6)`): 
      Flag to define the ratio of length of a span of masked tokens to 
      surrounding context length for permutation language modeling. 
      Used for XLNet.
      This argument is optional and it has a default value.

    max_span_length (:obj:`int`, `optional`, defaults to :obj:`5`): 
      Flag may also be used to limit the length of a span of masked tokens used 
      for permutation language modeling. Used for XLNet.
      This argument is optional and it has a default value.

    block_size (:obj:`int`, `optional`, defaults to :obj:`-1`): 
      It refers to the windows size that is moved across the text file. 
      Set to -1 to use maximum allowed length.
      This argument is optional and it has a default value.

    overwrite_cache (:obj:`bool`, `optional`, defaults to :obj:`False`): 
      If there are any cached files, overwrite them.
      This argument is optional and it has a default value.

    model_type (:obj:`str`, `optional`): 
      Type of model used: bert, roberta, gpt2. 
      More details: https://huggingface.co/transformers/pretrained_models.html
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    model_config_name (:obj:`str`, `optional`):
      Config of model used: bert, roberta, gpt2. 
      More details: https://huggingface.co/transformers/pretrained_models.html
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    tokenizer_name: (:obj:`str`, `optional`)
      Tokenizer used to process data for training the model. 
      It usually has same name as model_name_or_path: bert-base-cased, 
      roberta-base, gpt2 etc.
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    model_name_or_path (:obj:`str`, `optional`): 
      Path to existing transformers model or name of 
      transformer model to be used: bert-base-cased, roberta-base, gpt2 etc. 
      More details: https://huggingface.co/transformers/pretrained_models.html
      This argument is optional and it will have a `None` value attributed 
      inside the function.

    model_cache_dir (:obj:`str`, `optional`): 
      Path to cache files to save time when re-running code.
      This argument is optional and it will have a `None` value attributed 
      inside the function.

  Raises:

        ValueError: If `CONFIG_MAPPING` is not loaded in global variables.

        ValueError: If `model_type` is not present in `CONFIG_MAPPING.keys()`.

        ValueError: If `model_type`, `model_config_name` and 
          `model_name_or_path` variables are all `None`. At least one of them 
          needs to be set.

        warnings: If `model_config_name` and `model_name_or_path` are both 
          `None`, the model will be trained from scratch.

        ValueError: If `tokenizer_name` and `model_name_or_path` are both 
          `None`. We need at least one of them set to load tokenizer.
    
  """

  def __init__(self, train_data_file=None, eval_data_file=None, 
               line_by_line=False, mlm=False, mlm_probability=0.15, 
               whole_word_mask=False, plm_probability=float(1/6), 
               max_span_length=5, block_size=-1, overwrite_cache=False, 
               model_type=None, model_config_name=None, tokenizer_name=None, 
               model_name_or_path=None, model_cache_dir=None):
    
    # Make sure CONFIG_MAPPING is imported from transformers module.
    if 'CONFIG_MAPPING' not in globals():
      raise ValueError('Could not find `CONFIG_MAPPING` imported! Make sure' \
                       ' to import it from `transformers` module!')

    # Make sure model_type is valid.
    if (model_type is not None) and (model_type not in CONFIG_MAPPING.keys()):
      raise ValueError('Invalid `model_type`! Use one of the following: %s' % 
                       (str(list(CONFIG_MAPPING.keys()))))
      
    # Make sure that model_type, model_config_name and model_name_or_path 
    # variables are not all `None`.
    if not any([model_type, model_config_name, model_name_or_path]):
      raise ValueError('You can`t have all `model_type`, `model_config_name`,' \
                       ' `model_name_or_path` be `None`! You need to have' \
                       'at least one of them set!')
    
    # Check if a new model will be loaded from scratch.
    if not any([model_config_name, model_name_or_path]):
      # Setup warning to show pretty. This is an overkill
      warnings.formatwarning = lambda message,category,*args,**kwargs: \
                               '%s: %s\n' % (category.__name__, message)
      # Display warning.
      warnings.warn('You are planning to train a model from scratch! πŸ™€')

    # Check if a new tokenizer wants to be loaded.
    # This feature is not supported!
    if not any([tokenizer_name, model_name_or_path]):
      # Can't train tokenizer from scratch here! Raise error.
      raise ValueError('You want to train tokenizer from scratch! ' \
                    'That is not possible yet! You can train your own ' \
                    'tokenizer separately and use path here to load it!')
      
    # Set all data related arguments.
    self.train_data_file = train_data_file
    self.eval_data_file = eval_data_file
    self.line_by_line = line_by_line
    self.mlm = mlm
    self.whole_word_mask = whole_word_mask
    self.mlm_probability = mlm_probability
    self.plm_probability = plm_probability
    self.max_span_length = max_span_length
    self.block_size = block_size
    self.overwrite_cache = overwrite_cache

    # Set all model and tokenizer arguments.
    self.model_type = model_type
    self.model_config_name = model_config_name
    self.tokenizer_name = tokenizer_name
    self.model_name_or_path = model_name_or_path
    self.model_cache_dir = model_cache_dir
    return


def get_model_config(args: ModelDataArguments):
  r"""
  Get model configuration.

  Using the ModelDataArguments return the model configuration.

  Arguments:

    args (:obj:`ModelDataArguments`):
      Model and data configuration arguments needed to perform pretraining.

  Returns:

    :obj:`PretrainedConfig`: Model transformers configuration.

  Raises:

    ValueError: If `mlm=True` and `model_type` is NOT in ["bert", 
          "roberta", "distilbert", "camembert"]. We need to use a masked 
          language model in order to set `mlm=True`.
  """

  # Check model configuration.
  if args.model_config_name is not None:
    # Use model configure name if defined.
    model_config = AutoConfig.from_pretrained(args.model_config_name, 
                                      cache_dir=args.model_cache_dir)

  elif args.model_name_or_path is not None:
    # Use model name or path if defined.
    model_config = AutoConfig.from_pretrained(args.model_name_or_path, 
                                      cache_dir=args.model_cache_dir)

  else:
    # Use config mapping if building model from scratch.
    model_config = CONFIG_MAPPING[args.model_type]()

  # Make sure `mlm` flag is set for Masked Language Models (MLM).
  if (model_config.model_type in ["bert", "roberta", "distilbert", 
                                  "camembert"]) and (args.mlm is False):
    raise ValueError('BERT and RoBERTa-like models do not have LM heads ' \
                    'butmasked LM heads. They must be run setting `mlm=True`')
  
  # Adjust block size for xlnet.
  if model_config.model_type == "xlnet":
    # xlnet used 512 tokens when training.
    args.block_size = 512
    # setup memory length
    model_config.mem_len = 1024
  
  return model_config


def get_tokenizer(args: ModelDataArguments):
  r"""
  Get model tokenizer.

  Using the ModelDataArguments return the model tokenizer and change 
  `block_size` form `args` if needed.

  Arguments:

    args (:obj:`ModelDataArguments`):
      Model and data configuration arguments needed to perform pretraining.

  Returns:

    :obj:`PreTrainedTokenizer`: Model transformers tokenizer.

  """

  # Check tokenizer configuration.
  if args.tokenizer_name:
    # Use tokenizer name if define.
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, 
                                              cache_dir=args.model_cache_dir)

  elif args.model_name_or_path:
    # Use tokenizer name of path if defined.
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, 
                                              cache_dir=args.model_cache_dir)
    
  # Setp data block size.
  if args.block_size <= 0:
    # Set block size to maximum length of tokenizer.
    # Input block size will be the max possible for the model.
    # Some max lengths are very large and will cause a
    args.block_size = tokenizer.model_max_length
  else:
    # Never go beyond tokenizer maximum length.
    args.block_size = min(args.block_size, tokenizer.model_max_length)

  return tokenizer
  

def get_model(args: ModelDataArguments, model_config):
  r"""
  Get model.

  Using the ModelDataArguments return the actual model.

  Arguments:

    args (:obj:`ModelDataArguments`):
      Model and data configuration arguments needed to perform pretraining.

    model_config (:obj:`PretrainedConfig`):
      Model transformers configuration.

  Returns:

    :obj:`torch.nn.Module`: PyTorch model.

  """

  # Make sure MODEL_FOR_MASKED_LM_MAPPING and MODEL_FOR_CAUSAL_LM_MAPPING are 
  # imported from transformers module.
  if ('MODEL_FOR_MASKED_LM_MAPPING' not in globals()) and \
                ('MODEL_FOR_CAUSAL_LM_MAPPING' not in globals()):
    raise ValueError('Could not find `MODEL_FOR_MASKED_LM_MAPPING` and' \
                     ' `MODEL_FOR_MASKED_LM_MAPPING` imported! Make sure to' \
                     ' import them from `transformers` module!')
    
  # Check if using pre-trained model or train from scratch.
  if args.model_name_or_path:
    # Use pre-trained model.
    if type(model_config) in MODEL_FOR_MASKED_LM_MAPPING.keys():
      # Masked language modeling head.
      return AutoModelForMaskedLM.from_pretrained(
                        args.model_name_or_path,
                        from_tf=bool(".ckpt" in args.model_name_or_path),
                        config=model_config,
                        cache_dir=args.model_cache_dir,
                        )
    elif type(model_config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys():
      # Causal language modeling head.
      return AutoModelForCausalLM.from_pretrained(
                                          args.model_name_or_path, 
                                          from_tf=bool(".ckpt" in 
                                                        args.model_name_or_path),
                                          config=model_config, 
                                          cache_dir=args.model_cache_dir)
    else:
      raise ValueError(
          'Invalid `model_name_or_path`! It should be in %s or %s!' % 
          (str(MODEL_FOR_MASKED_LM_MAPPING.keys()), 
           str(MODEL_FOR_CAUSAL_LM_MAPPING.keys())))
    
  else:
    # Use model from configuration - train from scratch.
      print("Training new model from scratch!")
      return AutoModelWithLMHead.from_config(config)


def get_dataset(args: ModelDataArguments, tokenizer: PreTrainedTokenizer, 
                evaluate: bool=False):
  r"""
  Process dataset file into PyTorch Dataset.

  Using the ModelDataArguments return the actual model.

  Arguments:

    args (:obj:`ModelDataArguments`):
      Model and data configuration arguments needed to perform pretraining.

    tokenizer (:obj:`PreTrainedTokenizer`):
      Model transformers tokenizer.

    evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
      If set to `True` the test / validation file is being handled.
      If set to `False` the train file is being handled.

  Returns:

    :obj:`Dataset`: PyTorch Dataset that contains file's data.

  """

  # Get file path for either train or evaluate.
  file_path = args.eval_data_file if evaluate else args.train_data_file

  # Check if `line_by_line` flag is set to `True`.
  if args.line_by_line:
    # Each example in data file is on each line.
    return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, 
                                 block_size=args.block_size)
    
  else:
    # All data in file is put together without any separation.
    return TextDataset(tokenizer=tokenizer, file_path=file_path, 
                       block_size=args.block_size, 
                       overwrite_cache=args.overwrite_cache)


def get_collator(args: ModelDataArguments, model_config: PretrainedConfig, 
                 tokenizer: PreTrainedTokenizer):
  r"""
  Get appropriate collator function.

  Collator function will be used to collate a PyTorch Dataset object.

  Arguments:

    args (:obj:`ModelDataArguments`):
      Model and data configuration arguments needed to perform pretraining.

    model_config (:obj:`PretrainedConfig`):
      Model transformers configuration.

    tokenizer (:obj:`PreTrainedTokenizer`):
      Model transformers tokenizer.

  Returns:

    :obj:`data_collator`: Transformers specific data collator.

  """

  # Special dataset handle depending on model type.
  if model_config.model_type == "xlnet":
    # Configure collator for XLNET.
    return DataCollatorForPermutationLanguageModeling(
                                          tokenizer=tokenizer,
                                          plm_probability=args.plm_probability,
                                          max_span_length=args.max_span_length,
                                          )
  else:
    # Configure data for rest of model types.
    if args.mlm and args.whole_word_mask:
      # Use whole word masking.
      return DataCollatorForWholeWordMask(
                                          tokenizer=tokenizer, 
                                          mlm_probability=args.mlm_probability,
                                          )
    else:
      # Regular language modeling.
      return DataCollatorForLanguageModeling(
                                          tokenizer=tokenizer, 
                                          mlm=args.mlm, 
                                          mlm_probability=args.mlm_probability,
                                          )

Parameters Setup

Declare the rest of the parameters used for this notebook:

  • model_data_args contains all arguments needed to setup dataset, model configuration, model tokenizer and the actual model. This is created using the ModelDataArguments class.

  • training_args contain all arguments needed to use the Trainer functionality from Transformers that allows us to train transformers models in PyTorch very easy. You can find the complete documentation here. There are a lot of parameters that can be set to allow multiple functionalities. I only used the following parameters (the comments are inspired from the HuggingFace documentation of TrainingArguments:

  • output_dir: The output directory where the model predictions and checkpoints will be written. I set it up to pretrained_bert_model where the model and will be saved.

  • overwrite_output_dir: Overwrite the content of the output directory. I set it to True in case I run the notebook multiple times I only care about the last run.

  • do_train: Whether to run training or not. I set this parameter to True because I want to train the model on my custom dataset.

  • do_eval: Whether to run evaluation on the evaluation files or not. I set it to True since I have test data file and I want to evaluate how well the model trains.

  • per_device_train_batch_size: Batch size GPU/TPU core/CPU training. I set it to 2 for this example. I recommend setting it up as high as your GPU memory allows you.

  • per_device_eval_batch_size: Batch size GPU/TPU core/CPU for evaluation.I set this value to 100 since it's not dealing with gradients.

  • evaluation_strategy: Evaluation strategy to adopt during training: no: No evaluation during training; steps: Evaluate every eval_steps;epoch`: Evaluate every end of epoch. I set it to 'steps' since I want to evaluate model more often.

  • logging_steps: How often to show logs. I will se this to plot history loss and calculate perplexity. I set this to 20 just as an example. If your evaluate data is large you might not want to run it that often because it will significantly slow down training time.

  • eval_steps: Number of update steps between two evaluations if evaluation_strategy="steps". Will default to the same value as logging_steps if not set. Since I want to evaluate model everlogging_steps I will set this to None since it will inherit same value as logging_steps.

  • prediction_loss_only: Set prediction loss to True in order to return loss for perplexity calculation. Since I want to calculate perplexity I set this to True since I want to monitor loss and perplexity (which is exp(loss)).

  • learning_rate: The initial learning rate for Adam. Defaults is set to 5e-5.

  • weight_decay: The weight decay to apply (if not zero)Defaults is set to 0.

  • adam_epsilon: Epsilon for the Adam optimizer. Defaults to 1e-8.

  • max_grad_norm: Maximum gradient norm (for gradient clipping). Defaults to 0.

  • num_train_epochs: Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training). I set it to 2 at most. Since the custom dataset will be a lot smaller than the original dataset the model was trained on we don't want to overfit.

  • save_steps: Number of updates steps before two checkpoint saves. Defaults to 500.

# Define arguments for data, tokenizer and model arguments.
# See comments in `ModelDataArguments` class.
model_data_args = ModelDataArguments(
                                    train_data_file='/content/train.txt', 
                                    eval_data_file='/content/test.txt', 
                                    line_by_line=True, 
                                    mlm=True,
                                    whole_word_mask=True,
                                    mlm_probability=0.15,
                                    plm_probability=float(1/6), 
                                    max_span_length=5,
                                    block_size=50, 
                                    overwrite_cache=False, 
                                    model_type='bert', 
                                    model_config_name='bert-base-cased', 
                                    tokenizer_name='bert-base-cased', 
                                    model_name_or_path='bert-base-cased', 
                                    model_cache_dir=None,
                                    )

# Define arguments for training
# Note: I only used the arguments I care about. `TrainingArguments` contains
# a lot more arguments. For more details check the awesome documentation:
# https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments
training_args = TrainingArguments(
                          # The output directory where the model predictions 
                          # and checkpoints will be written.
                          output_dir='pretrain_bert',

                          # Overwrite the content of the output directory.
                          overwrite_output_dir=True,

                          # Whether to run training or not.
                          do_train=True, 
                          
                          # Whether to run evaluation on the dev or not.
                          do_eval=True,
                          
                          # Batch size GPU/TPU core/CPU training.
                          per_device_train_batch_size=10,
                          
                          # Batch size  GPU/TPU core/CPU for evaluation.
                          per_device_eval_batch_size=100,

                          # evaluation strategy to adopt during training
                          # `no`: No evaluation during training.
                          # `steps`: Evaluate every `eval_steps`.
                          # `epoch`: Evaluate every end of epoch.
                          evaluation_strategy='steps',

                          # How often to show logs. I will se this to 
                          # plot history loss and calculate perplexity.
                          logging_steps=700,

                          # Number of update steps between two 
                          # evaluations if evaluation_strategy="steps".
                          # Will default to the same value as l
                          # logging_steps if not set.
                          eval_steps = None,
                          
                          # Set prediction loss to `True` in order to 
                          # return loss for perplexity calculation.
                          prediction_loss_only=True,

                          # The initial learning rate for Adam. 
                          # Defaults to 5e-5.
                          learning_rate = 5e-5,

                          # The weight decay to apply (if not zero).
                          weight_decay=0,

                          # Epsilon for the Adam optimizer. 
                          # Defaults to 1e-8
                          adam_epsilon = 1e-8,

                          # Maximum gradient norm (for gradient 
                          # clipping). Defaults to 0.
                          max_grad_norm = 1.0,
                          # Total number of training epochs to perform 
                          # (if not an integer, will perform the 
                          # decimal part percents of
                          # the last epoch before stopping training).
                          num_train_epochs = 2,

                          # Number of updates steps before two checkpoint saves. 
                          # Defaults to 500
                          save_steps = -1,
                          )

Load Configuration, Tokenizer and Model

Loading the three essential parts of the pretrained transformers: configuration, tokenizer and model.

Since I use the AutoClass functionality from Hugging Face I only need to worry about the model's name as input and the rest is handled by the transformers library.

I will be calling each three functions created in the Helper Functions tab that help return config of the model, tokenizer of the model and the actual PyTorch model.

After model is loaded is always good practice to resize the model depending on the tokenizer size. This means that the tokenizer's vocabulary will be aligned with the models embedding layer. This is very useful when we have a different tokenizer that the pretrained one or we train a transformer model from scratch.

# Load model configuration.
print('Loading model configuration...')
config = get_model_config(model_data_args)

# Load model tokenizer.
print('Loading model`s tokenizer...')
tokenizer = get_tokenizer(model_data_args)

# Loading model.
print('Loading actual model...')
model = get_model(model_data_args, config)

# Resize model to fit all tokens in tokenizer.
model.resize_token_embeddings(len(tokenizer))
Loading model configuration...
Downloading: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|433/433 [00:01<00:00, 285B/s]

Loading model`s tokenizer...
Downloading: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|433/433 [00:01<00:00, 285B/s]

Loading actual model...
Downloading: 100% |β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|436M/436M [00:36<00:00, 11.9MB/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Embedding(28996, 768, padding_idx=0)


Dataset and Collator

This is where I create the PyTorch Dataset and data collator objects that will be used to feed data into our model.

This is where I use the MovieReviewsDataset text files created with the movie_reviews_to_file function. Since data is partitioned for both train and test I will create two text files: one used for train and one used for evaluation.

I strongly recommend to use a validation text file in order to determine how much training is needed in order to avoid overfitting. After you figure out what parameters yield the best results, the validation file can be incorporated in train and run a final train with the whole dataset.

The data collator is used to format the PyTorch Dataset outputs to match the output of our specific transformers model: i.e. for Bert it will created the masked tokens needed to train.

# Create texts file from train data.
movie_reviews_to_file(path_data='/content/aclImdb/train', path_texts_file='/content/train.txt')
# Create texts file from test data.
movie_reviews_to_file(path_data='/content/aclImdb/test', path_texts_file='/content/test.txt')


# Setup train dataset if `do_train` is set.
print('Creating train dataset...')
train_dataset = get_dataset(model_data_args, tokenizer=tokenizer, evaluate=False) if training_args.do_train else None

# Setup evaluation dataset if `do_eval` is set.
print('Creating evaluate dataset...')
eval_dataset = get_dataset(model_data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None

# Get data collator to modify data format depending on type of model used.
data_collator = get_collator(model_data_args, config, tokenizer)

# Check how many logging prints you'll have. This is to avoid overflowing the 
# notebook with a lot of prints. Display warning to user if the logging steps 
# that will be displayed is larger than 100.
if (len(train_dataset) // training_args.per_device_train_batch_size \
    // training_args.logging_steps * training_args.num_train_epochs) > 100:
  # Display warning.
  warnings.warn('Your `logging_steps` value will will do a lot of printing!' \
                ' Consider increasing `logging_steps` to avoid overflowing' \
                ' the notebook with a lot of prints!')
Reading `train` partition...
neg: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|12500/12500 [00:55<00:00, 224.11files/s]
pos: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|12500/12500 [00:55<00:00, 224.11files/s]
`.txt` file saved in `/content/train.txt`

Reading `test` partition...
neg: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|12500/12500 [00:55<00:00, 224.11files/s]
pos: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|12500/12500 [00:55<00:00, 224.11files/s]
`.txt` file saved in `/content/train.txt`

Creating train dataset...
Creating evaluate dataset...


Train

Hugging Face was very nice to us for creating the Trainer class. This helps make PyTorch model training of transformers very easy! We just need to make sure we loaded the proper parameters and everything else is taking care of!

At the end of the training the tokenizer is saved along with the model so you can easily re-use it later or even load in on Hugging Face Models.

I configured the arguments to display both train and validation loss at every logging_steps. It gives us a sense of how well the model is trained.

# Initialize Trainer.
print('Loading `trainer`...')
trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=data_collator,
                  train_dataset=train_dataset,
                  eval_dataset=eval_dataset,
                  )


# Check model path to save.
if training_args.do_train:
  print('Start training...')

  # Setup model path if the model to train loaded from a local path.
  model_path = (model_data_args.model_name_or_path 
                if model_data_args.model_name_or_path is not None and 
                os.path.isdir(model_data_args.model_name_or_path) 
                else None
                )
  # Run training.
  trainer.train(model_path=model_path)
  # Save model.
  trainer.save_model()

  # For convenience, we also re-save the tokenizer to the same directory,
  # so that you can share your model easily on huggingface.co/models =).
  if trainer.is_world_process_zero():
    tokenizer.save_pretrained(training_args.output_dir)
Loading `trainer`...
Start training...
|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|[5000/5000 09:43, Epoch 2/2]

Step    Training Loss   Validation Loss
700     2.804672    2.600590
1400    2.666996    2.548267
2100    2.625075    2.502431
2800    2.545872    2.485056
3500    2.470102    2.444808
4200    2.464950    2.420487
4900    2.436973    2.410310


Plot Train

The Trainer class is so useful that it will record the log history for us. I use this to access the train and validation losses recorded at each logging_steps during training.

Since we are training / fine-tuning / extended training or pretraining (depending what terminology you use) a language model, we want to compute the perplexity.

This is what Wikipedia says about perplexity: In information theory, perplexity is a measurement of how well a probability distribution or probability model predicts a sample. It may be used to compare probability models. A low perplexity indicates the probability distribution is good at predicting the sample.

We can look at the perplexity plot in the same way we look at the loss plot: the lower the better and if the validation perplexity starts to increase we are starting to overfit the model.

Note: It looks from the plots that the train loss is higher than validation loss. That means that our validation data is too easy for the model and we should use a different validation dataset. Since the purpose of this notebook is to show how to train transformers models and provide tools to evaluate such process I will leave the results as is.

# Keep track of train and evaluate loss.
loss_history = {'train_loss':[], 'eval_loss':[]}

# Keep track of train and evaluate perplexity.
# This is a metric useful to track for language models.
perplexity_history = {'train_perplexity':[], 'eval_perplexity':[]}

# Loop through each log history.
for log_history in trainer.state.log_history:

  if 'loss' in log_history.keys():
    # Deal with trianing loss.
    loss_history['train_loss'].append(log_history['loss'])
    perplexity_history['train_perplexity'].append(math.exp(log_history['loss']))
    
  elif 'eval_loss' in log_history.keys():
    # Deal with eval loss.
    loss_history['eval_loss'].append(log_history['eval_loss'])
    perplexity_history['eval_perplexity'].append(math.exp(log_history['eval_loss']))

# Plot Losses.
plot_dict(loss_history, start_step=training_args.logging_steps, 
          step_size=training_args.logging_steps, use_title='Loss', 
          use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

print()

# Plot Perplexities.
plot_dict(perplexity_history, start_step=training_args.logging_steps, 
          step_size=training_args.logging_steps, use_title='Perplexity', 
          use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

train_valid_loss

train_valid_perplexity

Evaluate

For the final evaluation we can have a separate test set that we use to do our final perplexity evaluation. For simplicity I used the same validation text file for the final evaluation. That is the reason I get the same results as the last validation perplexity plot value.

# check if `do_eval` flag is set.
if training_args.do_eval:
  
  # capture output if trainer evaluate.
  eval_output = trainer.evaluate()
  # compute perplexity from model loss.
  perplexity = math.exp(eval_output["eval_loss"])
  print('\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))
else:
  print('No evaluation needed. No evaluation data provided, `do_eval=False`!')
|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ|[250/250 00:25]

Evaluate Perplexity:      11.01

Final Note

If you made it this far Congrats! 🎊 and Thank you! πŸ™ for your interest in my tutorial!

I've been using this code for a while now and I feel it got to a point where is nicely documented and easy to follow.

Of course is easy for me to follow because I built it. That is why any feedback is welcome and it helps me improve my future tutorials!

If you see something wrong please let me know by opening an issue on my ml_things GitHub repository!

A lot of tutorials out there are mostly a one-time thing and are not being maintained. I plan on keeping my tutorials up to date as much as I can.

Contact 🎣

🦊 GitHub: gmihaila

🌐 Website: gmihaila.github.io

πŸ‘” LinkedIn: mihailageorge

πŸ“¬ Email: georgemihaila@my.unt.edu.com