We already have the embeddings precomputed in `embeddings` and their matching `labels`. To train the siamese networks, we need to generate random pairs of embeddings, assigning as target `1` if the two come from the same class and `0` otherwise.
In order to keep the training balanced, we can't simply select two random `(embedding, label)` tuples from the dataset, because this is heavily unbalanced towards the negative class. To keep thing simple, we'll randomly select two samples and then use `rejection_resample` to rebalance the classes.
## Training hyperparameters (values selected randomly at the moment, would be easy to set up hyperparameter tuning wth Keras Tuner)
TRAIN_BATCH_SIZE=128
STEPS_PER_EPOCH=1000
NUM_EPOCHS=25
**NOTE**: rejection resampling works only if the number of classes is reasonably low: with 10 classes there's a 90% probability that a sample will be rejected, it can get very inefficient very quickly if the number of classes is too great.
The `projection_model` is the part of the network that generates the final image vector (currently, a simple Dense layer with tanh activation, but it can be as complex as needed).
The `siamese` model is the one we train. It applies the projection model to two embeddings, calculates the euclidean distance between the two generated image vectors and calculates the contrastive loss.
As a note, [here](https://towardsdatascience.com/contrastive-loss-explaned-159f2d4a87ec) they mention that cosine distance is preferable to euclidean distance:
> in a large dimensional space, all points tend to be far apart by the euclidian measure. In higher dimensions, the angle between vectors is a more effective measure.
Note that, when using cosine distance, the margin needs to be reduced from its default value of 1 (see below).
__________________
### Contrastive Loss
$ Loss = Y*Dist(v_1,v_2)^2 + (1-Y)*max(margin-D,0)^2$
$Y$ is the GT target (1 if $v_1$ and $v_2$ belong to the same class, 0 otherwise). If images are from the same class, use the squared distance as loss (you want to push the distance to be close to 0 for same-class couples), otherwise keep the (squared) maximum between 0 and $margin - D$.
For different-class couples, the distance should be pushed to a high value. The **margin identifies a cone inside which vectors are considered the same**. For cosine distance, which has range [0,2], **1 is NOT an adequate value**).
**NOTE** In the loss implementation below, we calculate the mean of the two terms, though this should not actually be necessary (the minimizer value for the loss is the same whether the loss is divided by 2 or not).
"""
## Model hyperparters
EMBEDDING_VECTOR_DIMENSION=4096
# EMBEDDING_VECTOR_DIMENSION = int(1280/2)
IMAGE_VECTOR_DIMENSIONS=128
# IMAGE_VECTOR_DIMENSIONS = 3 # use for test visualization on tensorboard
ACTIVATION_FN='tanh'# same as in paper
MARGIN=0.05
## These functions are straight from the Keras tutorial linked above
# Provided two tensors t1 and t2
# Euclidean distance = sqrt(sum(square(t1-t2)))
defeuclidean_distance(vects):
"""Find the Euclidean distance between two vectors.
Arguments:
vects: List containing two tensors of same length.
To validate the model, we load the validation chunk of the dataset and we feed it into the network. We don't need to repeat the preprocessing steps done to the dataset, because the preprocessing is embedded in the inference model by the `Rescaling` and `Resizing` layers we added above.
____________
## Visualizing embeddings in TensorBoard
In `metadata.tsv` file we list the labels in the same order as they appear in the embeddings list.
We write out the embeddings list as a tf.Variable initialized to the embeddings values, using TensorBoard's writers to specify the metadata file to use and the name of the tensor to display.
Additionally, in the specification of ProjectorConfig's proto message, there is the possibility to pass the values as a second .tsv file (`values.tsv`) instead than having them loaded from the checkpoint file.
I don't know which values are getting loaded at the moment, but since it works I won't change it further and keep both the .tsv and the checkpointed values.
(See https://stackoverflow.com/a/57230031/3214872)