Commit 2d3259f1 authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

fix: paths for new models

parent a9e489c0
......@@ -23,10 +23,11 @@ embedding_model = tf.keras.models.Sequential([
])
embedding_model.build(MODEL_INPUT_SIZE)
embedding_model.summary()
#DATASET_NAME = 'cats_vs_dogs'
DATASET_NAME = 'cifar10'
# DATASET_NAME = 'cars196'
#DATASET_NAME = 'cars196'
ds = tfds.load(DATASET_NAME, split='train')
......@@ -50,7 +51,7 @@ batched_ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# "label" has shape [BATCH_SIZE,1] and is an integer label (value between 0 and 9)
# Naming schema: <dataset_name>-<dataset_split>.<model-name>.embeddings.pickle
DST_FNAME = f'/content/drive/MyDrive/{DATASET_NAME}-train.efficientnet_v2_imagenet1k_s.embeddings.pickle'
DST_FNAME = get_datadir('efficientnet_v2_imagenet1k_s.embeddings.pkl')
if Path(DST_FNAME).exists():
# When you need to use the embeddings, upload the file (or store it on Drive and mount your drive folder in Colab), then run:
......@@ -79,8 +80,6 @@ embeddings_ds = tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(embeddings),
tf.data.Dataset.from_tensor_slices(labels)
)).cache().shuffle(1000).repeat()
# change for triplet loss implementation
@tf.function
def make_label_for_pair(embeddings, labels):
......@@ -99,13 +98,11 @@ train_ds = train_ds.map(make_label_for_pair, num_parallel_calls=tf.data.AUTOTUNE
## Model hyperparters
EMBEDDING_VECTOR_DIMENSION = 1280
# EMBEDDING_VECTOR_DIMENSION = int(1280/2)
IMAGE_VECTOR_DIMENSIONS = 128
ACTIVATION_FN = 'tanh' # same as in paper
MARGIN = 0.005
DST_MODEL_FNAME = f'/content/drive/MyDrive/trained_model.margin-{MARGIN}.{Path(Path(DST_FNAME).stem).stem}'
DST_MODEL_FNAME = get_modeldir('seamese_cifar10_' + str(Path(Path(DST_FNAME).stem).stem) + '_' + str(IMAGE_VECTOR_DIMENSIONS) + '.tf')
## These functions are straight from the Keras tutorial linked above
......@@ -257,7 +254,7 @@ inference_model = tf.keras.models.load_model(DST_MODEL_FNAME, compile=False)
# NUM_SAMPLES_TO_DISPLAY = 10000
NUM_SAMPLES_TO_DISPLAY = 3000
LOG_DIR=Path('logs')
LOG_DIR=Path('logs_efficientnet')
LOG_DIR.mkdir(exist_ok=True, parents=True)
val_ds = (tfds.load(DATASET_NAME, split='test')
......
import sys
sys.path.append("..")
from src.data.embeddings import *
from src.model.alexnet import AlexNetModel
from src.utils.common import get_modeldir
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
......@@ -52,9 +56,8 @@ batched_ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
# "image" has shape [BATCH_SIZE,32,32,3] and is an RGB uint8 image
# "label" has shape [BATCH_SIZE,1] and is an integer label (value between 0 and 9)
# Naming schema: <dataset_name>-<dataset_split>.<model-name>.embeddings.pickle
DST_FNAME = f'/content/drive/MyDrive/{DATASET_NAME}-train.vit_s16_fe.embeddings.pickle'
DST_FNAME = get_datadir('vit_s16_fe.embeddings.pkl')
if Path(DST_FNAME).exists():
# When you need to use the embeddings, upload the file (or store it on Drive and mount your drive folder in Colab), then run:
......@@ -100,13 +103,12 @@ train_ds = train_ds.map(make_label_for_pair, num_parallel_calls=tf.data.AUTOTUNE
#train_ds = train_ds.map(lambda _, vals: vals) # discard the prepended "selected" class from the rejction resample, since we aleady have it available
## Model hyperparters
EMBEDDING_VECTOR_DIMENSION = 1280
EMBEDDING_VECTOR_DIMENSION = 384
IMAGE_VECTOR_DIMENSIONS = 128
ACTIVATION_FN = 'tanh' # same as in paper
MARGIN = 0.005
DST_MODEL_FNAME = f'/content/drive/MyDrive/trained_model.margin-{MARGIN}.{Path(Path(DST_FNAME).stem).stem}'
DST_MODEL_FNAME = get_modeldir('seamese_cifar10_' + str(Path(Path(DST_FNAME).stem).stem) + '_' + str(IMAGE_VECTOR_DIMENSIONS) + '.tf')
## These functions are straight from the Keras tutorial linked above
......@@ -257,7 +259,7 @@ def write_embeddings_for_tensorboard(image_vectors: list, labels: list, root_dir
inference_model = tf.keras.models.load_model(DST_MODEL_FNAME, compile=False)
NUM_SAMPLES_TO_DISPLAY = 10000
LOG_DIR=Path('logs')
LOG_DIR=Path('logs_vit')
LOG_DIR.mkdir(exist_ok=True, parents=True)
val_ds = (tfds.load(DATASET_NAME, split='test')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment