Real Relationships & Overfitting
- 08:31
How real relationships are used in machine learning, as well as the challenges of overfitting and underfitting models to complex data relationships.
Downloads
No associated resources to download.
Transcript
Let's talk about the real relationships we're trying to model and the problem of overfitting in machine learning. The purpose of machine learning is to find and quantify real relationships because if you understand real relationships, you can predict the future, and if you can predict the future, you'll make better decisions. For example, you learned about the relationship between price and demand in your first economics course as price decreases, quantity of demand increases and vice versa. You can see it in the graph right here. If you're pricing a product, your understanding of this relationship affects your decisions and ultimately whether you make money or lose money. For example, if TD Ameritrade is charging 6.95 per trade, they might process 650,000 trades per day. TD Ameritrade knows that if they decrease their fee to 4.95 per trade, they would process more trades, and if they increase their fee to 9.95 per trade, they would process fewer trades. Seems pretty obvious, but how many trades exactly? If they drop the price, would volume increase by enough to increase their ultimate revenue? Or if they increase the price, would the decrease in volume cause a loss in revenue? In order to answer this question and make an accurate prediction, which will determine the price decision, TD Ameritrade could use a regression model. Not all regression models require machine learning For a basic relationship, a simple one variable linear regression might be sufficient to make predictions, and you see a graph of a one variable linear regression model right here. However, if you're trying to model more complex relationships, you're going to need something more powerful. And as a quick tip, many people assume that there's a hard boundary between basic analyses like a simple linear regression and this magical realm of machine learning. The truth is that these methods exist on a spectrum and it's difficult to point to one method where basic stops and machine learning begins. We're going to start simple in this lesson and then move toward more advanced methods. Visually, if you look at the graph right here, you can clearly see that there's a relationship in this data, but it's also clear that the relationship is not linear. You can see it curve up and then curve down and then move back up. There are these different inflection points In this situation. You would say that the model, the simple linear regression, which is this dotted line, cutting across the data, is underfit to the data. That means that the real relationship is more complex than the model that you're fitting to the data and the model is too simple to describe the real relationship. Therefore, if you make a decision based on a prediction from the model, it's likely that you're going to make a bad decision. This table is a good way to define some of the language that we're talking about. If your real relationship complexity is greater than your model complexity, which is what you see in the graph, your model is under fit, your model is more simple than the relationship you're trying to model. If your real relationship complexity is equal to your model complexity, that's a good fit. That's what you want. You want the complexity of your model to be just about equal to the complexity of the relationship of your data. However, if your model complexity is greater than the real relationship complexity, that means that your model is overfit. Your model is reading complexity in the real relationship that's not actually there, and with machine learning, there's no limit to the complexity of the models you can create, and what that means is that it's easy to correct underfitting, but there's a big danger of overfitting. Your goal is to match the complexity of your model to the complexity of the real relationship that you're modeling. So you see here, this is a good fit. The data is moving, and it's a complex relationship. It's not just a simple linear relationship, but the model is moving along nicely and it seems to correctly model the general trend where the data is moving over time. But if you're not careful, machine learning model complexity can quickly get out of hand like you see right here where the model is just jumping from data point to data point to data point, and obviously a model like this is not modeling real relationships. There's a lot of random fluctuation in the data, and the model is reading too much into that random fluctuation. What you ultimately want is something that understands the general trend without getting too complex. Returning to the example of the price demand relationship, you know that it's true that variation in price significantly influences variation in demand. However, you also know that it's false to say that variation in price is the only cause of variation in demand. There are countless other factors at play influencing demand, the state of the economy, the success of your advertising campaigns, the actions of your competitors, way too many variables to account for in your model. This variation that can't be predicted by your model is called error. It's possible to develop a model that has very little error, and that's the goal, but you can't develop a model with zero error.
A good model fully describes the real relationship between its input variables and its target variable, but it also acknowledges that some variation in the target variable is caused by factors that are not inputs in the model. In other words, a good model is reasonably specific and reasonably confident. For example, imagine that the model below estimates the height of the ocean tide, which is your target variable based on the time of the day, which is your input variable. The real relationship might be that you can predict 80% of the tide variation based on the time of day, but 20% of the variation is caused by other factors. This is a good model because it allows you to make predictions with a reasonable amount of accuracy, but it also gives you a sense for how confident you should be in that prediction, meaning the risk of the prediction being wrong because you can make reasonably good guesses, but you also know that it's not 100 percent accurate. You can make better decisions using your model and account for some of that risk in your decisions. As a quick aside, you might remember from your last statistics class that model error is measured with R squared, which is between 0 and 1. The imaginary ocean tide model that you see here has an R squared of 0.80, which means that 80% of the variation in the target variable is explained by the variation in the input variable.
An overfit model is overly specific, overly confident, and when applied to make predictions for new observations, it's usually wrong. In an attempt to minimize error, an overfit model claims to explain 100% of the variation of the target variable with variation in the input variables. This results in a model with zero error, but it's based on exaggerated relationships. Using the example of the ocean tide, the model above indicates that 100% of the variation in the ocean tide can be predicted with the time of day. High tide will be at exactly 2:13 PM every day. Low tide will occur at exactly 4:06 PM every day and so on. The problem is that this exaggerates the real relationship between the tide and the time of day. The truth is that there are other factors such as the amount of rainfall that affect the target variable.
One red flag that should warn you of an overfit model is a suspiciously high R squared. For example, the imaginary ocean tide model that we're looking at above has an R squared of 1 perfect prediction. You should be a little bit suspicious of an R squared like that.
Ultimately, you want a model that's generalizable, which means that it performs equally as well with new unseen data as it performs with your training data. That is to say, if a model explains 80% of the variance in 100 training observations where you know the answer, it should also be able to explain 80% of the variance in 100 brand new observations that it's never seen before. You can identify under fitting and overfitting looking at R squared in your training data versus R squared in new data. If your model is under fit, meaning that the model is more simple and the relationship is more complex, then your R squared with the training data is going to be low. You'll make bad predictions with your training data. With that very simple model and your R squared with your new data will also be low. It's too simple to predict a real relationship. If your model's a good fit, the complexity of your model is about equal to the complexity of the real relationships you're modeling. Then the R squared with your training data will be high, and then when you apply it to new data that you've never seen before, your R squared will also be high. You'll be able to make good predictions in the real world. However, if your model is over fit, you're still going to get a high R squared with your training data, maybe even higher than if the model is a good fit. But the problem is when you apply that model to the real world data that the model has never seen before, your R squared with the new data is gonna be terrible. The real danger of using overfit models is that you're gonna feel highly confident in your predictions, but they're going to be wrong.