Commit fedb9912 authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

feat: load alexnet weights

parent 181ad96e
from tensorflow.keras import models
from model.alexnet import AlexNetModel
from common import get_modeldir
model_name = 'alexnet_cifar10-new'
train_ds, test_ds, validation_ds = AlexNetModel.x_dataset()
# load model
# alexnet = models.load_model(get_modeldir(model_name + '.tf'))
# create model
alexnet = AlexNetModel()
alexnet.compile()
# alexnet.summary()
# load weights
alexnet.load_weights(get_modeldir(model_name + '.h5'))
# train
train_ds, test_ds, validation_ds = alexnet.x_dataset()
alexnet.x_train(train_ds, test_ds)
alexnet.evaluate(validation_ds)
# alexnet.fit(train_ds, validation_data=test_ds)
# save
alexnet.save_weights(get_modeldir(model_name + '.h5'))
alexnet.save(get_modeldir(model_name + '.tf'))
# alexnet.save_weights(get_modeldir(model_name + '.h5'))
# alexnet.save(get_modeldir(model_name + '.tf'))
# print('evaluate')
# evaluate
alexnet.evaluate(validation_ds)
# res = alexnet.predict(validation_ds)
print('done')
......@@ -2,6 +2,7 @@ from src.common import *
import tensorflow as tf
from tensorflow.keras import layers, callbacks, datasets, Sequential
tensorboard_cb = callbacks.TensorBoard(get_logdir("alexnet/fit"))
class AlexNetModel(Sequential):
def __init__(self):
......@@ -36,6 +37,20 @@ class AlexNetModel(Sequential):
layers.Dense(name='unfreeze', units=10, activation='softmax')
])
def compile(self, optimizer=tf.optimizers.SGD(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
loss_weights=None, weighted_metrics=None, run_eagerly=None, steps_per_execution=None, **kwargs):
super().compile(optimizer, loss, metrics, loss_weights, weighted_metrics, run_eagerly, steps_per_execution, **kwargs)
def fit(self, x=None, y=None, batch_size=None, epochs=50, verbose='auto', callbacks=[tensorboard_cb], validation_split=0.,
validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0,
steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1,
max_queue_size=10, workers=1, use_multiprocessing=False):
return super().fit(x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle,
class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps,
validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
@staticmethod
def x_dataset():
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
......@@ -62,11 +77,3 @@ class AlexNetModel(Sequential):
test_ds = (test_ds.map(process_images_couple).shuffle(buffer_size=train_ds_size).batch(batch_size=32, drop_remainder=True))
validation_ds = (validation_ds.map(process_images_couple).shuffle(buffer_size=train_ds_size).batch(batch_size=32, drop_remainder=True))
return train_ds, test_ds, validation_ds
def x_train(self, train_ds, validation_ds):
tensorboard_cb = callbacks.TensorBoard(get_logdir("alexnet/fit"))
# optimizer='adam', SGD W
self.compile(loss='sparse_categorical_crossentropy', optimizer=tf.optimizers.SGD(learning_rate=0.001), metrics=['accuracy'])
self.summary()
self.fit(train_ds, epochs=50, validation_data=validation_ds, validation_freq=1, callbacks=[tensorboard_cb])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment