Note
Click here to download the full example code
Visualization of MLP weights on MNIST¶
Sometimes looking at the learned coefficients of a neural network can provide insight into the learning behavior. For example if weights look unstructured, maybe some were not used at all, or if very large coefficients exist, maybe regularization was too low or the learning rate too high.
This example shows how to plot some of the first layer weights in a MLPClassifier trained on the MNIST dataset.
The input data consists of 28x28 pixel handwritten digits, leading to 784 features in the dataset. Therefore the first layer weight matrix have the shape (784, hidden_layer_sizes[0]). We can therefore visualize a single column of the weight matrix as a 28x28 pixel image.
To make the example run faster, we use very few hidden units, and train only for a very short time. Training longer would result in weights with a much smoother spatial appearance. The example will throw a warning because it doesn’t converge, in this case this is what we want because of CI’s time constraints.
Traceback (most recent call last):
File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_gallery.py", line 159, in call_memory
return 0., func()
File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_rst.py", line 466, in __call__
exec(self.code, self.fake_main.__dict__)
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/examples/neural_networks/plot_mnist_filters.py", line 36, in <module>
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 738, in fetch_openml
data_info = _get_data_info_by_name(name, version, data_home)
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 381, in _get_data_info_by_name
json_data = _get_json_content_from_openml_api(url, None, False,
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 161, in _get_json_content_from_openml_api
return _load_json()
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 61, in wrapper
return f(*args, **kw)
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 157, in _load_json
with closing(_open_openml_url(url, data_home)) as response:
File "/build/scikit-learn-BOS8cN/scikit-learn-0.23.2/.pybuild/cpython3_3.8/build/sklearn/datasets/_openml.py", line 106, in _open_openml_url
with closing(urlopen(req)) as fsrc:
File "/usr/lib/python3.8/urllib/request.py", line 222, in urlopen
return opener.open(url, data, timeout)
File "/usr/lib/python3.8/urllib/request.py", line 525, in open
response = self._open(req, data)
File "/usr/lib/python3.8/urllib/request.py", line 542, in _open
result = self._call_chain(self.handle_open, protocol, protocol +
File "/usr/lib/python3.8/urllib/request.py", line 502, in _call_chain
result = func(*args)
File "/usr/lib/python3.8/urllib/request.py", line 1393, in https_open
return self.do_open(http.client.HTTPSConnection, req,
File "/usr/lib/python3.8/urllib/request.py", line 1353, in do_open
raise URLError(err)
urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>
import warnings
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.exceptions import ConvergenceWarning
from sklearn.neural_network import MLPClassifier
print(__doc__)
# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X = X / 255.
# rescale the data, use the traditional train/test split
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
solver='sgd', verbose=10, random_state=1,
learning_rate_init=.1)
# this example won't converge because of CI's time constraints, so we catch the
# warning and are ignore it here
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning,
module="sklearn")
mlp.fit(X_train, y_train)
print("Training set score: %f" % mlp.score(X_train, y_train))
print("Test set score: %f" % mlp.score(X_test, y_test))
fig, axes = plt.subplots(4, 4)
# use global min / max to ensure all weights are shown on the same scale
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=.5 * vmin,
vmax=.5 * vmax)
ax.set_xticks(())
ax.set_yticks(())
plt.show()
Total running time of the script: ( 0 minutes 0.010 seconds)