import numpy as np
from timeit import default_timer as timer
import urllib.request
from PIL import ImageFile
from PIL import Image
import requests
import os, sys, time, stat, argparse, glob, csv, threading, warnings
import json, io, gc, random, cv2
import tensorflow as tf
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image, compute_resize_scale
import keras
from keras import backend as K

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# gc.collect()
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# K.clear_session()
# tf.keras.backend.clear_session()
# del model1
# del model2
'''

CONFIDENCE_THRESHOLD = 0.5
label_to_name_model1 = {0: 'Bumper', 1: 'Door handle', 2: 'Engine Hood', 3: 'Front Head lights', 4: 'Gas cap', 5: 'Grill', 6: 'Navigation', 7: 'Rear Lights', 8: 'Rear Seats', 9: 'Shifter', 10: 'Steering Wheel', 11: 'Vents', 12: 'Wheel', 13: 'Windows', 14: 'Windshield'}
label_to_name_model2 = {0: 'Exhaust', 1: 'Fog Lights', 2: 'Logo', 3: 'Side Mirror', 4: 'Towing Hitch'}


def get_session_detection():
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

keras.backend.tensorflow_backend.set_session(get_session_detection())

model1 = models.load_model('/var/www/models/model2/Concat_resnet50_Classes15_8716.h5', backbone_name='resnet50')
model2 = models.load_model('/var/www/models/model2/Concat_resnet50_Classes5_8120.h5', backbone_name='resnet50')
'''

def display_object_classes_model1(image, scale, scale_1, shapes, reshapes):
    points_string = ''
    # image = read_image_bgr(image_name)
    # shapes = str(image.shape)
    # image = preprocess_image(image)
    # image, scale = resize_image(image)
    # reshapes = str(image.shape)

    boxes, scores, labels = model1.predict_on_batch(np.expand_dims(image, axis=0))
    boxes /= scale
    for box, score, label in zip(boxes[0], scores[0], labels[0]):

        if score < CONFIDENCE_THRESHOLD:
            break

        b = box.astype(int)
        center_x = int(b[0] + (b[2] - b[0]) / 2)
        center_y = int(b[1] + (b[3] - b[1]) / 2)
        current_string = '{ "X1" : "' + str(b[0]) + '", "Y1" : "' + str(b[1]) + '", "X2" : "' + str(b[2]) + '", "Y2" : "' + str(b[3]) + '", "CenterX" : "' + str(center_x) + '", "CenterY" : "' + str(center_y) + '", "LabelID" : "' + str(label) + '", "Label" : "' + label_to_name_model1[
            label] + '", "Score" : "' + str(score) + '" ,"Shape": "' + shapes + '","ReShape": "' + reshapes + '"},'
        # print(current_string)
        points_string += current_string

    # points_string_json = '{"image_name" : "'+images_filename+'","object" : ['+points_string.rstrip(",")+']}'
    return points_string  # json.loads(json.dumps(points_string_json))


def display_object_classes_model2(image, scale, scale_1, shapes, reshapes):
    points_string = ''
    # image = read_image_bgr(image_name)
    # shapes = str(image.shape)
    # image = preprocess_image(image)

    ##custom code to match the training data image size
    # rows,cols,_ = image.shape
    # scale_1 = 640/cols
    # image = cv2.resize(image,None, fx=scale_1, fy=scale_1)
    ###

    # image, scale = resize_image(image)
    # reshapes = str(image.shape);
    boxes, scores, labels = model2.predict_on_batch(np.expand_dims(image, axis=0))
    boxes /= scale
    # custom
    # boxes /= scale_1
    ##
    for box, score, label in zip(boxes[0], scores[0], labels[0]):

        if score < CONFIDENCE_THRESHOLD:
            break

        b = box.astype(int)
        center_x = int(b[0] + (b[2] - b[0]) / 2)
        center_y = int(b[1] + (b[3] - b[1]) / 2)
        current_string = '{ "X1" : "' + str(b[0]) + '", "Y1" : "' + str(b[1]) + '", "X2" : "' + str(b[2]) + '", "Y2" : "' + str(b[3]) + '", "CenterX" : "' + str(center_x) + '", "CenterY" : "' + str(center_y) + '", "LabelID" : "' + str(label) + '", "Label" : "' + label_to_name_model2[
            label] + '", "Score" : "' + str(score) + '" ,"Shape": "' + shapes + '","ReShape": "' + reshapes + '"},'
        # print(current_string)
        points_string += current_string

    # points_string_json = '{"image_name" : "'+images_filename+'","object" : ['+points_string.rstrip(",")+']}'
    return points_string  # json.loads(json.dumps(points_string_json))


def display_object_classes(image_name):
    points_string_json = points_string = model1_string = model2_string = ''
    response = requests.get(image_name, stream=True).raw
    image_data = np.asarray(bytearray(response.read()), dtype="uint8")
    image = cv2.imdecode(image_data, -1)

    #image = read_image_bgr(image_data)
    shapes = str(image.shape)
    image = preprocess_image(image)

    scale_1 = ''
    # rows,cols,_ = image.shape
    # scale_1 = 640/cols
    # image = cv2.resize(image,None, fx=scale_1, fy=scale_1)

    image, scale = resize_image(image)
    reshapes = str(image.shape)
    #print(reshapes)
    model1_string = display_object_classes_model1(image, scale, scale_1, shapes, reshapes)
    model2_string = display_object_classes_model2(image, scale, scale_1, shapes, reshapes)
    points_string = model1_string + model2_string
    points_string_json = '{"image_name" : "' + image_name + '","object" : [' + points_string.rstrip(",") + ']}'

    return json.loads(json.dumps(points_string_json))


#prediction = display_object_classes('/var/www/py_jobs/images/single_images/00821549-4245-1BDF-3408-1C4FF2BCD8F9/00821549-4245-1BDF-3408-1C4FF2BCD8F9_00821549-4245-1BDF-3408-1C4FF2BCD8F9.jpg', 'abc')
#print(prediction)

#prediction = display_object_classes('https://gcbimages.storage.googleapis.com/segmentation_images/nonhybrid/E09CB207-4E55-D0F2-CBE1-576F8B56E7BB/E09CB207-4E55-D0F2-CBE1-576F8B56E7BB_revised_anchor_hybrid.jpg', 'abc')
#print(prediction)
# start = timer()
# exterior_json = display_object_classes('/var/www/py_jobs/images/AC66D1E5-3700-4F22-883D-222E830DACA0_f00002.jpg', 'AC66D1E5-3700-4F22-883D-222E830DACA0_f00002.jpg')
# print(exterior_json)
# end = timer()
# print('Time for Loading Data : ' + str(end - start))
