Commit 844716cb authored by Oleh Astappiev's avatar Oleh Astappiev
Browse files

fix: exports of the datasets

parent 542dfe7d
...@@ -58,7 +58,7 @@ def export_embeddings(): ...@@ -58,7 +58,7 @@ def export_embeddings():
# write the header # write the header
writer.writerow(header) writer.writerow(header)
seamese = models.load_model(get_modeldir('seamese1.tf')) seamese = models.load_model(get_modeldir('seamese_cifar10.tf'))
embedding_vds = (cifar10_vds.map(process_images_couple).batch(batch_size=32, drop_remainder=False)) embedding_vds = (cifar10_vds.map(process_images_couple).batch(batch_size=32, drop_remainder=False))
print('predicting embeddings') print('predicting embeddings')
...@@ -66,12 +66,12 @@ def export_embeddings(): ...@@ -66,12 +66,12 @@ def export_embeddings():
print('embeddings done') print('embeddings done')
for i, (label) in enumerate(cifar10_labels): for i, (label) in enumerate(cifar10_labels):
label_str = label label_str = ','.join(map(str, label))
value_str = ','.join(map(str, embeddings[i])) value_str = ','.join(map(str, embeddings[i]))
writer.writerow([i, label_str, value_str]) writer.writerow([i, label_str, value_str])
# export_hsv() # export_hsv()
export_sift() # export_sift()
# export_embeddings() export_embeddings()
print('done') print('done')
...@@ -9,7 +9,9 @@ from tensorflow.keras import datasets ...@@ -9,7 +9,9 @@ from tensorflow.keras import datasets
# Load dataset # Load dataset
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
cifar10_vds = tf.data.Dataset.from_tensor_slices((np.concatenate([train_images, test_images]), np.concatenate([train_labels, test_labels]))) cifar10_images = np.concatenate([train_images, test_images])
cifar10_labels = np.concatenate([train_labels, test_labels])
cifar10_vds = tf.data.Dataset.from_tensor_slices((cifar10_images, cifar10_labels))
# test HSV # test HSV
print('test HSV') print('test HSV')
...@@ -18,4 +20,14 @@ plot_hsv(cifar10_vds) ...@@ -18,4 +20,14 @@ plot_hsv(cifar10_vds)
print('test SIFT') print('test SIFT')
plot_sift(cifar10_vds) plot_sift(cifar10_vds)
# 906, 1692, 1711, 2610, 3259, 3418, 3789, 4277, 4975, 5010, 5255, 5867, 5988, 6406, 7089, 7365, 8072
# 8443, 8998, 9008, 9323, 9664, 9881, 9903, 9985, 10095, 11650, 13043, 13075, 13841, 14698, 15443
# 16004, 16733, 16888, 18948, 19378, 20015, 20233, 20467, 20621, 20696, 20778, 22672, 22804, 22904
# 23252, 23654, 23985, 25236, 25734, 25931, 27596, 27931, 28016, 28300, 28387, 28807, 30029, 31581
# 32024, 32117, 32629, 32861, 33328, 33489, 33589, 34466, 35063, 35202, 35719, 35877, 35985, 36560
# 36777, 37358, 37439, 38224, 38345, 39942, 40389, 40621, 40864, 41454, 41902, 42017, 43593, 44207
# 44226, 44257, 45801, 47091, 47375, 48663, 48690, 48884, 52366, 52622, 52847, 53227, 53248, 53423
# 53429, 53444, 53660, 53759, 53952, 54957, 55164, 55189, 55762, 56549, 56574, 57105, 57171, 58485
# 58572, 58826, 59318, 59970
print('done') print('done')
...@@ -79,4 +79,4 @@ embedding = alexnet(im_input) ...@@ -79,4 +79,4 @@ embedding = alexnet(im_input)
image_vector = projection_model(embedding) image_vector = projection_model(embedding)
inference_model = Model(inputs=im_input, outputs=image_vector) inference_model = Model(inputs=im_input, outputs=image_vector)
inference_model.save(get_modeldir('seamese1.tf'), save_format='tf', include_optimizer=False) inference_model.save(get_modeldir('seamese_cifar10.tf'), save_format='tf', include_optimizer=False)
...@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt ...@@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import cv2 import cv2
from src.utils.common import subplot_image from src.utils.common import *
def extract_hsv(image): def extract_hsv(image):
...@@ -10,14 +10,15 @@ def extract_hsv(image): ...@@ -10,14 +10,15 @@ def extract_hsv(image):
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
# The ranges of the 3 HSV channels in opencv are 0-180, 0-256, 0-256 respectively # The ranges of the 3 HSV channels in opencv are 0-180, 0-256, 0-256 respectively
# Bins is set to 1365, so that each picture can be represented by a 4000-dimensional vector # Bins is set to 1365, so that each picture can be represented by a 4000-dimensional vector
histh = cv2.calcHist([hsv], [0], None, [1365], [0, 180]) histh = cv2.calcHist([hsv], [0], None, [170], [0, 180])
hists = cv2.calcHist([hsv], [1], None, [1365], [0, 256]) hists = cv2.calcHist([hsv], [1], None, [171], [0, 256])
histv = cv2.calcHist([hsv], [2], None, [1365], [0, 256]) histv = cv2.calcHist([hsv], [2], None, [171], [0, 256])
# normalize the histogram # normalize the histogram
histh /= histh.sum() histh /= histh.sum()
hists /= hists.sum() hists /= hists.sum()
histv /= histv.sum() histv /= histv.sum()
hist_array = np.array([histh, hists, histv]) hist_array = np.append(histh, hists)
hist_array = np.append(hist_array, histv)
# return the flattened histogram as the feature vector # return the flattened histogram as the feature vector
return histh, hists, histv, hist_array.flatten() return histh, hists, histv, hist_array.flatten()
...@@ -25,9 +26,12 @@ def extract_hsv(image): ...@@ -25,9 +26,12 @@ def extract_hsv(image):
def plot_hsv(dataset): def plot_hsv(dataset):
plt.figure(figsize=(20, 20)) plt.figure(figsize=(20, 20))
for i, (image, label) in enumerate(dataset.take(3)): for i, (image, label) in enumerate(dataset.take(3)):
subplot_image(3, 2, i * 2 + 1, image, "Original image") # from smaller image only smaller number of key points can be extracted
img = cv2.resize(image.numpy(), target_shape)
hist0_s, hist1_s, hist2_s, hist_s = extract_hsv(image.numpy()) subplot_image(3, 2, i * 2 + 1, img, "Original image")
hist0_s, hist1_s, hist2_s, hist_s = extract_hsv(img)
# print('the length of histogram of the sample', len(hist_s)) # print('the length of histogram of the sample', len(hist_s))
# subplot_image(3, 2, i * 2 + 2, image, "HSV Histogram") # subplot_image(3, 2, i * 2 + 2, image, "HSV Histogram")
......
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import cv2 import cv2
from src.utils.common import subplot_image from src.utils.common import *
def extract_sift(image): def extract_sift(image):
sift = cv2.SIFT_create(32) sift = cv2.SIFT_create(8)
# Calculate the keypoint and each point description of the image # Calculate the keypoint and each point description of the image
keypoints, features = sift.detectAndCompute(image, None) keypoints, features = sift.detectAndCompute(image, None)
return keypoints, features return keypoints, features
...@@ -14,10 +14,10 @@ def extract_sift(image): ...@@ -14,10 +14,10 @@ def extract_sift(image):
def plot_sift(dataset): def plot_sift(dataset):
plt.figure(figsize=(20, 20)) plt.figure(figsize=(20, 20))
for i, (image, label) in enumerate(dataset.take(3)): for i, (image, label) in enumerate(dataset.take(3)):
subplot_image(3, 2, i * 2 + 1, image, "Original image")
# from smaller image only smaller number of key points can be extracted # from smaller image only smaller number of key points can be extracted
img = cv2.resize(image.numpy(), (230, 230)) img = cv2.resize(image.numpy(), target_shape)
subplot_image(3, 2, i * 2 + 1, img, "Original image")
keypoints, features = extract_sift(img) keypoints, features = extract_sift(img)
img_kp = img.copy() img_kp = img.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment