Monday, August 10, 2015

Supervised Learning

If mathematics is both the queen and servant of science, then machine learning must be the princess and the maid of AI (artificial intelligence). The main goal in AI is to develop software capable of intelligent behavior, for example, a self-driving car. One of the definitions of intelligent behavior is the ability to learn from "experience" (given data) and develop a model that can understand and react to new data. This goal is achieved with machine learning.

This post is a high-level overview of supervised learning, which is the simplest category of machine learning. In future posts, I can elaborate on the actual algorithms and I might also write about other categories, such as unsupervised learning and deep learning -- but I have to admit that I'll have to study them in more depth for that.

I think the best way to explain supervised learning is with a motivating example: spam detection. Your mail box contains a spam folder, and when spam is detected, it is stored in the spam folder rather than in the inbox. How does your mail provider recognize that a certain email is spam?

Suppose you had to manually decide whether a certain email is a spam or not. You would probably check if the sender is known or unknown and whether the message contains suspicious words and phrases such as "free", "cash bonus", and other spam triggering words. Then you can define rules on top of these observations, for instance: "classify as spam if the message is sent from an unknown sender and contains at least 2 spam triggering words, and as non-spam otherwise". In the same way, you can define these rules and let the software apply them automatically to new mails. This approach is called rule-based. While it can lead to accurate results, it requires the effort of defining the rules, which in some tasks should be performed by experts (spam detection is a relatively easy task).

An example for accurate spam detection

The common solution is to let the machine learn a model (function) that receives as input an email, and returns whether it is a spam message or not, based on observed features of the message, such as the sender and content.

In supervised learning, the machine is provided with a set of labeled examples, called training set. This is a small set of data items that their classification is known in advance (e.g. annotated by humans). Each instance describes one data item (e.g. an email message), using predefined features that are relevant for the certain task (e.g. the sender address, each of the words that occur in the subject of the message, etc). In addition, in the training set, each item has a matching true label. which is the item's known class (e.g. spam / non-spam). The machine performs a learning phase, during which it learns a function (model) that receives an unlabeled instance (e.g. new email message) and returns its predicted label (e.g. spam / non-spam).

The learning phase is performed once, and then the model is ready to use for inference as many times as you want. You can give it a new unlabeled instance (e.g. a new email message that just arrived) and it will predict its class (e.g. spam / non-spam) by applying the learned function.

Supervised learning pipeline (picture taken from here).

As you may have noticed, spam detection does not perform perfectly; sometimes a spam message is missed and stays in the inbox (the model classifies a "positive" spam as "negative" non-spam - false negative). In other times, a valid message unjustifiably finds its way to the spam folder (the model classifies a "negative" non-spam as "positive" spam - false positive).

In order for the algorithm to perform well, it needs to learn a model that best describes the training set, with the assumption that the training set is representative of the real-world instances. In order to assess how successful a learned model is (in comparison with other models or in general), an evaluation is performed. This requires an additional set of labeled examples, used to test the model, which is called the test set. This set is disjoint from the training set and not used during the learning phase. The model is applied for each of the instances in the test set, and the predicted label is compared with the true label (gold standard) given in the test set. An evaluation measure is then computed - for example precision1, recall2 or F13.

Of course that spam detection is only one example out of many examples of supervised learning. Other examples are:
  • Medical diagnosis - predict whether a patient suffers from a certain disease, based on his symptoms
  • Detecting credit card fraudulent transactions
  • Lexical inference - predict whether two terms hold a certain semantic relation, based on the relations between them in knowledge resources
In addition, there are more complex variations of supervised learning. The examples I gave were of binary classification, where each instance is classified to one of two classes: either positive (e.g. spam) or negative (e.g. non-spam). Other tasks require multi-class classification, in which every instance can be classified to one of several predefined classes; for instance, in optical character recognition (OCR), each hand-written character should be classified as one of the possible characters or digits in the alphabet.

In other tasks, any instance can be classified to multiple classes from a predefined set of classes, for instance, determining the different topics of a document, from a predefined set of topics (this post can be classified as computer science, machine learning, supervised learning, etc). This is called multi-label classification.

More complex tasks require outputting a structure rather than a class - this is called structured prediction. One such task is part-of-speech (POS) tagging: given a sentence, predict the part-of-speech of every word in the sentence (e.g. noun, verb, adjective). Rather than predicting the POS tag of every word separately, the sequence is predicted together, taking advantage of dependencies between preceding POS tags; e.g. if the previous word is tagged as a determiner, it is more likely that this word is a noun.

An example of POS tagging, from Stanford Parser
No post about machine learning is complete without mentioning overfitting and regularization. During the learning phase, the machine tries to learn a model that fits the training set. However, it might overfit the training set, by memorizing all the instances instead of learning the main trends in the data. In this case, the evaluation results when applied to the training set (trying to predict the labels of the instances without looking at the true label) will be very good. On a separate test set, however, they are expected to perform worse, since the algorithm learned a very specific function which is not good at handling unseen data. For example, suppose that our training set contains the following 6 emails (training sets are usually much larger, this is for simplicity):

SubjectTrue Label
earn extra cashspam
our meeting on Mondaynon-spam
the slides you requestednon-spam
get cash todayspam
hinon-spam
cash bonusspam

A good algorithm will learn that "cash" in the mail's subject is indicative of spam. A bad algorithm will only recognize emails with the exact subjects "earn extra cash", "get cash today" and "cash bonus" as spam. Then, if it sees a new mail with the subject "get your cash immediately", it won't know it is also spam.

The solution is to apply regularization. Without going into too technical details, regularization is used to punish the algorithm for overfitting the training set, causing it to prefer learning a more general model.

This was just the tip of the iceberg of machine learning. Stay tuned for more about it!

1 The fraction of instances that were classified as positive (e.g. predicted to be spam) that are actually positive (e.g. actual spam messages). A numeric value between 0 and 1, 1 being the best precision, in which there are no negative instances falsely classified as positive. 
2 The fraction of positive instances (e.g. spam messages) that were also classified as positive (e.g. predicted to be spam). A numeric value between 0 and 1, 1 being the best recall, in which there are no positive instances falsely classified as negative. 
3 A measure that balances between precision and recall. A numeric value between 0 and 1, 1 being the best F1. 

4 comments:

  1. So how do you punish the algorithm?

    ReplyDelete
    Replies
    1. You hit it very hard :)

      It depends on the specific algorithm (there are several such algorithm to compute a binary model). In general, these algorithms start with some random model and try to improve it with every training instance they see. A model is usually a weight given to every feature according to its indicativeness of the class (e.g. the word "cash" is highly indicative of spam, so its weight will be high).

      For every training instance, if the current model predicts its label incorrectly (different from the true label), the algorithm will change the model a bit in the "direction" of the current training instance (so that the model could predict it "more correctly"). Every change is applied to all the features (to strengthen the feature values of the current instance).

      Overfitting is caused when the model considers all the features of the training instances as indicative, even though some of them are non-indicative (specific to some training instances, and not generally to most negative / positive instances). Applying regularization means reducing the score of a certain model that has too many indicative features. By penalizing the model this way, it would make it only increase the weights of features that occurred in many training instances (of the same class) -- the indicative features.

      I hope it's not too complicated :)

      Delete
  2. "Medical diagnosis - predict whether a patient suffers from a certain disease, based on his symptoms"

    What if they don't run all tests? Is there machine learning where some items don't have information regarding some features?

    ReplyDelete
    Replies
    1. First of all, it's just as an aid to the doctor - no machine learning algorithm is accurate enough to diagnose a patient and prescribe a drug :)

      As for your question, some models can handle missing values (and for others, I assume there are heuristics about how to complete them, but I never experimented with that).

      Delete