2
0
mirror of https://github.com/Shawn-Shan/fawkes.git synced 2024-12-22 07:09:33 +05:30

local evaluation code

Former-commit-id: 7f9e32ef86adc18f33704b8824367a610df80e69 [formerly 30a0d4935de4da1c762819eab2bfc2e5411ff6a1]
Former-commit-id: 157928d16012db8f7470abb7a297c9740860b5b7
This commit is contained in:
Shawn-Shan 2020-07-13 15:34:10 -05:00
parent 5e8800ca3d
commit 30d8aba8be
7 changed files with 223 additions and 624 deletions

184
evaluation/eval_local.py Normal file
View File

@ -0,0 +1,184 @@
import argparse
import glob
import os
import random
import sys
import keras
import numpy as np
random.seed(1000)
from fawkes.utils import init_gpu, load_extractor, load_victim_model, get_file, preprocess, Faces
from keras.preprocessing import image
from keras.utils import to_categorical
from fawkes.align_face import aligner
def select_samples(data_dir):
all_data_path = []
for cls in os.listdir(data_dir):
cls_dir = os.path.join(data_dir, cls)
for data_path in os.listdir(cls_dir):
all_data_path.append(os.path.join(cls_dir, data_path))
return all_data_path
class DataGenerator(object):
def __init__(self, original_images, protect_images):
l = int(len(original_images) * 0.7)
self.original_images_test = original_images[l:]
self.protect_images_train = protect_images[:l]
other_classes = range(0, 20946)
selected_classes = random.sample(other_classes, args.num_other_classes)
print("Downloading additional data...")
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
self.id2label = {-1: 0}
self.id2path = {}
self.id2pathtest = {}
idx = 1
for target_data_id in selected_classes:
image_dir = os.path.join(model_dir, "target_data/{}".format(target_data_id))
os.makedirs(os.path.join(model_dir, "target_data"), exist_ok=True)
os.makedirs(image_dir, exist_ok=True)
self.id2label[target_data_id] = idx
idx += 1
for i in range(10):
if os.path.exists(os.path.join(model_dir, "target_data/{}/{}.jpg".format(target_data_id, i))):
continue
try:
get_file("{}.jpg".format(i),
"http://sandlab.cs.uchicago.edu/fawkes/files/target_data/{}/{}.jpg".format(target_data_id,
i),
cache_dir=model_dir, cache_subdir='target_data/{}/'.format(target_data_id))
except Exception:
print("error getting http://sandlab.cs.uchicago.edu/fawkes/files/target_data/{}/{}.jpg".format(
target_data_id,
i))
pass
all_pathes = glob.glob(os.path.join(model_dir, 'target_data/{}/*.jpg'.format(target_data_id)))
test_path = random.sample(all_pathes, 2)
train_path = [p for p in all_pathes if p not in test_path]
self.id2path[target_data_id] = train_path
self.id2pathtest[target_data_id] = test_path
self.num_classes = 1 + len(self.id2path)
np.random.seed(12345)
self.all_id = selected_classes + [-1]
def generate(self, test=False):
while True:
batch_X = []
batch_Y = []
cur_batch_path = np.random.choice(self.all_id, 32)
for p in cur_batch_path:
cur_y = self.id2label[p]
if test and p == -1:
continue
# protect class images in train dataset
elif p == -1:
cur_x = random.choice(self.protect_images_train)
else:
if test:
cur_path = random.choice(self.id2pathtest[p])
else:
cur_path = random.choice(self.id2path[p])
im = image.load_img(cur_path, target_size=(224, 224))
cur_x = image.img_to_array(im)
cur_x = preprocess(cur_x, 'imagenet')
batch_X.append(cur_x)
batch_Y.append(cur_y)
batch_X = np.array(batch_X)
batch_Y = to_categorical(np.array(batch_Y), num_classes=self.num_classes)
yield batch_X, batch_Y
def test_original(self):
original_y = to_categorical([0] * len(self.original_images_test), num_classes=self.num_classes)
return self.original_images_test, original_y
class CallbackGenerator(keras.callbacks.Callback):
def __init__(self, original_imgs, protect_imgs, original_y, original_protect_y, test_gen):
self.original_imgs = original_imgs
self.protect_imgs = protect_imgs
self.original_y = original_y
self.original_protect_y = original_protect_y
self.test_gen = test_gen
def on_epoch_end(self, epoch, logs=None):
_, original_acc = self.model.evaluate(self.original_imgs, self.original_y, verbose=0)
print("Epoch: {} - Protection Success Rate {:.4f}".format(epoch, 1 - original_acc))
def main():
sess = init_gpu(args.gpu)
ali = aligner(sess)
print("Build attacker's model")
image_paths = glob.glob(os.path.join(args.directory, "*"))
cloak_file_name = "low_cloaked"
original_image_paths = sorted([path for path in image_paths if "cloaked" not in path.split("/")[-1]])
protect_image_paths = sorted([path for path in image_paths if cloak_file_name in path.split("/")[-1]])
original_faces = Faces(original_image_paths, ali, verbose=1, eval_local=True)
original_faces = original_faces.cropped_faces
cloaked_faces = Faces(protect_image_paths, ali, verbose=1, eval_local=True)
cloaked_faces = cloaked_faces.cropped_faces
if len(original_faces) <= 10:
raise Exception("Must have more than 10 protected images to run the evaluation")
num_classes = args.num_other_classes + 1
datagen = DataGenerator(original_faces, cloaked_faces)
original_test_X, original_test_Y = datagen.test_original()
print("{} Training Images | {} Testing Images".format(len(datagen.protect_images_train), len(original_test_X)))
train_generator = datagen.generate()
test_generator = datagen.generate(test=True)
base_model = load_extractor(args.base_model)
model = load_victim_model(teacher_model=base_model, number_classes=num_classes)
cb = CallbackGenerator(original_imgs=original_test_X, protect_imgs=cloaked_faces, original_y=original_test_Y,
original_protect_y=None,
test_gen=test_generator)
model.fit_generator(train_generator, steps_per_epoch=num_classes * 10 // 32,
epochs=args.n_epochs,
verbose=2,
callbacks=[cb]
)
_, acc_original = model.evaluate(original_test_X, original_test_Y, verbose=0)
print("Protection Success Rate: {:.4f}".format(1 - acc_original))
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str,
help='GPU id', default='0')
parser.add_argument('--dataset', type=str,
help='name of dataset', default='scrub')
parser.add_argument('--num_other_classes', type=int,
help='name of dataset', default=1000)
parser.add_argument('--directory', '-d', type=str,
help='name of the cloak result directory', required=True)
parser.add_argument('--base_model', type=str,
help='the feature extractor used for tracker model training. ', default='low_extract')
parser.add_argument('--n_epochs', type=int, default=5)
return parser.parse_args(argv)
if __name__ == '__main__':
args = parse_arguments(sys.argv[1:])
main()

View File

@ -8,6 +8,8 @@ import shutil
import sys
import tarfile
import zipfile
import PIL
import six
from six.moves.urllib.error import HTTPError, URLError
@ -25,7 +27,7 @@ from keras.models import Model
from keras.preprocessing import image
from skimage.transform import resize
from fawkes.align_face import align, aligner
from fawkes.align_face import align
from six.moves.urllib.request import urlopen
if sys.version_info[0] == 2:
@ -62,8 +64,10 @@ def clip_img(X, preprocessing='raw'):
def load_image(path):
img = Image.open(path)
try:
img = Image.open(path)
except PIL.UnidentifiedImageError:
return None
if img._getexif() is not None:
for orientation in ExifTags.TAGS.keys():
@ -87,7 +91,7 @@ def load_image(path):
class Faces(object):
def __init__(self, image_paths, aligner, verbose=1):
def __init__(self, image_paths, aligner, verbose=1, eval_local=False):
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
if not os.path.exists(os.path.join(model_dir, "mtcnn.p.gz")):
os.makedirs(model_dir, exist_ok=True)
@ -105,8 +109,17 @@ class Faces(object):
print("Identify {} images".format(len(image_paths)))
for i, p in enumerate(image_paths):
cur_img = load_image(p)
if cur_img is None:
continue
self.org_faces.append(cur_img)
align_img = align(cur_img, self.aligner, margin=0.7)
if eval_local:
margin = 0
else:
margin = 0.7
align_img = align(cur_img, self.aligner, margin=margin)
if align_img is None:
print("Find 0 face(s) in {}".format(p.split("/")[-1]))
continue
@ -120,9 +133,12 @@ class Faces(object):
print("Find {} face(s) in {}".format(len(cur_faces), p.split("/")[-1]))
for img in cur_faces:
long_size = max([img.shape[1], img.shape[0]])
base = np.zeros((long_size, long_size, 3))
base[0:img.shape[0], 0:img.shape[1], :] = img
if eval_local:
base = resize(img, (224, 224))
else:
long_size = max([img.shape[1], img.shape[0]])
base = np.zeros((long_size, long_size, 3))
base[0:img.shape[0], 0:img.shape[1], :] = img
cur_faces_square.append(base)
cur_index = align_img[1]
@ -170,21 +186,6 @@ def dump_dictionary_as_json(dict, outfile):
f.write(j.encode())
def fix_gpu_memory(mem_fraction=1):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf_config = None
if tf.test.is_gpu_available():
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_fraction)
tf_config = tf.ConfigProto(gpu_options=gpu_options)
tf_config.gpu_options.allow_growth = True
tf_config.log_device_placement = False
init_op = tf.global_variables_initializer()
sess = tf.Session(config=tf_config)
sess.run(init_op)
K.set_session(sess)
return sess
def load_victim_model(number_classes, teacher_model=None, end2end=False, dropout=0):
for l in teacher_model.layers:
l.trainable = end2end
@ -212,6 +213,21 @@ def init_gpu(gpu_index, force=False):
return sess
def fix_gpu_memory(mem_fraction=1):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf_config = None
if tf.test.is_gpu_available():
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_fraction)
tf_config = tf.ConfigProto(gpu_options=gpu_options)
tf_config.gpu_options.allow_growth = True
tf_config.log_device_placement = False
init_op = tf.global_variables_initializer()
sess = tf.Session(config=tf_config)
sess.run(init_op)
K.set_session(sess)
return sess
def preprocess(X, method):
assert method in {'raw', 'imagenet', 'inception', 'mnist'}

View File

@ -1,65 +0,0 @@
# Fawkes
Code implementation of the paper "[Fawkes: Protecting Personal Privacy against Unauthorized Deep Learning Models](https://arxiv.org/pdf/2002.08327.pdf)", at *USENIX Security 2020*.
### BEFORE YOU RUN OUR CODE
We appreciate your interest in our work and for trying out our code. We've noticed several cases where incorrect configuration leads to poor performances of protection. If you also observe low detection performance far away from what we presented in the paper, please feel free to open an issue in this repo or contact any of the authors directly. We are more than happy to help you debug your experiment and find out the correct configuration.
### ABOUT
This repository contains code implementation of the paper "[Fawkes: Protecting Personal Privacy against Unauthorized Deep Learning Models](https://arxiv.org/pdf/2002.08327.pdf)", at *USENIX Security 2020*.
### DEPENDENCIES
Our code is implemented and tested on Keras with TensorFlow backend. Following packages are used by our code.
- `keras==2.3.1`
- `numpy==1.18.4`
- `tensorflow-gpu==1.13.1`
Our code is tested on `Python 3.6.8`
### HOWTO
#### Download and Config Datasets
The first step is to download several datasets for protection and target selection.
1. Download the following dataset to your local machine. After downloading the datasets, restructure it the same way as the FaceScrub dataset downloaded.
- FaceScrub -- used for protection evaluation (link)
- VGGFace1 -- used for target select (link)
- VGGFace2 -- used for target select (link)
- WebFace -- used for target select (link)
2. Config datasets
open `fawkes/config.py` and update the `DATASETS` dictionary with the path to each dataset. Then run `python fawkes/config.py`. Every time the datasets are updated or moved, remember to rerun the command with the updated path.
3. Calculate embeddings using feature extractor.
Run `python3 fawkes/prepare_feature_extractor.py --candidate-datasets scrub vggface1 vggface2 webface`. This will calculate and cache the embeddings using the default feature extractor we provide. To use a customized feature extractor, please look at the Advance section at the end.
#### Generate Cloak for Images
To generate cloak, run
`python3 fawkes/protection.py --gpu 0 --dataset scrub --feature-extractor webface_dense_robust_extract`
For more information about the detailed parameters, please read `fawkes/protection.py`.
The code will output a directory in `results/` with `cloak_data.p` inside. You can check the cloaked images or inspect the changes in `this notebook`.
#### Evaluate Cloak Effectiveness
To evaluate the cloak, run `python3 fawkes/eval_cloak.py --gpu 0 --cloak_data PATH-TO-RESULT-DIRECTORY --transfer_model vggface2_inception_extract`.
The code will print out the tracker model accuracy on uncloaked/original test images of the protected user, which should be close to 0.
#### Exisiting Feature extractors
We shared three different feature extractors under feature_extractors/
1. low_extract.h5: trained on WebFace dataset with DenseNet architecture.
2. mid_extract.h5: VGGFace2 dataset with DenseNet architecture. Trained with PGD adversarial training for 5 epochs.
3. high_extract.h5: WebFace dataset with DenseNet architecture. Trained with PGD adversarial training for 20 epochs.
4. high2_extract.h5: VGGFace2 dataset with DenseNet architecture. Trained with PGD adversarial training for 20 epochs.
### Citation
```
@inproceedings{shan2020fawkes,
title={Fawkes: Protecting Personal Privacy against Unauthorized Deep Learning Models},
author={Shan, Shawn and Wenger, Emily and Zhang, Jiayun and Li, Huiying and Zheng, Haitao and Zhao, Ben Y},
booktitle="Proc. of USENIX Security",
year={2020}
}
```

View File

View File

@ -1,420 +0,0 @@
import http.client
import json
import random
import time
import urllib.error
import urllib.parse
import urllib.request
import requests
# Face API Key and Endpoint
f = open('api_key.txt', 'r')
data = f.read().split("\n")
subscription_key = data[0]
uri_base = data[1]
cloak_image_base = 'http://sandlab.cs.uchicago.edu/fawkes/files/cloak/{}_high_cloaked.png'
original_image_base = 'http://sandlab.cs.uchicago.edu/fawkes/files/cloak/{}.png'
def test_cloak():
NUM_TRAIN = 5
total_idx = range(0, 82)
TRAIN_RANGE = random.sample(total_idx, NUM_TRAIN)
TEST_RANGE = random.sample([i for i in total_idx if i not in TRAIN_RANGE], 20)
personGroupId = 'all'
# delete_personGroup(personGroupId)
# create_personGroupId(personGroupId, personGroupId)
with open("protect_personId.txt", 'r') as f:
protect_personId = f.read()
print(protect_personId)
delete_personGroupPerson(personGroupId, protect_personId)
protect_personId = create_personId(personGroupId, 'Emily')
with open("protect_personId.txt", 'w') as f:
f.write(protect_personId)
print("Created protect personId: {}".format(protect_personId))
for idx in TRAIN_RANGE:
image_url = cloak_image_base.format(idx)
r = add_persistedFaceId(personGroupId, protect_personId, image_url)
if r is not None:
print("Added {}".format(idx))
else:
print("Unable to add {}-th image of protect person".format(idx))
# add other people
# for idx_person in range(5000, 15000):
# personId = create_personId(personGroupId, str(idx_person))
# print("Created personId: {}".format(idx_person))
# for idx_image in range(10):
# image_url = "http://sandlab.cs.uchicago.edu/fawkes/files/target_data/{}/{}.jpg".format(
# idx_person, idx_image)
# r = add_persistedFaceId(personGroupId, personId, image_url)
# if r is not None:
# print("Added {}".format(idx_image))
# else:
# print("Unable to add {}-th image".format(idx_image))
# train model based on personGroup
train_personGroup(personGroupId)
while json.loads(get_trainStatus(personGroupId))['status'] != 'succeeded':
time.sleep(2)
# list_personGroupPerson(personGroupId)
# test original image
idx_range = TEST_RANGE
acc = 0.
tot = 0.
for idx in idx_range:
original_image_url = original_image_base.format(idx)
faceId = detect_face(original_image_url)
if faceId is None:
print("{} does not exist".format(idx))
continue
original_faceIds = [faceId]
# verify
res = eval(original_faceIds, personGroupId, protect_personId)
if res:
acc += 1.
tot += 1.
acc /= tot
print(acc) # 1.0
def list_personGroups():
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("GET", "/face/v1.0/persongroups?%s" % params, body, headers)
response = conn.getresponse()
data = response.read()
print(data)
conn.close()
def detect_face(image_url):
r = requests.get(image_url)
if r.status_code != 200:
return None
headers = {
# Request headers
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
# Request parameters
'returnFaceId': 'true',
'returnFaceLandmarks': 'false',
'recognitionModel': 'recognition_01',
'returnRecognitionModel': 'false',
'detectionModel': 'detection_01',
})
body = json.dumps({
'url': image_url
})
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST", "/face/v1.0/detect?%s" % params, body, headers)
response = conn.getresponse()
data = json.loads(response.read())
#
# print(data)
conn.close()
return data[0]["faceId"]
def verify_face(faceId, personGroupId, personId):
# html header
headers = {
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
# image URL
body = json.dumps({
"faceId": faceId,
"personId": personId,
"PersonGroupId": personGroupId
})
# Call Face API
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST", "/face/v1.0/verify?%s" % params, body, headers)
response = conn.getresponse()
data = json.loads(response.read())
conn.close()
return data
def create_personGroupId(personGroupId, personGroupName):
headers = {
# Request headers
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({
"name": personGroupName
})
conn = http.client.HTTPSConnection(uri_base)
conn.request("PUT", "/face/v1.0/persongroups/{}?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = response.read()
print(data)
conn.close()
def create_personId(personGroupId, personName):
headers = {
# Request headers
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({
"name": personName
})
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST", "/face/v1.0/persongroups/{}/persons?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = json.loads(response.read())
# print(data)
conn.close()
return data["personId"]
def add_persistedFaceId(personGroupId, personId, image_url):
r = requests.get(image_url)
if r.status_code != 200:
return None
headers = {
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
'personGroupId': personGroupId,
'personId': personId
})
body = json.dumps({
'url': image_url
})
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST",
"/face/v1.0/persongroups/{}/persons/{}/persistedFaces?%s".format(personGroupId, personId) % params,
body, headers)
response = conn.getresponse()
data = json.loads(response.read())
conn.close()
if "persistedFaceId" not in data:
return None
return data["persistedFaceId"]
def list_personGroupPerson(personGroupId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("GET", "/face/v1.0/persongroups/{}/persons?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = json.loads(response.read())
conn.close()
for person in data:
print(person["personId"], len(person["persistedFaceIds"]))
def get_personGroupPerson(personGroupId, personId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("GET", "/face/v1.0/persongroups/{}/persons/{}?%s".format(personGroupId, personId) % params, body,
headers)
response = conn.getresponse()
data = json.loads(response.read())
print(data)
conn.close()
def train_personGroup(personGroupId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST", "/face/v1.0/persongroups/{}/train?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = response.read()
print(data)
conn.close()
def eval(original_faceIds, personGroupId, protect_personId):
headers = {
'Content-Type': 'application/json',
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({
'faceIds': original_faceIds,
'personGroupId': personGroupId,
'maxNumOfCandidatesReturned': 1
})
conn = http.client.HTTPSConnection(uri_base)
conn.request("POST", "/face/v1.0/identify?%s" % params, body, headers)
response = conn.getresponse()
data = json.loads(response.read())
conn.close()
face = data[0]
print(face)
if len(face["candidates"]) and face["candidates"][0]["personId"] == protect_personId:
return True
else:
return False
def delete_personGroupPerson(personGroupId, personId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("DELETE", "/face/v1.0/persongroups/{}/persons/{}?%s".format(personGroupId, personId) % params, body,
headers)
response = conn.getresponse()
data = response.read()
print(data)
conn.close()
def get_trainStatus(personGroupId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("GET", "/face/v1.0/persongroups/{}/training?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = response.read()
conn.close()
return data
def delete_personGroup(personGroupId):
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
}
params = urllib.parse.urlencode({
})
body = json.dumps({})
conn = http.client.HTTPSConnection(uri_base)
conn.request("DELETE", "/face/v1.0/persongroups/{}?%s".format(personGroupId) % params, body, headers)
response = conn.getresponse()
data = response.read()
print(data)
conn.close()
def main():
test_cloak()
# delete_personGroup('cloaking')
# delete_personGroup('cloaking-emily')
# delete_personGroup('pubfig')
# list_personGroups()
# exit()
# personGroupId = 'cloaking'
# create_personGroupId(personGroupId, 'cloaking')
# delete_personGroupPerson(personGroupId, '0ac606cd-24b3-440f-866a-31adf2a1b446')
# add_protect_person(personGroupId, 'Emily')
# protect_personId = create_personId(personGroupId, 'Emily')
# add_sybil_person(personGroupId, 'sybil')
#
# # train model based on personGroup
# train_personGroup(personGroupId)
# get_trainStatus(personGroupId)
# add_other_person(personGroupId)
# list_personGroupPerson(personGroupId)
#
# idx_range = range(72, 82)
# original_faceIds = []
# for idx in idx_range:
# original_image_url = "https://super.cs.uchicago.edu/~shawn/cloaked/{}_o.png".format(idx)
# faceId = detect_face(original_image_url)
# original_faceIds.append(faceId)
#
# # verify
# eval(original_faceIds, personGroupId, protect_personId)
if __name__ == '__main__':
test_cloak()

View File

@ -1,41 +0,0 @@
import glob
import json
import os
DATASETS = {
"pubfig": "../data/pubfig",
"scrub": "/home/shansixioing/fawkes/data/scrub/",
"vggface2": "/mnt/data/sixiongshan/data/vggface2/",
"webface": "/mnt/data/sixiongshan/data/webface/",
"youtubeface": "/mnt/data/sixiongshan/data/youtubeface/keras_flow_data/",
}
def main():
config = {}
for dataset in DATASETS.keys():
path = DATASETS[dataset]
if not os.path.exists(path):
print("Dataset path for {} does not exist, skipped".format(dataset))
continue
train_dir = os.path.join(path, "train")
test_dir = os.path.join(path, "test")
if not os.path.exists(train_dir):
print("Training dataset path for {} does not exist, skipped".format(dataset))
continue
num_classes = len(os.listdir(train_dir))
num_images = len(glob.glob(os.path.join(train_dir, "*/*")))
if num_images == 0 or num_classes == 0 or num_images == num_classes:
raise Exception("Dataset {} is not setup as detailed in README.".format(dataset))
config[dataset] = {"train_dir": train_dir, "test_dir": test_dir, "num_classes": num_classes,
"num_images": num_images}
print("Successfully config {}".format(dataset))
j = json.dumps(config)
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
with open(os.path.join(model_dir, "config.json"), "wb") as f:
f.write(j.encode())
if __name__ == '__main__':
main()

View File

@ -1,75 +0,0 @@
import argparse
import glob
import os
import pickle
import random
import sys
import numpy as np
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing import image
sys.path.append("../fawkes")
# from utils import load_extractor
import keras
def load_sample_dir(path, sample=10):
x_ls = []
image_paths = list(os.listdir(path))
random.shuffle(image_paths)
for i, file in enumerate(image_paths):
if i > sample:
break
cur_path = os.path.join(path, file)
im = image.load_img(cur_path, target_size=(224, 224))
im = image.img_to_array(im)
x_ls.append(im)
raw_x = np.array(x_ls)
return preprocess_input(raw_x)
def normalize(x):
return x / np.linalg.norm(x)
def main():
extractor = keras.models.load_model(args.feature_extractor)
path2emb = {}
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
for path in glob.glob(os.path.join(model_dir, "target_data/*")):
print(path)
idx = int(path.split("/")[-1])
cur_image_paths = glob.glob(os.path.join(path, "*"))
imgs = np.array([image.img_to_array(image.load_img(p, target_size=(224, 224))) for p in cur_image_paths])
imgs = preprocess_input(imgs)
cur_feature = extractor.predict(imgs)
cur_feature = np.mean(cur_feature, axis=0)
path2emb[idx] = cur_feature
model_path = os.path.join(model_dir, "{}_extract.h5".format(args.feature_extractor_name))
emb_path = os.path.join(model_dir, "{}_emb.p".format(args.feature_extractor_name))
extractor.save(model_path)
pickle.dump(path2emb, open(emb_path, "wb"))
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str,
help='GPU id', default='0')
parser.add_argument('--candidate-datasets', nargs='+',
help='path candidate datasets')
parser.add_argument('--feature-extractor', type=str,
help="path of the feature extractor used for optimization",
default="/home/shansixioing/fawkes/feature_extractors/high2_extract.h5")
parser.add_argument('--feature-extractor-name', type=str,
help="name of the feature extractor used for optimization",
default="high2")
return parser.parse_args(argv)
if __name__ == '__main__':
args = parse_arguments(sys.argv[1:])
main()