Commit 1fc2750c authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

feat: move trainable = False to model classes

parent c95f3b2a
......@@ -30,9 +30,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......@@ -58,8 +55,5 @@ inference_model.save(get_modeldir(model_name + '_inference.tf'), save_format='tf
print('visualization')
# compute vectors of the images and their labels, store them in a tsv file for visualization
siamese_vectors, siamese_labels = calc_vectors(comb_ds, inference_model)
project_embeddings(siamese_vectors, siamese_labels, model_name + '_siamese')
projection_vectors = siamese.get_projection_model().predict(emb_vectors)
project_embeddings(projection_vectors, emb_labels, model_name + '_siamese2')
......@@ -30,9 +30,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -31,9 +31,6 @@ model.load_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -30,9 +30,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -54,7 +54,9 @@ class AlexNetModel(Sequential):
return super().fit(x=x, y=y, batch_size=batch_size, epochs=epochs, callbacks=callbacks, **kwargs)
def get_embedding_model(self):
return Model(inputs=self.input, outputs=self.layers[-2].output)
core = Model(inputs=self.input, outputs=self.layers[-2].output)
for layer in core.layers: layer.trainable = False
return core
@staticmethod
def preprocess_input(image, label):
......
......@@ -38,10 +38,9 @@ class MobileNetModel(Model):
def get_embedding_model(self):
core = Model(inputs=self.input, outputs=self.layers[-7].output)
return Sequential([
core,
layers.Flatten(),
])
core = Sequential([core, layers.Flatten()])
for layer in core.layers: layer.trainable = False
return core
@staticmethod
def preprocess_input(image, label):
......
......@@ -57,10 +57,6 @@ class SiameseModel(Model):
super(SiameseModel, self).__init__(inputs=[emb_input_1, emb_input_2], outputs=computed_distance)
# def call(self, inputs):
# """ Projection model is a model from embeddings to image vector """
# return self.projection_model(inputs)
def get_projection_model(self):
""" Projection model is a model from embeddings to image vector """
return self.projection_model
......@@ -108,3 +104,22 @@ class SiameseModel(Model):
# train_ds = train_ds.map(lambda _, vals: vals) # discard the prepended "selected" class from the rejction resample, since we aleady have it available
train_ds = train_ds.batch(TRAIN_BATCH_SIZE) # .prefetch(tf.data.AUTOTUNE)
return train_ds
@staticmethod
def prepare_tuples(embeddings, labels):
embeddings_ds = tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(embeddings),
tf.data.Dataset.from_tensor_slices(labels)
)).cache().shuffle(1000).repeat()
# TODO: change for triplet loss implementation
# because of shuffling, we can take two adjacent tuples as a randomly matched pair
train_ds = embeddings_ds.window(3, drop_remainder=True)
train_ds = train_ds.flat_map(lambda w1, w2: tf.data.Dataset.zip((w1.batch(2), w2.batch(2)))) # see https://stackoverflow.com/questions/55429307/how-to-use-windows-created-by-the-dataset-window-method-in-tensorflow-2-0
# generate the target label depending on whether the labels match or not
train_ds = train_ds.map(make_label_for_pair, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
# resample to the desired distribution
# train_ds = train_ds.rejection_resample(lambda embs, target: tf.cast(target, tf.int32), [0.5, 0.5], initial_dist=[0.9, 0.1])
# train_ds = train_ds.map(lambda _, vals: vals) # discard the prepended "selected" class from the rejction resample, since we aleady have it available
train_ds = train_ds.batch(TRAIN_BATCH_SIZE) # .prefetch(tf.data.AUTOTUNE)
return train_ds
......@@ -51,7 +51,9 @@ class VGG16Model(Model):
return super().fit(x=x, y=y, batch_size=batch_size, epochs=epochs, callbacks=callbacks, **kwargs)
def get_embedding_model(self):
return Model(inputs=self.input, outputs=self.layers[-2].output)
core = Model(inputs=self.input, outputs=self.layers[-2].output)
for layer in core.layers: layer.trainable = False
return core
@staticmethod
def preprocess_input(image, label):
......
......@@ -33,9 +33,6 @@ model.load_weights(get_modeldir(model_name + '.h5'))
# print('evaluating...')
# model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -31,9 +31,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -33,9 +33,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
......@@ -33,9 +33,6 @@ model.save_weights(get_modeldir(model_name + '.h5'))
print('evaluating...')
model.evaluate(test_ds)
for layer in model.layers:
layer.trainable = False
print('calculating embeddings...')
embedding_model = model.get_embedding_model()
embedding_model.summary()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment