While I might be late to the dangerous language models party, I thought this blog lacks a basic post about text generation. How are these models trained? How are they used? Are they really that good? And dangerous?
Scope
Language Models
Generating Text
Training a language model
Evaluating text generation
Are language models dangerous?
Scope
The reason that everyone is talking about language models (LMs) lately is not so much that they're all working on text generation, but because pre-trained LMs (like the OpenAI GPT-2 or Google's BERT) are used to produce text representations across various NLP applications, greatly improving their performances. The effect is similar to the effect that pre-trained word embeddings had on NLP in 2013. I recommend reading Sebastian Ruder's article NLP's ImageNet moment has arrived that summarizes it very nicely. This blog post will focus on text generation.
There is an important distinction between two main types of applications of text generation:
1. Open-ended generation: the purpose is to generate any text. It could be on some specific topic or continuing a previous paragraph, but the model is given the artistic freedom to generate any text.
This post focuses on open-ended text generation.
Language Models
I've discussed LMs in one of the earlier posts in this blog, in the context of machine translation. Simply put, a language model is a probability distribution of the next word in the text, given the previous words. The distribution is over all the words in the vocabulary, which is typically very large (may be a few hundred thousands or more).
For example, what can be the next word in the sentence "I'm tired, I want to"? A good language model would assign a high score to p(sleep|I'm tired, I want to). The probability of a word like "bed" should be low - although it is a related term, it doesn't form a grammatical sentence; or "party" which is syntactically correct but contradicts with logic. The probability of an entire sentence is the product of the conditional probability of each word given the previous words, using the chain rule:
p(I'm tired, I want to sleep) = p(I'm|<s>) * p(tired|<s> I'm) * p(,|<s> I'm tired) * p(I|<s> I'm tired,) * p(want|<s> I'm tired, I) * p(to|<s> I'm tired, I want) * p(sleep|<s> I'm tired, I want to) * p(</s>|<s> I'm tired, I want to sleep)
where <s> and </s> mark the beginning and the end of the sentence, respectively. Note that I used a word-based LM for the purpose of demonstration in this post, however, it's possible to define the basic token as a character or a "word piece" / "subword unit".
Generating Text
While LMs can be used to score a certain text on its likelihood in the language, in this post we will discuss another common usage of them which is to generate new text. Assuming we've already trained a language model, how do we generate text? We will demonstrate it with this very simple toy LM, which has a tiny vocabulary and very few probable utterances:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# Create a very simple unigram LM with a small vocabulary | |
vocab = ['<s>', 'This', 'LM', 'is', 'stupid', 'cool', '</s>'] | |
word2index = {w: i for i, w in enumerate(vocab)} | |
probabilities = np.array([[0.01, 0.94, 0.01, 0.01, 0.01, 0.01, 0.01], | |
[0.01, 0.01, 0.94, 0.01, 0.01, 0.01, 0.01], | |
[0.01, 0.01, 0.01, 0.94, 0.01, 0.01, 0.01], | |
[0.01, 0.01, 0.01, 0.01, 0.47, 0.48, 0.01], | |
[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.94], | |
[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.94], | |
[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.94]]) | |
stupid_lm = lambda s: probabilities[word2index.get(s.split()[-1], -1), :] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_most_probable(lm, index2word): | |
""" | |
Generates a string, taking the most probable word at each time step. | |
:param lm - the language model: a function that gets a string and returns a distribution on the next word | |
:param index2word - a mapping from the index of a word in the vocabulary to the word itself | |
""" | |
generated_sentence = '<s>' | |
curr_token = None | |
while curr_token != '</s>': | |
curr_distribution = lm(generated_sentence) # vector of probabilities | |
sorted_by_probability = np.argsort(curr_distribution) | |
curr_token = index2word[int(sorted_by_probability[-1])] # last token is the most probable | |
generated_sentence += ' ' + curr_token | |
return generated_sentence | |
generated_str = generate_most_probable(stupid_lm, vocab) | |
print(generated_str) |
An alternative is to sample from the distribution, i.e., randomly select a word from the vocabulary, proportional to its probability given the previous words, according to the language model. The code will look something like this:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_sample(lm, index2word, max_tokens=25): | |
""" | |
Generates a string, sample a word from the distribution at each time step. | |
:param lm - the language model | |
:param index2word - a mapping from the index of a word in the vocabulary to the word itself | |
""" | |
generated_sentence = '<s>' | |
generated_tokens = 0 | |
curr_token = None | |
while curr_token != '</s>' and generated_tokens < max_tokens: | |
curr_distribution = lm(generated_sentence) # vector of probabilities | |
selected_index = np.random.choice(range(len(vocab)), p=curr_distribution) | |
curr_token = index2word[int(selected_index)] | |
generated_sentence += ' ' + curr_token | |
generated_tokens += 1 | |
return generated_sentence | |
generated_str = generate_sample(stupid_lm, vocab) | |
print(generated_str) |
A simple solution is to combine the two approaches and sample only from the top k most probable words in the distribution, for some pre-defined k (as done in this work). This is what it would look like:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.special import softmax | |
def generate_sample_top_k(lm, index2word, k=5, max_tokens=25): | |
""" | |
Generates a string, sample a word from the top k probable words in the distribution at each time step. | |
:param lm - the language model | |
:param index2word - a mapping from the index of a word in the vocabulary to the word itself | |
:param k - how many words to keep in the distribution | |
""" | |
generated_sentence = '<s>' | |
curr_token = None | |
generated_tokens = 0 | |
while curr_token != '</s>' and generated_tokens < max_tokens: | |
curr_distribution = lm(generated_sentence) # vector of probabilities | |
sorted_by_probability = np.argsort(curr_distribution) # sort by probability | |
top_k_indices = sorted_by_probability[-(k+1):] # keep the top k words | |
top_k = [curr_distribution[i] if i in set(top_k_indices) else 0.0 for i in range(len(vocab))] | |
# normalize to make it a probability distribution again | |
top_k = softmax(top_k) | |
selected_index = np.random.choice(range(len(vocab)), p=top_k) | |
curr_token = index2word[int(selected_index)] | |
generated_sentence += ' ' + curr_token | |
generated_tokens += 1 | |
return generated_sentence | |
generated_str = generate_sample_top_k(stupid_lm, vocab) | |
print(generated_str) |
Notice that after keeping only k words in the distribution, we need to make sure again that they form a valid probability distribution, i.e. each entry is between 0 and 1, and the sum is 1.
An alternative way to sample from the top of the distribution is top p: sort the tokens by their probability from highest to lowest, and take tokens until the sum of probabilities (which is exactly the probability to generate any of these tokens) reaches some pre-defined value p between 0 and 1. A small number close to 0 is similar to always taking the most probable token, while a large number close to 1 is similar to sampling from the entire distribution. This method is more flexible from top k because the number of candidate tokens may change according to the generated prefix. For example, a general text like I want to may have many valid continuations (with a relatively small probability for each), while a more specific text like The bride and the groom got will have much fewer, with the obvious next token married taking most of the probability mass.
Update 01/11/21: a top p snippet is now available - thanks to Saptarshi Sengupta for the contribution!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#top-p sampling | |
def generate_sample_top_p(lm, index2word, max_tokens=5, p=0.96): | |
""" | |
Generates a string, sample a word from the top k probable words in the distribution at each time step. | |
:param lm - the language model | |
:param index2word - a mapping from the index of a word in the vocabulary to the word itself | |
:param p - Till what probability should our sum run too | |
""" | |
generated_sentence = '<s>' | |
curr_token = None | |
generated_tokens = 0 | |
while curr_token != '</s>' and generated_tokens < max_tokens: | |
curr_distribution = lm(generated_sentence) # vector of probabilities | |
sorted_by_probability = np.argsort(curr_distribution) # sort by probability | |
total_prob = 0 | |
i = -1 | |
prob_dist = [] | |
prob_indices = [] | |
while total_prob < p: | |
curr_index = sorted_by_probability[i] | |
curr_prob = curr_distribution[curr_index] | |
temp = total_prob + curr_prob | |
if temp < p: | |
total_prob += curr_prob | |
prob_dist.append(curr_prob) | |
prob_indices.append(curr_index) | |
else: | |
break | |
i -= 1 | |
prob_dist = softmax(prob_dist) | |
selected_index = np.random.choice(prob_indices, p=prob_dist) | |
curr_token = index2word[int(selected_index)] | |
generated_sentence += ' ' + curr_token | |
generated_tokens += 1 | |
return generated_sentence | |
generated_str = generate_sample_top_p(stupid_lm, vocab) | |
print(generated_str) |
Training a language model
I've already discussed N-gram language models, but by the time I wrote that post (4 years ago), they were already obsolete and replaced by neural language models. The basic algorithm for training a neural LM is as follows:
A large amount of text is dedicated for training (training corpus).
The model goes over the corpus, sentence by sentence.
For a given sentence w1... wn, for each word wi:
|
The various neural LMs differ in their choice of basic token (i.e. word, character, word piece) and encoder. The encoder takes a sequence of word embeddings and returns a single vector representing the corresponding sequence of words (e.g. ... tired, I want to). I may have a separate post in the future that focuses on ways to encode text into a vector. For the purpose of this post, let's treat it as a black box function. The following figure illustrates the training (specifically for an encoder based on an RNN):
The two main advantages of neural LMs over N-gram LMs are:
(1) N-gram LMs predict the next word based on a history of N-1 words, e.g. given I'm tired, I want to, a 3-gram LM will predict the next word only based on the last 3 words "I want to", completely ignoring the crucial word "tired". N-gram LMs were usually based on small Ns (2-4) (see the post about N-gram language models for explanation).
(2) N-gram LMs are based on the statistics of how many times each text appeared in the data, and it has to be verbatim, i.e. the occurrences of I'm tired are disjoint from those of I'm exhausted. Neural LMs, on the other hand, learn to represent a fragment of text as a vector and to predict the next word based on it. It may generalize semantically-similar texts by assigning them similar vector representations (resulting in the same prediction).
Important note: some LMs today are trained with a different training objective, i.e. not optimizing guessing correctly the next word in the sentence. Specifically, BERT has a "masked LM objective", i.e. hiding random words in the sentence and guessing them from their surrounding context - tokens before and after these hidden words. Text GANs (Generative Adversarial Networks) consist of two competing components: a generator that generates human-like text and a discriminator trained to distinguish between human-generated and generator-generated texts. In practice, current GAN-based text generation doesn't perform as well as generation from language models (see here and here).
Evaluating text generation
Comparing the performance of two classifiers that were trained to solve the same task is easy - we have a test set with the true label of each data point; we predict the test labels using each model, and compute the accuracy of each model compared to the true labels. We then have exactly two numeric values - the higher the accuracy, the better the model. This is not the case for text generation.
Since we are talking about open-ended generation, there is no gold standard text the model is expected to produce (we have a test set, but we really just want to make sure the generated text looks like it), so how can we judge the model's quality? The best we can do is to manually examine some of the model outputs and decide whether we think it's a good (human-like?) text or not. To do so more systematically we can perform a more proper human evaluation by showing people texts generated by our model vs. texts generated by some baseline model (or by humans...), asking them to rate which is better, and aggregating across multiple judgements on multiple texts. While this is probably the best evaluation method, it is costly and takes a long time to obtain. As a result, it is usually applied to a relatively small number of texts at the final stages of the model development, and isn't used to validate texts in the intermediate steps (which can potentially help improving the model).
The alternative and commonly used metric is perplexity: by definition, it is the inverse probability of the test set, normalized by the number of words. So we want to get a low perplexity score as possible which means the probability of the test set is maximized - i.e., the LM learned a probability distribution which is similar to the "truth". The test set is just a bunch of texts which the LM has not seen before, and its probability is computed by going over it word by word and computing the LM probability of predicting each word given its past. A good LM will assign high probability to the "correct" (actual) next word and a low probability to other words.
Although perplexity is the most common evaluation metric for text generation, it is criticized for various reasons. Mainly, because it has been shown that improvement in perplexity doesn’t always mean an improvement in applications using the language model (it's basically not a good indicator of quality). And also because perplexity can't be used to evaluate text generation models that don't produce a distribution of words under the hood, like GANs. And if you thought that evaluation metrics for non open-ended generation are better, think twice!1
Are language models dangerous?
In the previous post I discussed the potential misuses of machine learning models, so the starting point should be that yes, if used by people with malicious intentions, LMs may pose a danger. More specifically, the announcement from OpenAI expressed the concern that such a model, if released, may be used to generate fake news at scale. While this is not completely unreasonable, there are currently two limitations of text generation that may help reducing the fear of LMs enhancing disinformation, at least temporarily.
Yes, generated text today is quite impressive. It is grammatical and in most cases it doesn't deviate from the topic. But it is not fact-aware (see how it continues the following sentence: GPT-2 is a language model ___), it has little common sense (and this one: she fell and broke her leg because someone left a banana peel ____), and as previously mentioned, often just doesn't read "human-like". Even when it does and humans can't tell that it's machine-generated, there are models that are good at detecting that. The robots may fail us humans, but not each other 🤖
Fear of disinformation is justified, but at least at its current state, I'm more concerned about the humans involved in it. Those that initiate and generate it, those that spread it with evil intention, and especially the many others that spread it ignorantly and naively. Perhaps, in parallel to the race between technology used for or developed against disinformation, we can also train humans to think more critically?
I learned a lot of what I know about text generation pretty recently, thanks to my awesome collaborators on the text GAN evaluation paper and my teammates at AI2/UW (especially Ari Holtzman and Rowan Zellers). Thanks!
1 The evaluation of non open-ended generation depends on the task, yet suffers from a major issue: the gold standard is a given text, but it may not be the only correct text due to variability in language. In machine translation, for example, the standard evaluation metric is BLEU, which basically compares chunks of text in the reference (gold standard) translation to the system predicted translation. Various correct translations may differ in their syntactic structure or in the choice of words. Penalizing a model for not predicting the exact sentence that the human translators suggested (and which is found in the test set) is unfair, yet this is the standard way to evaluate machine translation models today. The same issue exists for summarization with the ROUGE metric. For a much more elaborate discussion on this topic, see Rachael Tatman's blog post. ↩
1 The evaluation of non open-ended generation depends on the task, yet suffers from a major issue: the gold standard is a given text, but it may not be the only correct text due to variability in language. In machine translation, for example, the standard evaluation metric is BLEU, which basically compares chunks of text in the reference (gold standard) translation to the system predicted translation. Various correct translations may differ in their syntactic structure or in the choice of words. Penalizing a model for not predicting the exact sentence that the human translators suggested (and which is found in the test set) is unfair, yet this is the standard way to evaluate machine translation models today. The same issue exists for summarization with the ROUGE metric. For a much more elaborate discussion on this topic, see Rachael Tatman's blog post. ↩