Commit 13bf6165 authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

feat: add embeddings

parent 9cf5261f
import numpy as np import sys
from keras import Model sys.path.append("..")
import tensorflow as tf
from tensorflow.keras import datasets from import *
from src.model.alexnet import AlexNetModel from src.model.alexnet import AlexNetModel
from src.utils.common import get_modeldir, process_images_couple, get_datadir from src.utils.common import get_modeldir
model_name = 'alexnet_cifar10-new' model_name = 'alexnet_cifar10-new'
train_ds, test_ds, validation_ds = AlexNetModel.x_dataset() train_ds, test_ds, validation_ds = AlexNetModel.x_dataset()
...@@ -27,7 +27,10 @@ alexnet.load_weights(get_modeldir(model_name + '.h5')) ...@@ -27,7 +27,10 @@ alexnet.load_weights(get_modeldir(model_name + '.h5'))
# + '.tf')) # + '.tf'))
# evaluate # evaluate
alexnet.evaluate(validation_ds) # alexnet.evaluate(validation_ds)
# res = alexnet.predict(validation_ds) # res = alexnet.predict(validation_ds)
embeddings, labels = calc_embeddings(alexnet)
save_embeddings(embeddings, labels)
print('done') print('done')
import numpy as np
import _pickle as pickle
from keras import Model
import tensorflow as tf
from tensorflow.keras import datasets
from src.utils.common import process_images_couple, get_datadir
def calc_embeddings(alexnet):
# remove the last layer
embedding_model = Model(inputs=alexnet.input, outputs=alexnet.layers[-2].output)
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
embedding_images = np.concatenate([train_images, test_images])
embedding_labels = np.concatenate([train_labels, test_labels])
embedding_vds =, embedding_labels))
embedding_vds = (, drop_remainder=False))
print('predicting embeddings')
embeddings = embedding_model.predict(embedding_vds)
return embeddings, embedding_labels
# # zip together embeddings and their labels, cache in memory (maybe not necessay or maybe faster this way), shuffle, repeat forever.
# embeddings_ds =
# ))
def save_embeddings(embeddings, labels):
data = [embeddings, labels]
with open(get_datadir('embeddings_labels.pkl'), 'wb') as outfile:
pickle.dump(data, outfile, -1)
def load_embeddings():
with open(get_datadir('embeddings_labels.pkl'), 'rb') as infile:
result = pickle.load(infile)
return result[0], result[1]
\ No newline at end of file
import sys
from utils.common import * from utils.common import *
from data.cifar10_tuples import * from data.cifar10_tuples import *
from utils.distance import * from utils.distance import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment