{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ahIDpzwk4Pff"
},
"source": [
"# Παραγωγικά Δίκτυα Μάθησης με Αντιπαλότητα (ΠΔΜΑ) - Generative Adversarial Networks (GAN)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ylNh3I_84oWp"
},
"source": [
"Όπως είπαμε και στη διάλεξη, ένα Παραγωγικό Δίκτυο Μάθησης με Αντιπαλότητα (ΓΔΜΑ) - Generative Adversarial Network (GAN) είναι ένα παραγωγικό (generative) σύστημα μηχανικής μάθησης το οποίο [προτάθηκε από τον Goodfellow το 2014](https://arxiv.org/abs/1406.2661), και αποτελείται από δύο διακριτά νευρωνικά δίκτυα:\n",
"1. Τον γεννήτορα (generator), ο οποίος προσπαθεί να δημιουργήσει νέα δείγματα δεδομένων, όμοια προς τα υπάρχοντα\n",
"2. Τον διευκρινιστή (discriminator), ο οποίος καλείται να αναγνωρίσει αν ένα δείγμα δεδομένων είναι \"αληθινό\" ή \"ψεύτικο\"\n",
"\n",
"Η είσοδος του δικτύου του γεννήτορα είναι μια σειρά τυχαίων αριθμών (θορύβου), τα οποία καλούνται λανθάνοντα δείγματα (latent samples). Όπως είπαμε και παραπάνω, ο γεννήτορας προσπαθεί να δημιουργήσει δείγματα, τα οποία τελικά προέρχονται από την επιθυμητή κατανομή των δεδομένων. Ο θόρυβος που λαμβάνει στην είσοδό του διέρχεται μέσω των διαφορίσημων συναρτήσεων ενεργοποίησης του δικτύου και μετασχηματίζεται, μέσω της εκπαίδευσης, με τέτοιο τρόπο έτσι ώστε στην έξοδό του να παράξει \"ρεαλιστικά δεδομένα\". Άρα, στη συγκεκριμένη περίπτωση, ο ρόλος της εισόδου είναι η εισαγωγή τυχαιότητας στο σύστημα, η οποία επιτρέπει στον γεννήτορα να παράξει εξόδους που θα καλύπτουν όλο το εύρος των υπό εξέταση δεδομένων.\n",
"\n",
"Το δίκτυο του διευκρινιστή από την άλλη αποτελείται από έναν ταξινομητή, ο οποίος εκπαιδεύεται μέσω επιβλεπόμενης μάθησης σε ένα dataset και μαθαίνει να αναγνωρίζει αν η έξοδος που παράγει ο γεννήτορας ανταποκρίνεται στα δεδομένα στα οποία έχει εκπαιδευτεί ο διευκρινιστής ή όχι.\n",
"\n",
"Η εκπαίδευση ενός GAN προσομοιάζει με ένα παιχνίδι minmax. Ο γεννήτορας προσπαθεί να μάθει να δημιουργεί δεδομένα με τέτοιο τρόπο που ο διευκρινιστής δε θα μπορεί πλέον να αναγνωρίζει αν είναι ψεύτικα ή όχι. Ο ανταγωνισμός αυτός μεταξύ των δύο δικτύων βελτιώνει τη δυνατότητα μάθησης του συνολικού δικτύου, μέχρις ότου ο γεννήτορας μάθει να δημιουργεί \"ρεαλιστικά\" δεδομένα.\n",
"\n",
"Στο συγκεκριμένο notebook θα δούμε πως μπορούμε να χρησιμοποιήσουμε τα GANs για την παραγωγή εικόνων ψηφίων, παρόμοιων με αυτά που υπάρχουν στο γνωστό μας [MNIST dataset](http://yann.lecun.com/exdb/mnist/).\n",
"\n",
"Για την ταχύτερη εκτέλεση του κώδικα, στην περίπτωση που χρησιμοποιείτε Google Colaboratory, συστήνεται να επιλέξετε τις κάρτες γραφικών ως τον επιταχυντή υλικού (hardware accelerator). Αυτό επιτυγχάνεται επιλέγοντας το μενού Runtime και στη συνέχεια το υπομενού Change Runtime type. Στο νέο παράθυρο που θα εμφανιστεί επιλέγουμε GPU στο hardware accelerator και κατόπιν Save.\n",
"\n",
"Η εκπαίδευση των GANs στα κελιά που θα ακολουθήσουν θα χρειαστεί κάποιο χρόνο, οπότε ενδέχεται να εμφανιστεί παράθυρο σφάλματος με τίτλο “Runtime disconnected” και μήνυμα “The connection to the runtime has timed out”. Σε αυτή την περίπτωση πατάμε “Reconnect” και συνδεόμαστε ξανά με το περιβάλλον εκτέλεση του κώδικα."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cGnrlzdR4o2E"
},
"source": [
"## Βιβλιοθήκες"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kwWhAx4x4rwA"
},
"source": [
"Ξεκινάμε εισάγοντας τις βιβλιοθήκες που πρόκειται να χρησιμοποιήσουμε (κυρίως το *keras* και δευτερευόντως το *scikit-learn,* το *numpy* και το *matplotlib*)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KMUh7XVTB9cZ"
},
"outputs": [],
"source": [
"from keras import backend as K\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import BatchNormalization, Conv2D, Conv2DTranspose, Dense\n",
"from tensorflow.keras.layers import Flatten, Input, LeakyReLU, ReLU, Reshape\n",
"from tensorflow.keras.optimizers import Adam, RMSprop\n",
"from tensorflow.keras.initializers import RandomNormal\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Θέτουμε συγκεκριμένη «σπορά» στη γεννήτρια ψευδοτυχαίων αριθμών για να \n",
"# μπορούμε να αναπαράξουμε τα αποτελέσματα\n",
"np.random.seed(2022)\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNiHYFnC46mx"
},
"source": [
"### Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ykgfTWAg4_Fz"
},
"source": [
"Αρχικά φορτώνουμε το MNIST Dataset από το keras"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bBkp_s_CXSI"
},
"outputs": [],
"source": [
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RuFbCgzw5F0d"
},
"source": [
"Ας δούμε μια χαρακτηριστική εικόνα (ψηφίο) για την κάθε ετικέτα"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "blRzwHQyCaik"
},
"outputs": [],
"source": [
"fig = plt.figure()\n",
"for i in range(10):\n",
" plt.subplot(2, 5, i+1)\n",
" x_y = X_train[y_train == i]\n",
" plt.imshow(x_y[0], cmap='gray', interpolation='none')\n",
" plt.title(\"Κλάση %d\" % (i))\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" \n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ZtELO055Vbj"
},
"source": [
"Μετασχηματίζουμε τις εικόνες από πίνακα $28\\times28$ pixels σε να διάνυσμα $784$ χαρακτηριστικών. Επίσης τους αλλάζουμε κλίμακα, απεικονίζοντάς τες γραμμικά από το $[0, 255]$ στο $[-1,1]$ με την χρήση του [MinMaxScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html) του scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AIKGvszwClO3"
},
"outputs": [],
"source": [
"print('Διαστάσεις X_train:', X_train.shape)\n",
"\n",
"# μετασχηματισμός εικόνας σε διάνυσμα\n",
"X_train = X_train.reshape(60000, 28*28)\n",
"# κανονικοποίηση στο [-1,1]\n",
"scaler = MinMaxScaler(feature_range=(-1,1))\n",
"\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"\n",
"print('Διαστάσεις X_train_scaled:', X_train_scaled.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HFFToN2aza_T"
},
"source": [
"## Simple GAN"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Juam1nPdTv_N"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3hSHeVNOVGuM"
},
"source": [
"Η αντικειμενική συνάρτηση που προσπαθεί να βελτιστοποιήσει το απλό GAN εναι η παρακάτω:\n",
"\n",
" $ \\underset{\\theta_{g}}{min} \\: \\underset{\\theta_{d}}{max} V(D,G) = \\mathbb{E}_{x\\sim p_{data}(x)} [log D_{\\theta_{d}}(x)] + \\mathbb{E}_{z\\sim p_{z}(z)}[log(1 - D_{\\theta_{d}}(G_{\\theta_{g}}(z)))]$\n",
" \n",
"\n",
"* Ο διευκρινιστής $D$ θέλει να μεγιστοποιήσει την αντικειμενική συνάρτηση όσον αφορά τις παραμέτρους του ($\\theta_d$), έτσι ώστε το $D(x)$ να είναι κατά το δυνατόν εγγύτερα στο $1$ (αληθινά δεδομένα) και το $D(G(z))$ κοντά στο $0$ (ψευδή δεδομένα)\n",
"* Ο γεννήτορας $G$ θέλει να ελαχιστοποιήσει την αντικειμενική συνάρτηση όσον αφορά τις παραμέτρους του ($\\theta_g$), έτσι ώστε το $D(G(z))$ να είναι εγγύτερα στο $1$\n",
"\n",
"Ξεκινάμε ορίζοντας το μέγεθος του χώρου των λανθανουσών μεταβλητών $z$, ο οποίος πρόκειται να είναι ένα διάνυσμα μεγέθους $100$ χαρακτηριστικών "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t1uF-PKzCwDQ"
},
"outputs": [],
"source": [
"# διάσταση χώρου λανθανουσών μεταβλητών z\n",
"latent_dim = 100\n",
"\n",
"# διάσταση εικόνας \n",
"img_dim = 28*28\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pr8P9d7JZd9n"
},
"source": [
"Το δίκτυο του γεννήτορα είναι ένα πλήρως διασυνδεδεμένο βαθύ δίκτυο $3$ κρυφών επιπέδων, μεγέθους $128, 256$ και $512$ νευρώνων αντίστοιχα."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IQQCQqH1xxu3"
},
"outputs": [],
"source": [
"# Δίκτυο γεννήτορα\n",
"generator = Sequential([\n",
" # Επίπεδο εισόδου και 1ο κρυφό επίπεδο\n",
" Dense(128, input_shape=(latent_dim,), \n",
" kernel_initializer=RandomNormal(stddev=0.02)),\n",
" LeakyReLU(alpha=0.2),\n",
" BatchNormalization(momentum=0.8),\n",
" \n",
" # 2ο κρυφό επίπεδο\n",
" Dense(256),\n",
" LeakyReLU(alpha=0.2),\n",
" BatchNormalization(momentum=0.8),\n",
" \n",
" # 3ο κρυφό επίπεδο\n",
" Dense(512),\n",
" LeakyReLU(alpha=0.2),\n",
" BatchNormalization(momentum=0.8),\n",
" \n",
" # Επίπεδο εξόδου \n",
" Dense(img_dim, activation='tanh')\n",
"])\n",
"\n",
"# Σύνοψη δικτύου\n",
"generator.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1XHoltUOZy-e"
},
"source": [
"Το δίκτυο του διευκρινιστή έχει αντίστοιχη δομή με αυτή του γεννήτορα, με τη διαφορά πως αλλάζει η είσοδος και η έξοδος."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Tg3cZJuDeGD"
},
"outputs": [],
"source": [
"# Δίκτυο Διευκρινιστή\n",
"discriminator = Sequential([\n",
" # Επίπεδο εισόδου και 1ο κρυφό επίπεδο \n",
" Dense(128, input_shape=(img_dim,), \n",
" kernel_initializer=RandomNormal(stddev=0.02)),\n",
" LeakyReLU(alpha=0.2),\n",
" \n",
" # 2ο κρυφό επίπεδο\n",
" Dense(256),\n",
" LeakyReLU(alpha=0.2),\n",
"\n",
" # 3ο κρυφό επίπεδο\n",
" Dense(512),\n",
" LeakyReLU(alpha=0.2),\n",
"\n",
" # Επίπεδο εξόδου\n",
" Dense(1, activation='sigmoid')\n",
"])\n",
"\n",
"# Σύνοψη μοντέλου\n",
"discriminator.summary()\n",
"\n",
"discriminator.compile(optimizer=Adam(learning_rate=2e-4, beta_1=0.5), \n",
" loss='binary_crossentropy', metrics=['binary_accuracy'])\n",
"\n",
"discriminator.trainable = False"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gk7BRvfBaCR6"
},
"source": [
"Πλέον, είμαστε έτοιμοι να ορίσουμε το συνολικό GAN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ff1pGuYiEOrx"
},
"outputs": [],
"source": [
"gan = Sequential([\n",
" generator,\n",
" discriminator\n",
"])\n",
"\n",
"# Σύνοψη μοντέλου\n",
"gan.summary()\n",
"\n",
"gan.compile(optimizer=Adam(learning_rate=2e-4, beta_1=0.5), \n",
" loss='binary_crossentropy', metrics=['binary_accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B-179zf0aN4n"
},
"source": [
"Με βάση την αντικειμενική συνάρτηση που ορίσαμε παραπάνω, έχουμε δύο εναλλασόμενα στάδια εκπαίδευσης.\n",
"\n",
"1. Ανάβαση κλίσης (μεγιστοποίηση) για τον διευκρινιστή
\n",
"$ \\underset{\\theta_{d}}{max} \\left[\\mathbb{E}_{x\\sim p_{data}(x)} log D_{\\theta_{d}}(x) + \\mathbb{E}_{z\\sim p_{z}(z)}\\left(log(1 - D_{\\theta_{d}}(G_{\\theta_{g}}(z)))\\right)\\right]$\n",
"\n",
"2. Κατάβαση κλίσης (ελαχιστοποίηση) για τον γεννήτορα
\n",
"$ \\underset{\\theta_{g}}{min} \\left[\\mathbb{E}_{z\\sim p_{z}(z)}\\left(log(1 - D_{\\theta_{d}}(G_{\\theta_{g}}(z)))\\right)\\right]$\n",
"\n",
"Ο γεννήτορας δηλαδή, προσπαθεί να ελαχιστοποιήσει την πιθανότητα ο διευκρινιστής να είναι σωστός. Ωστόσο, στη συγκεκριμένη περίπτωση, θα χρησιμοποιήσουμε έναν εναλλακτικό τρόπο: ο γεννήτορας θα προσπαθεί να μεγιστοποιήσει την πιθανότητα ο διευκρινιστής να είναι λάθος, δηλαδή:
\n",
"$ \\underset{\\theta_{g}}{max} \\left[\\mathbb{E}_{z\\sim p_{z}(z)}\\left(log(D_{\\theta_{d}}(G_{\\theta_{g}}(z)))\\right)\\right]$\n",
"\n",
"Κατ' αυτόν τον τρόπο, ενώ το αντικείμενο της εκπαίδευσης δεν αλλάζει, εντούτοις διασφαλίζουμε υψηλότερες κλίσεις για την εκπαίδευση του γεννήτορα, πράγμα που οδηγεί σε πιο γρήγορη εκπαίδευση.\n",
"\n",
"Συμπερασματικά, ο διευκρινιστής εκπαιδεύεται κανονικά όπως ένας ταξινομητής, στο να αναγνωρίζει αν η είσοδός του είναι αληθής η ψευδής. Ο γεννήτορας από την άλλη εκπαιδεύεται στο να παράγει κατά το δυνατόν ρεαλιστικά δεδομένα, έτσι ώστε να \"ξεγελά\" τον διευκρινιστή. Συνεπώς, σε αυτή την περίπτωση, η συνάρτηση απώλειας που χρησιμοποιείται είναι η $D(G(z))$\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ODztnfhJEqdD"
},
"outputs": [],
"source": [
"epochs = 50\n",
"batch_size = 64\n",
"smooth = 0.1\n",
"\n",
"real = np.ones(shape=(batch_size, 1))\n",
"fake = np.zeros(shape=(batch_size, 1))\n",
"\n",
"d_loss_gan = []\n",
"gan_loss = []\n",
"\n",
"# Δημιουργία 10 δειγμάτων από το γεννήτορα πριν την εκπαίδευση του GAN\n",
"samples = 10\n",
"x_fake = generator.predict(\n",
" np.random.normal(loc=0, scale=1, size=(samples, latent_dim))\n",
")\n",
"\n",
"for k in range(samples):\n",
" plt.subplot(2, 5, k+1)\n",
" plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"for e in range(epochs):\n",
" for i in range(len(X_train) // batch_size):\n",
" \n",
" # \"Ξεπάγωμα\" (εκπαίδευση) βαρών διευκρινιστή\n",
" discriminator.trainable = True\n",
" \n",
" # Δείγματα αληθινών εικόνων\n",
" X_batch = X_train_scaled[i*batch_size:(i+1)*batch_size]\n",
" d_loss_gan_real = discriminator.train_on_batch(\n",
" x=X_batch, y=real * (1 - smooth)\n",
" )\n",
" \n",
" # Δείγματα \"ψεύτικων\" εικόνων (παραγόμενα από τον γεννήτορα)\n",
" z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))\n",
" X_fake = generator.predict_on_batch(z)\n",
" d_loss_gan_fake = discriminator.train_on_batch(x=X_fake, y=fake)\n",
" \n",
" # Υπολογίσμος συνάρτησης απώλειας GAN\n",
" d_loss_gan_batch = 0.5 * (d_loss_gan_real[0] + d_loss_gan_fake[0])\n",
" \n",
" # \"Πάγωμα\" των βαρών του διευκρινιστή\n",
" discriminator.trainable = False\n",
" gan_loss_batch = gan.train_on_batch(x=z, y=real)\n",
" \n",
"\n",
" d_loss_gan.append(d_loss_gan_batch)\n",
" gan_loss.append(gan_loss_batch[0])\n",
" print('Εποχή: %d/%d, Απώλεια Διευκρινιστή: %.3f, Απώλεια GAN: %.3f' % \n",
" (e + 1, epochs, d_loss_gan[-1], gan_loss[-1]))\n",
"\n",
" if (e + 1) % 10 == 0:\n",
" samples = 10\n",
" x_fake = generator.predict(\n",
" np.random.normal(loc=0, scale=1, size=(samples, latent_dim))\n",
" )\n",
"\n",
" for k in range(samples):\n",
" plt.subplot(2, 5, k+1)\n",
" plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vtHrts8jeZxP"
},
"source": [
"Τέλος, ας σχεδιάσουμε τη μεταβολή των συναρτήσεων απώλειας για τον διευκρινιστή καθώς και για το συνολικό δίκτυο"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KFiIX4T2y-cX"
},
"outputs": [],
"source": [
"plt.plot(d_loss_gan)\n",
"plt.plot(gan_loss)\n",
"plt.title('Απώλεια Μοντέλων')\n",
"plt.ylabel('Απώλεια')\n",
"plt.xlabel('Εποχή')\n",
"plt.legend(['Διευκρινιστής', 'GAN'], loc='upper right')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mhFB52wVLSjg"
},
"source": [
"Παρατηρούμε ότι οι συναρτήσεις απώλειας του διευκρινιστή και του GAN σταθεροποιούνται μετά την 20ή εποχή."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "GAN-MNIST.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}