Commit 195641ef authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

feat: add num_classes to calculate weights inside siamese class

parent 04f1733a
...@@ -74,8 +74,12 @@ class SiameseModel(Model): ...@@ -74,8 +74,12 @@ class SiameseModel(Model):
**kwargs): **kwargs):
super().compile(optimizer=optimizer, loss=tfa.losses.ContrastiveLoss(margin=loss_margin), **kwargs) super().compile(optimizer=optimizer, loss=tfa.losses.ContrastiveLoss(margin=loss_margin), **kwargs)
def fit(self, x=None, y=None, batch_size=None, epochs=NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, callbacks=[tensorboard_cb], **kwargs): def fit(self, x=None, y=None, epochs=NUM_EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, num_classes=None, callbacks=[tensorboard_cb], **kwargs):
return super().fit(x=x, y=y, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, **kwargs)
if num_classes is not None and 'class_weight' not in kwargs:
kwargs = dict(kwargs, class_weight={0: 1 / num_classes, 1: (num_classes - 1) / num_classes})
return super().fit(x=x, y=y, epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks, **kwargs)
@staticmethod @staticmethod
def prepare_dataset(embeddings, labels): def prepare_dataset(embeddings, labels):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment