mirror of
https://github.com/Shawn-Shan/fawkes.git
synced 2026-06-12 21:50:46 +05:30
local evaluation code
Former-commit-id: 7f9e32ef86adc18f33704b8824367a610df80e69 [formerly 30a0d4935de4da1c762819eab2bfc2e5411ff6a1] Former-commit-id: 157928d16012db8f7470abb7a297c9740860b5b7
This commit is contained in:
+39
-23
@@ -8,6 +8,8 @@ import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import zipfile
|
||||
|
||||
import PIL
|
||||
import six
|
||||
from six.moves.urllib.error import HTTPError, URLError
|
||||
|
||||
@@ -25,7 +27,7 @@ from keras.models import Model
|
||||
from keras.preprocessing import image
|
||||
from skimage.transform import resize
|
||||
|
||||
from fawkes.align_face import align, aligner
|
||||
from fawkes.align_face import align
|
||||
from six.moves.urllib.request import urlopen
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
@@ -62,8 +64,10 @@ def clip_img(X, preprocessing='raw'):
|
||||
|
||||
|
||||
def load_image(path):
|
||||
|
||||
img = Image.open(path)
|
||||
try:
|
||||
img = Image.open(path)
|
||||
except PIL.UnidentifiedImageError:
|
||||
return None
|
||||
|
||||
if img._getexif() is not None:
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
@@ -87,7 +91,7 @@ def load_image(path):
|
||||
|
||||
|
||||
class Faces(object):
|
||||
def __init__(self, image_paths, aligner, verbose=1):
|
||||
def __init__(self, image_paths, aligner, verbose=1, eval_local=False):
|
||||
model_dir = os.path.join(os.path.expanduser('~'), '.fawkes')
|
||||
if not os.path.exists(os.path.join(model_dir, "mtcnn.p.gz")):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
@@ -105,8 +109,17 @@ class Faces(object):
|
||||
print("Identify {} images".format(len(image_paths)))
|
||||
for i, p in enumerate(image_paths):
|
||||
cur_img = load_image(p)
|
||||
if cur_img is None:
|
||||
continue
|
||||
|
||||
self.org_faces.append(cur_img)
|
||||
align_img = align(cur_img, self.aligner, margin=0.7)
|
||||
|
||||
if eval_local:
|
||||
margin = 0
|
||||
else:
|
||||
margin = 0.7
|
||||
align_img = align(cur_img, self.aligner, margin=margin)
|
||||
|
||||
if align_img is None:
|
||||
print("Find 0 face(s) in {}".format(p.split("/")[-1]))
|
||||
continue
|
||||
@@ -120,9 +133,12 @@ class Faces(object):
|
||||
print("Find {} face(s) in {}".format(len(cur_faces), p.split("/")[-1]))
|
||||
|
||||
for img in cur_faces:
|
||||
long_size = max([img.shape[1], img.shape[0]])
|
||||
base = np.zeros((long_size, long_size, 3))
|
||||
base[0:img.shape[0], 0:img.shape[1], :] = img
|
||||
if eval_local:
|
||||
base = resize(img, (224, 224))
|
||||
else:
|
||||
long_size = max([img.shape[1], img.shape[0]])
|
||||
base = np.zeros((long_size, long_size, 3))
|
||||
base[0:img.shape[0], 0:img.shape[1], :] = img
|
||||
cur_faces_square.append(base)
|
||||
|
||||
cur_index = align_img[1]
|
||||
@@ -170,21 +186,6 @@ def dump_dictionary_as_json(dict, outfile):
|
||||
f.write(j.encode())
|
||||
|
||||
|
||||
def fix_gpu_memory(mem_fraction=1):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf_config = None
|
||||
if tf.test.is_gpu_available():
|
||||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_fraction)
|
||||
tf_config = tf.ConfigProto(gpu_options=gpu_options)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
tf_config.log_device_placement = False
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess = tf.Session(config=tf_config)
|
||||
sess.run(init_op)
|
||||
K.set_session(sess)
|
||||
return sess
|
||||
|
||||
|
||||
def load_victim_model(number_classes, teacher_model=None, end2end=False, dropout=0):
|
||||
for l in teacher_model.layers:
|
||||
l.trainable = end2end
|
||||
@@ -212,6 +213,21 @@ def init_gpu(gpu_index, force=False):
|
||||
return sess
|
||||
|
||||
|
||||
def fix_gpu_memory(mem_fraction=1):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf_config = None
|
||||
if tf.test.is_gpu_available():
|
||||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_fraction)
|
||||
tf_config = tf.ConfigProto(gpu_options=gpu_options)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
tf_config.log_device_placement = False
|
||||
init_op = tf.global_variables_initializer()
|
||||
sess = tf.Session(config=tf_config)
|
||||
sess.run(init_op)
|
||||
K.set_session(sess)
|
||||
return sess
|
||||
|
||||
|
||||
def preprocess(X, method):
|
||||
assert method in {'raw', 'imagenet', 'inception', 'mnist'}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user