fix: remove function for loading dataset locally and add basic lines of code to load csv instead for practical 3b.
This commit is contained in:
+7
-28
@@ -38,34 +38,13 @@ from sklearn.metrics import confusion_matrix, classification_report
|
||||
# Fashion MNIST is built into Keras, downloads automatically on first run
|
||||
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import gzip
|
||||
import os
|
||||
|
||||
def load_fashion_mnist(path):
|
||||
"""Load Fashion MNIST from local .gz files (Kaggle Zalando format)."""
|
||||
files = {
|
||||
'X_train': 'train-images-idx3-ubyte.gz',
|
||||
'y_train': 'train-labels-idx1-ubyte.gz',
|
||||
'X_test': 't10k-images-idx3-ubyte.gz',
|
||||
'y_test': 't10k-labels-idx1-ubyte.gz',
|
||||
}
|
||||
|
||||
with gzip.open(os.path.join(path, files['X_train'])) as f:
|
||||
X_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
|
||||
with gzip.open(os.path.join(path, files['y_train'])) as f:
|
||||
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with gzip.open(os.path.join(path, files['X_test'])) as f:
|
||||
X_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
|
||||
with gzip.open(os.path.join(path, files['y_test'])) as f:
|
||||
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
|
||||
return (X_train, y_train), (X_test, y_test)
|
||||
|
||||
# Replace the Keras load line with:
|
||||
(X_train, y_train), (X_test, y_test) = load_fashion_mnist('./fashion-mnist/')
|
||||
'''
|
||||
# import pandas as pd
|
||||
# train_df = pd.read_csv('fashion-mnist_train.csv')
|
||||
# test_df = pd.read_csv('fashion-mnist_test.csv')
|
||||
# y_train = train_df['label'].values
|
||||
# y_test = test_df['label'].values
|
||||
# X_train = train_df.drop('label', axis=1).values.reshape(-1, 28, 28) # unflatten pixels to 28x28
|
||||
# X_test = test_df.drop('label', axis=1).values.reshape(-1, 28, 28)
|
||||
|
||||
print("Training set shape:", X_train.shape) # (60000, 28, 28)
|
||||
print("Test set shape: ", X_test.shape) # (10000, 28, 28)
|
||||
|
||||
Reference in New Issue
Block a user