mirror of
https://github.com/Shawn-Shan/fawkes.git
synced 2024-11-13 08:11:30 +05:30
889fd933e8
Former-commit-id: 268fb7e6825ddfc1165fa7adc7c216f9d61005da [formerly 06376993a831c060c337ec6e7540252f0b2dfe09] Former-commit-id: c4812d40187a76a878e7d215d22ee84811b41896
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
import argparse
|
|
import glob
|
|
import os
|
|
import pickle
|
|
import random
|
|
import sys
|
|
|
|
import numpy as np
|
|
from keras.applications.vgg16 import preprocess_input
|
|
from keras.preprocessing import image
|
|
sys.path.append("../fawkes")
|
|
# from utils import load_extractor
|
|
import keras
|
|
|
|
def load_sample_dir(path, sample=10):
|
|
x_ls = []
|
|
image_paths = list(os.listdir(path))
|
|
random.shuffle(image_paths)
|
|
for i, file in enumerate(image_paths):
|
|
if i > sample:
|
|
break
|
|
cur_path = os.path.join(path, file)
|
|
im = image.load_img(cur_path, target_size=(224, 224))
|
|
im = image.img_to_array(im)
|
|
x_ls.append(im)
|
|
raw_x = np.array(x_ls)
|
|
return preprocess_input(raw_x)
|
|
|
|
|
|
def normalize(x):
|
|
return x / np.linalg.norm(x)
|
|
|
|
|
|
def main():
|
|
extractor = keras.models.load_model(args.feature_extractor)
|
|
|
|
path2emb = {}
|
|
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
|
|
for path in glob.glob(os.path.join(model_dir, "target_data/*")):
|
|
print(path)
|
|
idx = int(path.split("/")[-1])
|
|
cur_image_paths = glob.glob(os.path.join(path, "*"))
|
|
imgs = np.array([image.img_to_array(image.load_img(p, target_size=(224, 224))) for p in cur_image_paths])
|
|
imgs = preprocess_input(imgs)
|
|
|
|
cur_feature = extractor.predict(imgs)
|
|
cur_feature = np.mean(cur_feature, axis=0)
|
|
path2emb[idx] = cur_feature
|
|
|
|
model_path = os.path.join(model_dir, "{}_extract.h5".format(args.feature_extractor_name))
|
|
emb_path = os.path.join(model_dir, "{}_emb.p".format(args.feature_extractor_name))
|
|
extractor.save(model_path)
|
|
pickle.dump(path2emb, open(emb_path, "wb"))
|
|
|
|
|
|
|
|
def parse_arguments(argv):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--gpu', type=str,
|
|
help='GPU id', default='0')
|
|
parser.add_argument('--candidate-datasets', nargs='+',
|
|
help='path candidate datasets')
|
|
parser.add_argument('--feature-extractor', type=str,
|
|
help="path of the feature extractor used for optimization",
|
|
default="/home/shansixioing/fawkes/feature_extractors/high2_extract.h5")
|
|
parser.add_argument('--feature-extractor-name', type=str,
|
|
help="name of the feature extractor used for optimization",
|
|
default="high2")
|
|
|
|
return parser.parse_args(argv)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments(sys.argv[1:])
|
|
main()
|