Commit d997f56f authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

chore: minor cosmetics

parent 13bf6165
...@@ -5,7 +5,7 @@ from src.data.embeddings import * ...@@ -5,7 +5,7 @@ from src.data.embeddings import *
from src.model.alexnet import AlexNetModel from src.model.alexnet import AlexNetModel
from src.utils.common import get_modeldir from src.utils.common import get_modeldir
model_name = 'alexnet_cifar10-new' model_name = 'alexnet_cifar10'
train_ds, test_ds, validation_ds = AlexNetModel.x_dataset() train_ds, test_ds, validation_ds = AlexNetModel.x_dataset()
# load model # load model
......
...@@ -41,4 +41,4 @@ def load_embeddings(): ...@@ -41,4 +41,4 @@ def load_embeddings():
with open(get_datadir('embeddings_labels.pkl'), 'rb') as infile: with open(get_datadir('embeddings_labels.pkl'), 'rb') as infile:
result = pickle.load(infile) result = pickle.load(infile)
return result[0], result[1] return result[0], result[1]
\ No newline at end of file
...@@ -2,21 +2,19 @@ import sys ...@@ -2,21 +2,19 @@ import sys
sys.path.append("..") sys.path.append("..")
from utils.common import * from utils.common import *
from data.cifar10_tuples import *
from utils.distance import * from utils.distance import *
from src.data.embeddings import *
from src.model.alexnet import AlexNetModel from src.model.alexnet import AlexNetModel
from tensorflow.keras import layers, Model from tensorflow.keras import layers, Model
model_suffix = '-new'
alexnet = AlexNetModel() alexnet = AlexNetModel()
alexnet.compile() alexnet.compile()
alexnet.load_weights(get_modeldir('alexnet_cifar10-new.h5')) alexnet.load_weights(get_modeldir('alexnet_cifar10.h5'))
for layer in alexnet.layers: for layer in alexnet.layers:
layer.trainable = False layer.trainable = False
# Philipo Siemese model # Filippo's Siemese model
## Model hyperparters ## Model hyperparters
EMBEDDING_VECTOR_DIMENSION = 4096 EMBEDDING_VECTOR_DIMENSION = 4096
...@@ -46,7 +44,11 @@ NUM_EPOCHS = 3 ...@@ -46,7 +44,11 @@ NUM_EPOCHS = 3
siamese.compile(loss=loss(margin=0.05), optimizer="RMSprop") siamese.compile(loss=loss(margin=0.05), optimizer="RMSprop")
siamese.summary() siamese.summary()
embeddings_ds = tf.data.experimental.load(get_datadir('embeddings')) embeddings, embedding_labels = load_embeddings()
embeddings_ds = tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(embeddings),
tf.data.Dataset.from_tensor_slices(embedding_labels)
))
embeddings_ds = embeddings_ds.cache().shuffle(1000).repeat() embeddings_ds = embeddings_ds.cache().shuffle(1000).repeat()
@tf.function @tf.function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment