Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Oleh Astappiev
Near-Similar Image Recognition
Commits
fedb9912
Commit
fedb9912
authored
Nov 24, 2021
by
Oleh Astappiev
Browse files
feat: load alexnet weights
parent
181ad96e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
14 deletions
+33
-14
src/alexnet.py
src/alexnet.py
+18
-6
src/model/alexnet.py
src/model/alexnet.py
+15
-8
No files found.
src/alexnet.py
View file @
fedb9912
from
tensorflow.keras
import
models
from
model.alexnet
import
AlexNetModel
from
common
import
get_modeldir
model_name
=
'alexnet_cifar10-new'
train_ds
,
test_ds
,
validation_ds
=
AlexNetModel
.
x_dataset
()
# load model
# alexnet = models.load_model(get_modeldir(model_name + '.tf'))
# create model
alexnet
=
AlexNetModel
()
alexnet
.
compile
()
# alexnet.summary()
# load weights
alexnet
.
load_weights
(
get_modeldir
(
model_name
+
'.h5'
))
# train
train_ds
,
test_ds
,
validation_ds
=
alexnet
.
x_dataset
()
alexnet
.
x_train
(
train_ds
,
test_ds
)
alexnet
.
evaluate
(
validation_ds
)
# alexnet.fit(train_ds, validation_data=test_ds)
# save
alexnet
.
save_weights
(
get_modeldir
(
model_name
+
'.h5'
))
alexnet
.
save
(
get_modeldir
(
model_name
+
'.tf'
))
#
alexnet.save_weights(get_modeldir(model_name + '.h5'))
#
alexnet.save(get_modeldir(model_name + '.tf'))
# print('evaluate')
# evaluate
alexnet
.
evaluate
(
validation_ds
)
# res = alexnet.predict(validation_ds)
print
(
'done'
)
src/model/alexnet.py
View file @
fedb9912
...
...
@@ -2,6 +2,7 @@ from src.common import *
import
tensorflow
as
tf
from
tensorflow.keras
import
layers
,
callbacks
,
datasets
,
Sequential
tensorboard_cb
=
callbacks
.
TensorBoard
(
get_logdir
(
"alexnet/fit"
))
class
AlexNetModel
(
Sequential
):
def
__init__
(
self
):
...
...
@@ -36,6 +37,20 @@ class AlexNetModel(Sequential):
layers
.
Dense
(
name
=
'unfreeze'
,
units
=
10
,
activation
=
'softmax'
)
])
def
compile
(
self
,
optimizer
=
tf
.
optimizers
.
SGD
(
learning_rate
=
0.001
),
loss
=
'sparse_categorical_crossentropy'
,
metrics
=
[
'accuracy'
],
loss_weights
=
None
,
weighted_metrics
=
None
,
run_eagerly
=
None
,
steps_per_execution
=
None
,
**
kwargs
):
super
().
compile
(
optimizer
,
loss
,
metrics
,
loss_weights
,
weighted_metrics
,
run_eagerly
,
steps_per_execution
,
**
kwargs
)
def
fit
(
self
,
x
=
None
,
y
=
None
,
batch_size
=
None
,
epochs
=
50
,
verbose
=
'auto'
,
callbacks
=
[
tensorboard_cb
],
validation_split
=
0.
,
validation_data
=
None
,
shuffle
=
True
,
class_weight
=
None
,
sample_weight
=
None
,
initial_epoch
=
0
,
steps_per_epoch
=
None
,
validation_steps
=
None
,
validation_batch_size
=
None
,
validation_freq
=
1
,
max_queue_size
=
10
,
workers
=
1
,
use_multiprocessing
=
False
):
return
super
().
fit
(
x
,
y
,
batch_size
,
epochs
,
verbose
,
callbacks
,
validation_split
,
validation_data
,
shuffle
,
class_weight
,
sample_weight
,
initial_epoch
,
steps_per_epoch
,
validation_steps
,
validation_batch_size
,
validation_freq
,
max_queue_size
,
workers
,
use_multiprocessing
)
@
staticmethod
def
x_dataset
():
(
train_images
,
train_labels
),
(
test_images
,
test_labels
)
=
datasets
.
cifar10
.
load_data
()
...
...
@@ -62,11 +77,3 @@ class AlexNetModel(Sequential):
test_ds
=
(
test_ds
.
map
(
process_images_couple
).
shuffle
(
buffer_size
=
train_ds_size
).
batch
(
batch_size
=
32
,
drop_remainder
=
True
))
validation_ds
=
(
validation_ds
.
map
(
process_images_couple
).
shuffle
(
buffer_size
=
train_ds_size
).
batch
(
batch_size
=
32
,
drop_remainder
=
True
))
return
train_ds
,
test_ds
,
validation_ds
def
x_train
(
self
,
train_ds
,
validation_ds
):
tensorboard_cb
=
callbacks
.
TensorBoard
(
get_logdir
(
"alexnet/fit"
))
# optimizer='adam', SGD W
self
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
optimizer
=
tf
.
optimizers
.
SGD
(
learning_rate
=
0.001
),
metrics
=
[
'accuracy'
])
self
.
summary
()
self
.
fit
(
train_ds
,
epochs
=
50
,
validation_data
=
validation_ds
,
validation_freq
=
1
,
callbacks
=
[
tensorboard_cb
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment