web-dev-qa-db-fra.com

Visualisez l'ensemble de données MNIST en utilisant OpenCV ou Matplotlib / Pyplot

j'ai un ensemble de données MNIST et j'essaie de le visualiser en utilisant pyplot. L'ensemble de données est au format cvs où chaque ligne est une image de 784 pixels. je veux le visualiser dans pyplot ou opencv au format d'image 28 * 28. J'essaie directement d'utiliser:

plt.imshow(X[2:],cmap =plt.cm.gray_r, interpolation = "nearest") 

mais je ne fonctionne pas? toutes les idées sur la façon dont je dois aborder cela.

10
decipher

En supposant que vous disposez d'un fichier CSV avec ce format, qui est un format dans lequel l'ensemble de données MNIST est disponible

label, pixel_1_1, pixel_1_2, ...

Voici comment vous pouvez le visualiser dans Python avec Matplotlib puis OpenCV

Matplotlib/Pyplot

import numpy as np
import csv
import matplotlib.pyplot as plt

with open('mnist_test_10.csv', 'r') as csv_file:
    for data in csv.reader(csv_file):
        # The first column is the label
        label = data[0]

        # The rest of columns are pixels
        pixels = data[1:]

        # Make those columns into a array of 8-bits pixels
        # This array will be of 1D with length 784
        # The pixel intensity values are integers from 0 to 255
        pixels = np.array(pixels, dtype='uint8')

        # Reshape the array into 28 x 28 array (2-dimensional array)
        pixels = pixels.reshape((28, 28))

        # Plot
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()

        break # This stops the loop, I just want to see one

enter image description here

OpenCV

Vous pouvez prendre le tableau numpy pixels d'en haut qui est de dtype='uint8' (Entier 8 bits non signé) et de forme 28 x 28, et tracer avec cv2.imshow()

    title = 'Label is {label}'.format(label=label)

    cv2.imshow(title, pixels)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
28
bakkal

Pour tous comme moi qui veulent une solution rapide et sale, simplement pour avoir une idée approximative de ce qu'est une entrée donnée, dans la console et sans bibliothèques fantaisies:

def print_greyscale(pixels, width=28, height=28):
    def get_single_greyscale(pixel):
        val = 232 + round(pixel * 23)
        return '\x1b[48;5;{}m \x1b[0m'.format(int(val))

    for l in range(height):
        line_pixels = pixels[l * width:(l+1) * width]
        print(''.join(get_single_greyscale(p) for p in line_pixels))

(s'attend à ce que l'entrée ait la forme [784] et avec des valeurs flottantes de 0 à 1. Si ce n'est pas le cas, vous pouvez facilement convertir (par exemple pixels = pixels.reshape((784,)) ou pixels \= 255 )

Output

La sortie est un peu déformée mais vous avez l'idée.

4
cpury