Commit 5fb38925 authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

feat: rename BaseDataset to AbsDataset

parent 279c21eb
from abc import ABC, abstractmethod
from typing import Tuple, Callable, List
PRINT_SIZE = True
DEFAULT_BATCH_SIZE = 32
class AsbDataset(ABC):
def __init__(self, name: str, classes: List[str], image_size: Tuple[int, int], batch_size: int = None, map_fn: Callable = None):
self.name = name
self.classes = classes
self.num_classes = len(classes)
self._image_size = image_size
self._batch_size = batch_size
self._map_fn = map_fn
self._train_ds = None
self._val_ds = None
self._test_ds = None
def get_classes(self):
return self.classes
def get_num_classes(self):
return len(self.classes)
def get_train(self):
if self._train_ds is None:
self.__load()
return self._train_ds
def get_val(self):
if self._val_ds is None:
self.__load()
return self._val_ds
def get_test(self):
if self._test_ds is None:
self.__load()
return self._test_ds
def get_combined(self):
return self.get_train().concatenate(self.get_val()).concatenate(self.get_test())
def __load(self):
train_ds, val_ds, test_ds = self._load_dataset(self._image_size, self._batch_size, self._map_fn)
self._train_ds = train_ds
self._val_ds = val_ds
self._test_ds = test_ds
if PRINT_SIZE:
print(self.name, "dataset loaded")
print("Training size:", train_ds.cardinality().numpy())
print("Validation size:", val_ds.cardinality().numpy())
print("Evaluation size:", test_ds.cardinality().numpy())
@abstractmethod
def _load_dataset(self, image_size, batch_size, map_fn):
pass
from abc import ABC, abstractmethod
from typing import Tuple, Callable, List
PRINT_SIZE = True
DEFAULT_BATCH_SIZE = 32
class BaseDataset(ABC):
def __init__(self, name: str, classes: List[str], image_size: Tuple[int, int], batch_size: int = None, map_fn: Callable = None):
self.name = name
self.classes = classes
self.num_classes = len(classes)
self._image_size = image_size
self._batch_size = batch_size
self._map_fn = map_fn
self._train_ds = None
self._val_ds = None
self._test_ds = None
def get_classes(self):
return self.classes
def get_num_classes(self):
return len(self.classes)
def get_train(self):
if self._train_ds is None:
self.__load()
return self._train_ds
def get_val(self):
if self._val_ds is None:
self.__load()
return self._val_ds
def get_test(self):
if self._test_ds is None:
self.__load()
return self._test_ds
def get_combined(self):
return self.get_train().concatenate(self.get_val()).concatenate(self.get_test())
def __load(self):
args = self._load_dataset(self._image_size, self._batch_size, self._map_fn)
train_ds, val_ds, test_ds = self._split_dataset(*args)
self._train_ds = train_ds
self._val_ds = val_ds
self._test_ds = test_ds
if PRINT_SIZE:
print(self.name, "dataset loaded")
print("Training size:", train_ds.cardinality().numpy())
print("Validation size:", val_ds.cardinality().numpy())
print("Evaluation size:", test_ds.cardinality().numpy())
@abstractmethod
def _load_dataset(self, image_size, batch_size, map_fn):
pass
@abstractmethod
def _split_dataset(self, *args):
pass
import tensorflow as tf
from src.data.base import BaseDataset
from src.data import AsbDataset
DEFAULT_BATCH_SIZE = 32
DEFAULT_IMAGE_SIZE = (32, 32)
CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
class Cifar10(BaseDataset):
class Cifar10(AsbDataset):
def __init__(self, image_size=DEFAULT_IMAGE_SIZE, batch_size=DEFAULT_BATCH_SIZE, map_fn=None):
super(Cifar10, self).__init__(name='cifar10', classes=CLASS_NAMES, image_size=image_size, batch_size=batch_size, map_fn=map_fn)
......@@ -33,9 +34,6 @@ class Cifar10(BaseDataset):
train_ds = train_ds.map(map_fn).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(map_fn).prefetch(tf.data.AUTOTUNE)
return train_ds, test_ds
def _split_dataset(self, train_ds, test_ds):
train_ds_size = train_ds.cardinality().numpy()
train_ds = train_ds.skip(train_ds_size / 10)
val_ds = train_ds.take(train_ds_size / 10)
......
import tensorflow as tf
from src.data.base import BaseDataset
from src.data import AsbDataset
DEFAULT_BATCH_SIZE = 32
DEFAULT_IMAGE_SIZE = (400, 320)
CLASS_NAMES = ['fish', 'dog', 'player', 'saw', 'building', 'music', 'truck', 'gas', 'ball', 'parachute']
class Imagenette(BaseDataset):
class Imagenette(AsbDataset):
def __init__(self, image_size=DEFAULT_IMAGE_SIZE, batch_size=DEFAULT_BATCH_SIZE, map_fn=None):
super(Imagenette, self).__init__(name='imagenette', classes=CLASS_NAMES, image_size=image_size, batch_size=batch_size, map_fn=map_fn)
......@@ -33,9 +34,6 @@ class Imagenette(BaseDataset):
train_ds = train_ds.map(map_fn).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(map_fn).prefetch(tf.data.AUTOTUNE)
return train_ds, test_ds
def _split_dataset(self, train_ds, test_ds):
test_ds_size = test_ds.cardinality().numpy()
val_ds = test_ds.take(test_ds_size / 2)
test_ds = test_ds.skip(test_ds_size / 2)
......
import tensorflow as tf
from src.data.base import BaseDataset
from src.data import AsbDataset
DEFAULT_BATCH_SIZE = 6
DEFAULT_IMAGE_SIZE = (400, 320)
CLASS_NAMES = ['building', 'dog', 'player']
class Simple3(BaseDataset):
class Simple3(AsbDataset):
def __init__(self, image_size=DEFAULT_IMAGE_SIZE, batch_size=DEFAULT_BATCH_SIZE, map_fn=None):
super(Simple3, self).__init__(name='simple3', classes=CLASS_NAMES, image_size=image_size, batch_size=batch_size, map_fn=map_fn)
......@@ -22,9 +23,6 @@ class Simple3(BaseDataset):
if map_fn is not None:
ds = ds.map(map_fn).prefetch(tf.data.AUTOTUNE)
return ds
def _split_dataset(self, ds):
ds_size = ds.cardinality().numpy()
train_ds = ds.take(ds_size * 0.6)
val_ds = ds.skip(ds_size * 0.6).take(ds_size * 0.2)
......
......@@ -9,7 +9,7 @@ from tensorboard.plugins import projector
from google.protobuf import text_format
from src.utils.common import get_datadir, get_modeldir, get_logdir_root
from src.data.base import BaseDataset
from src.data import AsbDataset
def _save_vectors_path(values, labels, path):
......@@ -117,7 +117,7 @@ def project_embeddings(image_vectors, labels, name='projection'):
projector.visualize_embeddings(root_dir, config)
def load_weights_of(model: tf.keras.Model, dataset: BaseDataset):
def load_weights_of(model: tf.keras.Model, dataset: AsbDataset):
model_file = get_modeldir(model.name + '_' + dataset.name + '.h5')
if Path(model_file).exists():
......@@ -131,7 +131,7 @@ def load_weights_of(model: tf.keras.Model, dataset: BaseDataset):
model.evaluate(dataset.get_test())
def get_embeddings_of(model: tf.keras.Model, dataset: BaseDataset):
def get_embeddings_of(model: tf.keras.Model, dataset: AsbDataset):
embedding_file = get_datadir(model.name + '_' + dataset.name + '.pkl')
if Path(embedding_file).exists():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment