from flask import Flask, redirect, url_for, request, render_template, send_from_directory
import urllib.request
import numpy as np
from PIL import ImageFile
from PIL import Image
import requests
from detect_model_data import display_object_classes
import json, io, gc, random, sys, os, warnings
from flask import Flask, redirect, url_for, request, render_template, send_from_directory
from flask_caching import Cache
from flask_cors import CORS

api = Flask(__name__)
api.debug = True
CORS(api)

import tensorflow as tf
import keras_retinanet
from keras_retinanet import models
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image, compute_resize_scale
import keras
from keras import backend as K
from keras.models import load_model

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")
global graph, model
graph = tf.get_default_graph()

CONFIDENCE_THRESHOLD = 0.5
label_to_name_model1 = {0: 'Bumper', 1: 'Door handle', 2: 'Engine Hood', 3: 'Front Head lights', 4: 'Gas cap', 5: 'Grill', 6: 'Navigation', 7: 'Rear Lights', 8: 'Rear Seats', 9: 'Shifter', 10: 'Steering Wheel', 11: 'Vents', 12: 'Wheel', 13: 'Windows', 14: 'Windshield'}
label_to_name_model2 = {0: 'Exhaust', 1: 'Fog Lights', 2: 'Logo', 3: 'Side Mirror', 4: 'Towing Hitch'}


def get_session_detection():
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)


keras.backend.tensorflow_backend.set_session(get_session_detection())

model1 = models.load_model('/var/www/models/model2/Concat_resnet50_Classes15_8716.h5', backbone_name='resnet50')
model2 = models.load_model('/var/www/models/model2/Concat_resnet50_Classes5_8120.h5', backbone_name='resnet50')


@api.route('/detect')
def render():
    with graph.as_default():
        sess = tf.Session()
        if 'filename' in request.args:
            image_url = request.args.get('filename')
            prediction = display_object_classes(image_url)
            return json.dumps(prediction)
        sess.close()


if __name__ == '__main__':
    api.run('0.0.0.0', os.environ.get('PORT', 5000), debug=True, use_reloader=True)
