Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Oleh Astappiev
Near-Similar Image Recognition
Commits
5fb38925
Commit
5fb38925
authored
May 30, 2022
by
Oleh Astappiev
Browse files
feat: rename BaseDataset to AbsDataset
parent
279c21eb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
83 deletions
+73
-83
src/data/__init__.py
src/data/__init__.py
+60
-0
src/data/base.py
src/data/base.py
+0
-64
src/data/cifar10.py
src/data/cifar10.py
+4
-6
src/data/imagenette.py
src/data/imagenette.py
+3
-5
src/data/simple3.py
src/data/simple3.py
+3
-5
src/utils/embeddings.py
src/utils/embeddings.py
+3
-3
No files found.
src/data/__init__.py
View file @
5fb38925
from
abc
import
ABC
,
abstractmethod
from
typing
import
Tuple
,
Callable
,
List
PRINT_SIZE
=
True
DEFAULT_BATCH_SIZE
=
32
class
AsbDataset
(
ABC
):
def
__init__
(
self
,
name
:
str
,
classes
:
List
[
str
],
image_size
:
Tuple
[
int
,
int
],
batch_size
:
int
=
None
,
map_fn
:
Callable
=
None
):
self
.
name
=
name
self
.
classes
=
classes
self
.
num_classes
=
len
(
classes
)
self
.
_image_size
=
image_size
self
.
_batch_size
=
batch_size
self
.
_map_fn
=
map_fn
self
.
_train_ds
=
None
self
.
_val_ds
=
None
self
.
_test_ds
=
None
def
get_classes
(
self
):
return
self
.
classes
def
get_num_classes
(
self
):
return
len
(
self
.
classes
)
def
get_train
(
self
):
if
self
.
_train_ds
is
None
:
self
.
__load
()
return
self
.
_train_ds
def
get_val
(
self
):
if
self
.
_val_ds
is
None
:
self
.
__load
()
return
self
.
_val_ds
def
get_test
(
self
):
if
self
.
_test_ds
is
None
:
self
.
__load
()
return
self
.
_test_ds
def
get_combined
(
self
):
return
self
.
get_train
().
concatenate
(
self
.
get_val
()).
concatenate
(
self
.
get_test
())
def
__load
(
self
):
train_ds
,
val_ds
,
test_ds
=
self
.
_load_dataset
(
self
.
_image_size
,
self
.
_batch_size
,
self
.
_map_fn
)
self
.
_train_ds
=
train_ds
self
.
_val_ds
=
val_ds
self
.
_test_ds
=
test_ds
if
PRINT_SIZE
:
print
(
self
.
name
,
"dataset loaded"
)
print
(
"Training size:"
,
train_ds
.
cardinality
().
numpy
())
print
(
"Validation size:"
,
val_ds
.
cardinality
().
numpy
())
print
(
"Evaluation size:"
,
test_ds
.
cardinality
().
numpy
())
@
abstractmethod
def
_load_dataset
(
self
,
image_size
,
batch_size
,
map_fn
):
pass
src/data/base.py
deleted
100644 → 0
View file @
279c21eb
from
abc
import
ABC
,
abstractmethod
from
typing
import
Tuple
,
Callable
,
List
PRINT_SIZE
=
True
DEFAULT_BATCH_SIZE
=
32
class
BaseDataset
(
ABC
):
def
__init__
(
self
,
name
:
str
,
classes
:
List
[
str
],
image_size
:
Tuple
[
int
,
int
],
batch_size
:
int
=
None
,
map_fn
:
Callable
=
None
):
self
.
name
=
name
self
.
classes
=
classes
self
.
num_classes
=
len
(
classes
)
self
.
_image_size
=
image_size
self
.
_batch_size
=
batch_size
self
.
_map_fn
=
map_fn
self
.
_train_ds
=
None
self
.
_val_ds
=
None
self
.
_test_ds
=
None
def
get_classes
(
self
):
return
self
.
classes
def
get_num_classes
(
self
):
return
len
(
self
.
classes
)
def
get_train
(
self
):
if
self
.
_train_ds
is
None
:
self
.
__load
()
return
self
.
_train_ds
def
get_val
(
self
):
if
self
.
_val_ds
is
None
:
self
.
__load
()
return
self
.
_val_ds
def
get_test
(
self
):
if
self
.
_test_ds
is
None
:
self
.
__load
()
return
self
.
_test_ds
def
get_combined
(
self
):
return
self
.
get_train
().
concatenate
(
self
.
get_val
()).
concatenate
(
self
.
get_test
())
def
__load
(
self
):
args
=
self
.
_load_dataset
(
self
.
_image_size
,
self
.
_batch_size
,
self
.
_map_fn
)
train_ds
,
val_ds
,
test_ds
=
self
.
_split_dataset
(
*
args
)
self
.
_train_ds
=
train_ds
self
.
_val_ds
=
val_ds
self
.
_test_ds
=
test_ds
if
PRINT_SIZE
:
print
(
self
.
name
,
"dataset loaded"
)
print
(
"Training size:"
,
train_ds
.
cardinality
().
numpy
())
print
(
"Validation size:"
,
val_ds
.
cardinality
().
numpy
())
print
(
"Evaluation size:"
,
test_ds
.
cardinality
().
numpy
())
@
abstractmethod
def
_load_dataset
(
self
,
image_size
,
batch_size
,
map_fn
):
pass
@
abstractmethod
def
_split_dataset
(
self
,
*
args
):
pass
src/data/cifar10.py
View file @
5fb38925
import
tensorflow
as
tf
from
src.data
.base
import
Base
Dataset
from
src.data
import
Asb
Dataset
DEFAULT_BATCH_SIZE
=
32
DEFAULT_IMAGE_SIZE
=
(
32
,
32
)
CLASS_NAMES
=
[
'airplane'
,
'automobile'
,
'bird'
,
'cat'
,
'deer'
,
'dog'
,
'frog'
,
'horse'
,
'ship'
,
'truck'
]
class
Cifar10
(
BaseDataset
):
class
Cifar10
(
AsbDataset
):
def
__init__
(
self
,
image_size
=
DEFAULT_IMAGE_SIZE
,
batch_size
=
DEFAULT_BATCH_SIZE
,
map_fn
=
None
):
super
(
Cifar10
,
self
).
__init__
(
name
=
'cifar10'
,
classes
=
CLASS_NAMES
,
image_size
=
image_size
,
batch_size
=
batch_size
,
map_fn
=
map_fn
)
...
...
@@ -33,10 +34,7 @@ class Cifar10(BaseDataset):
train_ds
=
train_ds
.
map
(
map_fn
).
prefetch
(
tf
.
data
.
AUTOTUNE
)
test_ds
=
test_ds
.
map
(
map_fn
).
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
train_ds
,
test_ds
def
_split_dataset
(
self
,
train_ds
,
test_ds
):
train_ds_size
=
train_ds
.
cardinality
().
numpy
()
train_ds
=
train_ds
.
skip
(
train_ds_size
/
10
)
val_ds
=
train_ds
.
take
(
train_ds_size
/
10
)
return
train_ds
,
val_ds
,
test_ds
\ No newline at end of file
return
train_ds
,
val_ds
,
test_ds
src/data/imagenette.py
View file @
5fb38925
import
tensorflow
as
tf
from
src.data
.base
import
Base
Dataset
from
src.data
import
Asb
Dataset
DEFAULT_BATCH_SIZE
=
32
DEFAULT_IMAGE_SIZE
=
(
400
,
320
)
CLASS_NAMES
=
[
'fish'
,
'dog'
,
'player'
,
'saw'
,
'building'
,
'music'
,
'truck'
,
'gas'
,
'ball'
,
'parachute'
]
class
Imagenette
(
BaseDataset
):
class
Imagenette
(
AsbDataset
):
def
__init__
(
self
,
image_size
=
DEFAULT_IMAGE_SIZE
,
batch_size
=
DEFAULT_BATCH_SIZE
,
map_fn
=
None
):
super
(
Imagenette
,
self
).
__init__
(
name
=
'imagenette'
,
classes
=
CLASS_NAMES
,
image_size
=
image_size
,
batch_size
=
batch_size
,
map_fn
=
map_fn
)
...
...
@@ -33,9 +34,6 @@ class Imagenette(BaseDataset):
train_ds
=
train_ds
.
map
(
map_fn
).
prefetch
(
tf
.
data
.
AUTOTUNE
)
test_ds
=
test_ds
.
map
(
map_fn
).
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
train_ds
,
test_ds
def
_split_dataset
(
self
,
train_ds
,
test_ds
):
test_ds_size
=
test_ds
.
cardinality
().
numpy
()
val_ds
=
test_ds
.
take
(
test_ds_size
/
2
)
test_ds
=
test_ds
.
skip
(
test_ds_size
/
2
)
...
...
src/data/simple3.py
View file @
5fb38925
import
tensorflow
as
tf
from
src.data
.base
import
Base
Dataset
from
src.data
import
Asb
Dataset
DEFAULT_BATCH_SIZE
=
6
DEFAULT_IMAGE_SIZE
=
(
400
,
320
)
CLASS_NAMES
=
[
'building'
,
'dog'
,
'player'
]
class
Simple3
(
BaseDataset
):
class
Simple3
(
AsbDataset
):
def
__init__
(
self
,
image_size
=
DEFAULT_IMAGE_SIZE
,
batch_size
=
DEFAULT_BATCH_SIZE
,
map_fn
=
None
):
super
(
Simple3
,
self
).
__init__
(
name
=
'simple3'
,
classes
=
CLASS_NAMES
,
image_size
=
image_size
,
batch_size
=
batch_size
,
map_fn
=
map_fn
)
...
...
@@ -22,9 +23,6 @@ class Simple3(BaseDataset):
if
map_fn
is
not
None
:
ds
=
ds
.
map
(
map_fn
).
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
ds
def
_split_dataset
(
self
,
ds
):
ds_size
=
ds
.
cardinality
().
numpy
()
train_ds
=
ds
.
take
(
ds_size
*
0.6
)
val_ds
=
ds
.
skip
(
ds_size
*
0.6
).
take
(
ds_size
*
0.2
)
...
...
src/utils/embeddings.py
View file @
5fb38925
...
...
@@ -9,7 +9,7 @@ from tensorboard.plugins import projector
from
google.protobuf
import
text_format
from
src.utils.common
import
get_datadir
,
get_modeldir
,
get_logdir_root
from
src.data
.base
import
Base
Dataset
from
src.data
import
Asb
Dataset
def
_save_vectors_path
(
values
,
labels
,
path
):
...
...
@@ -117,7 +117,7 @@ def project_embeddings(image_vectors, labels, name='projection'):
projector
.
visualize_embeddings
(
root_dir
,
config
)
def
load_weights_of
(
model
:
tf
.
keras
.
Model
,
dataset
:
Base
Dataset
):
def
load_weights_of
(
model
:
tf
.
keras
.
Model
,
dataset
:
Asb
Dataset
):
model_file
=
get_modeldir
(
model
.
name
+
'_'
+
dataset
.
name
+
'.h5'
)
if
Path
(
model_file
).
exists
():
...
...
@@ -131,7 +131,7 @@ def load_weights_of(model: tf.keras.Model, dataset: BaseDataset):
model
.
evaluate
(
dataset
.
get_test
())
def
get_embeddings_of
(
model
:
tf
.
keras
.
Model
,
dataset
:
Base
Dataset
):
def
get_embeddings_of
(
model
:
tf
.
keras
.
Model
,
dataset
:
Asb
Dataset
):
embedding_file
=
get_datadir
(
model
.
name
+
'_'
+
dataset
.
name
+
'.pkl'
)
if
Path
(
embedding_file
).
exists
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment