# Imports
import time
import json

import numpy as np
import tokens_bert as tokens

from openvino.runtime import Core
from openvino.runtime import Dimension

# Download the model
# directory where model will be downloaded
base_model_dir = "model"

# desired precision
precision = "FP16-INT8"

# model name as named in Open Model Zoo
model_name = "bert-small-uncased-whole-word-masking-squad-int8-0002"

model_path = f"model/intel/{model_name}/{precision}/{model_name}.xml"
model_weights_path = f"model/intel/{model_name}/{precision}/{model_name}.bin"

download_command = f"omz_downloader " \
                   f"--name {model_name} " \
                   f"--precision {precision} " \
                   f"--output_dir {base_model_dir} " \
                   f"--cache_dir {base_model_dir}"
! $download_command

# Load the model for Entity Extraction with Dynamic Shape
# initialize inference engine
ie_core = Core()
# read the network and corresponding weights from file
model = ie_core.read_model(model=model_path, weights=model_weights_path)

# assign dynamic shapes to every input layer on the last dimension
for input_layer in model.inputs:
    input_shape = input_layer.partial_shape
    input_shape[1] = Dimension(1, 384)
    model.reshape({input_layer: input_shape})

# compile the model for the CPU
compiled_model = ie_core.compile_model(model=model, device_name="CPU")

# get input names of nodes
input_keys = list(compiled_model.inputs)

# Processing
# path to vocabulary file
vocab_file_path = "data/vocab.txt"

# create dictionary with words and their indices
vocab = tokens.load_vocab_file(vocab_file_path)

# define special tokens
cls_token = vocab["[CLS]"]
sep_token = vocab["[SEP]"]

# set a confidence score threshold
confidence_threshold = 0.4

# Preprocessing
# generator of a sequence of inputs
def prepare_input(entity_tokens, context_tokens):
    input_ids = [cls_token] + entity_tokens + [sep_token] + \
        context_tokens + [sep_token]
    # 1 for any index
    attention_mask = [1] * len(input_ids)
    # 0 for entity tokens, 1 for context part
    token_type_ids = [0] * (len(entity_tokens) + 2) + \
        [1] * (len(context_tokens) + 1)

    # create input to feed the model
    input_dict = {
        "input_ids": np.array([input_ids], dtype=np.int32),
        "attention_mask": np.array([attention_mask], dtype=np.int32),
        "token_type_ids": np.array([token_type_ids], dtype=np.int32),

    # some models require additional position_ids
    if "position_ids" in [i_key.any_name for i_key in input_keys]:
        position_ids = np.arange(len(input_ids))
        input_dict["position_ids"] = np.array([position_ids], dtype=np.int32)

    return input_dict

# Postprocessing
def postprocess(output_start, output_end, entity_tokens,
                context_tokens_start_end, input_size):

    def get_score(logits):
        out = np.exp(logits)
        return out / out.sum(axis=-1)

    # get start-end scores for context
    score_start = get_score(output_start)
    score_end = get_score(output_end)

    # index of first context token in tensor
    context_start_idx = len(entity_tokens) + 2
    # index of last+1 context token in tensor
    context_end_idx = input_size - 1

    # find product of all start-end combinations to find the best one
    max_score, max_start, max_end = find_best_entity_window(
        start_score=score_start, end_score=score_end,
        context_start_idx=context_start_idx, context_end_idx=context_end_idx

    # convert to context text start-end index
    max_start = context_tokens_start_end[max_start][0]
    max_end = context_tokens_start_end[max_end][1]

    return max_score, max_start, max_end

def find_best_entity_window(start_score, end_score,
                            context_start_idx, context_end_idx):
    context_len = context_end_idx - context_start_idx
    score_mat = np.matmul(
            (context_len, 1)),
            (1, context_len)),
    # reset candidates with end before start
    score_mat = np.triu(score_mat)
    # reset long candidates (>16 words)
    score_mat = np.tril(score_mat, 16)
    # find the best start-end pair
    max_s, max_e = divmod(score_mat.flatten().argmax(), score_mat.shape[1])
    max_score = score_mat[max_s, max_e]

    return max_score, max_s, max_e

def get_best_entity(entity, context, vocab):
    # convert context string to tokens
    context_tokens, context_tokens_end = tokens.text_to_tokens(
        text=context.lower(), vocab=vocab)
    # convert entity string to tokens
    entity_tokens, _ = tokens.text_to_tokens(text=entity.lower(), vocab=vocab)

    network_input = prepare_input(entity_tokens, context_tokens)
    input_size = len(context_tokens) + len(entity_tokens) + 3

    # openvino inference
    output_start_key = compiled_model.output("output_s")
    output_end_key = compiled_model.output("output_e")
    result = compiled_model(network_input)

    # postprocess the result getting the score and context range for the answer
    score_start_end = postprocess(output_start=result[output_start_key][0],

    # return the part of the context, which is already an answer
    return context[score_start_end[1]:score_start_end[2]], score_start_end[0]

# Set the Entity Recognition Template
template = ["building", "company", "persons", "city",
            "state", "height", "floor", "address"]

def run_analyze_entities(context):
    print(f"Context: {context}\n", flush=True)

    if len(context) == 0:
        print("Error: Empty context or outside paragraphs")

    if len(context) > 380:
        print("Error: The context is too long for this particular model. "
              "Try with context shorter than 380 words.")

    # measure processing time
    start_time = time.perf_counter()
    extract = []
    for field in template:
        entity_to_find = field + "?"
        entity, score = get_best_entity(entity=entity_to_find,
        if score >= confidence_threshold:
            extract.append({"Entity": entity, "Type": field,
                            "Score": f"{score:.2f}"})
    end_time = time.perf_counter()
    res = {"Extraction": extract, "Time": f"{end_time - start_time:.2f}s"}
    print("\nJSON Output:")
    print(json.dumps(res, sort_keys=False, indent=4))

# Run on Simple Text
# Sample 1
source_text = "Intel Corporation is an American multinational and technology" \
    " company headquartered in Santa Clara, California."

# Sample 2
source_text = "Intel was founded in Mountain View, California, " \
    "in 1968 by Gordon E. Moore, a chemist, and Robert Noyce, " \
    "a physicist and co-inventor of the integrated circuit."

# Sample 3
source_text = "The Robert Noyce Building in Santa Clara, California, " \
    "is the headquarters for Intel Corporation. It was constructed in 1992 " \
    "and is located at 2200 Mission College Boulevard - 95054. It has an " \
    "estimated height of 22.20 meters and 6 floors above ground."