web-dev-qa-db-fra.com

Classification binaire PyTorch - même structure de réseau, données «plus simples», mais performances moins bonnes?

Pour me familiariser avec PyTorch (et l'apprentissage en profondeur en général), j'ai commencé par travailler à travers quelques exemples de classification de base. Un de ces exemples était la classification d'un ensemble de données non linéaire créé à l'aide de sklearn (code complet disponible sous forme de bloc-notes ici )

n_pts = 500
X, y = datasets.make_circles(n_samples=n_pts, random_state=123, noise=0.1, factor=0.2)
x_data = torch.FloatTensor(X)
y_data = torch.FloatTensor(y.reshape(500, 1))

enter image description here

Ceci est ensuite classé avec précision à l'aide d'un réseau neuronal assez basique

class Model(nn.Module):
    def __init__(self, input_size, H1, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, H1)
        self.linear2 = nn.Linear(H1, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.linear(x))
        x = torch.sigmoid(self.linear2(x))
        return x

    def predict(self, x):
        pred = self.forward(x)
        if pred >= 0.5:
            return 1
        else:
            return 0

Comme je m'intéresse aux données de santé, j'ai alors décidé d'essayer d'utiliser la même structure de réseau pour classer certains ensembles de données de base du monde réel. J'ai pris des données de fréquence cardiaque pour un patient de ici , et je les ai modifiées pour que toutes les valeurs> 91 soient étiquetées comme des anomalies (par exemple, un 1 et tout <= 91 étiqueté a 0). C'est complètement arbitraire, mais je voulais juste voir comment la classification fonctionnerait. Le cahier complet de cet exemple est ici .

enter image description here

Ce qui n'est pas intuitif pour moi, c'est pourquoi le premier exemple atteint une perte de 0,0016 après 1000 époques , alors que le deuxième exemple n'atteint que une perte de 0,4296 après 10 000 époques

Training Loss for Example 1

Training Loss for Heart Rate Example

Je suis peut-être naïf en pensant que l'exemple de la fréquence cardiaque serait beaucoup plus facile à classer. Toute idée pour m'aider à comprendre pourquoi ce n'est pas ce que je vois serait formidable!

14
Philip O'Brien

TL; DR

Vos données d'entrée ne sont pas normalisées.

  1. utilisez x_data = (x_data - x_data.mean()) / x_data.std()
  2. augmenter le taux d'apprentissage optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Tu auras
enter image description here

convergence en seulement 1000 itérations.

Plus de détails

La principale différence entre les deux exemples que vous avez est que les données x dans le premier exemple sont centrées autour de (0, 0) et ont une très faible variance.
En revanche, les données du deuxième exemple sont centrées autour de 92 et présentent une variance relativement importante.

Ce biais initial dans les données n'est pas pris en compte lorsque vous aléatoirement initialisez les pondérations ce qui se fait en supposant que les entrées sont réparties à peu près normalement autour de zéro.
Il est presque impossible pour le processus d'optimisation de compenser cet écart brut - ainsi le modèle se retrouve coincé dans une solution sous-optimale.

Une fois que vous avez normalisé les entrées, en soustrayant la moyenne et en divisant par la std, le processus d'optimisation redevient stable et converge rapidement vers une bonne solution.

Pour plus de détails sur la normalisation des entrées et l'initialisation des poids, vous pouvez lire la section 2.2 dans He et al Plonger profondément dans les redresseurs: dépasser le niveau humain Performances sur la classification ImageNet (ICCV 2015).

Et si je ne peux pas normaliser les données?

Si, pour une raison quelconque, vous ne pouvez pas calculer à l'avance les données moyennes et standard, vous pouvez toujours utiliser nn.BatchNorm1d pour estimer et normaliser les données dans le cadre du processus de formation. Par exemple

class Model(nn.Module):
    def __init__(self, input_size, H1, output_size):
        super().__init__()
        self.bn = nn.BatchNorm1d(input_size)  # adding batchnorm
        self.linear = nn.Linear(input_size, H1)
        self.linear2 = nn.Linear(H1, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.linear(self.bn(x)))  # batchnorm the input x
        x = torch.sigmoid(self.linear2(x))
        return x

Cette modification sans toute modification des données d'entrée, produit une convergence similaire après seulement 1000 époques:
enter image description here

Un petit commentaire

Pour la stabilité numérique, il est préférable d'utiliser nn.BCEWithLogitsLoss au lieu de nn.BCELoss . À cette fin, vous devez supprimer le torch.sigmoid De la sortie forward(), le sigmoid sera calculé à l'intérieur de la perte.
Voir, par exemple, ce fil concernant la perte sigmoïde + entropie croisée associée pour les prédictions binaires.

16
Shai