# Imports
from collections import namedtuple
from itertools import groupby
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
from openvino.runtime import Core
# Settings
# Directories where data will be placed
model_folder = "model"
data_folder = "data"
charlist_folder = f"{data_folder}/charlists"
# Precision used by model
precision = "FP16"
Language = namedtuple(
typename="Language", field_names=["model_name", "charlist_name", "demo_image_name"]
)
chinese_files = Language(
model_name="handwritten-simplified-chinese-recognition-0001",
charlist_name="chinese_charlist.txt",
demo_image_name="handwritten_chinese_test.jpg",
)
japanese_files = Language(
model_name="handwritten-japanese-recognition-0001",
charlist_name="japanese_charlist.txt",
demo_image_name="handwritten_japanese_test.png",
)
# Select Language
# Select language by using either language='chinese' or language='japanese'
language = "chinese"
languages = {"chinese": chinese_files, "japanese": japanese_files}
selected_language = languages.get(language)
# Download Model
path_to_model_weights = Path(f'{model_folder}/intel/{selected_language.model_name}/{precision}/{selected_language.model_name}.bin')
if not path_to_model_weights.is_file():
download_command = f'omz_downloader --name {selected_language.model_name} --output_dir {model_folder} --precision {precision}'
print(download_command)
! $download_command
# Load Network and Execute
ie = Core()
path_to_model = path_to_model_weights.with_suffix(".xml")
model = ie.read_model(model=path_to_model)
# Select Device Name
# To check available device names run the line below
# print(ie.available_devices)
compiled_model = ie.compile_model(model=model, device_name="CPU")
# Fetch Information About Input and Output Layers
recognition_output_layer = compiled_model.output(0)
recognition_input_layer = compiled_model.input(0)
# Load an Image
# Read file name of demo file based on the selected model
file_name = selected_language.demo_image_name
# Text detection models expects an image in grayscale format
# IMPORTANT!!! This model allows to read only one line at time
# Read image
image = cv2.imread(filename=f"{data_folder}/{file_name}", flags=cv2.IMREAD_GRAYSCALE)
# Fetch shape
image_height, _ = image.shape
# B,C,H,W = batch size, number of channels, height, width
_, _, H, W = recognition_input_layer.shape
# Calculate scale ratio between input shape height and image height to resize image
scale_ratio = H / image_height
# Resize image to expected input sizes
resized_image = cv2.resize(
image, None, fx=scale_ratio, fy=scale_ratio, interpolation=cv2.INTER_AREA
)
# Pad image to match input size, without changing aspect ratio
resized_image = np.pad(
resized_image, ((0, 0), (0, W - resized_image.shape[1])), mode="edge"
)
# Reshape to network the input shape
input_image = resized_image[None, None, :, :]
# Visualise Input Image
plt.figure(figsize=(20, 1))
plt.axis("off")
plt.imshow(resized_image, cmap="gray", vmin=0, vmax=255);
# Prepare Charlist
# Get dictionary to encode output, based on model documentation
used_charlist = selected_language.charlist_name
# With both models, there should be blank symbol added at index 0 of each charlist
blank_char = "~"
with open(f"{charlist_folder}/{used_charlist}", "r", encoding="utf-8") as charlist:
letters = blank_char + "".join(line.strip() for line in charlist)
# Run Inference
# Run inference on the model
predictions = compiled_model([input_image])[recognition_output_layer]
# Process Output Data
# Remove batch dimension
predictions = np.squeeze(predictions)
# Run argmax to pick the symbols with the highest probability
predictions_indexes = np.argmax(predictions, axis=1)
# Use groupby to remove concurrent letters, as required by CTC greedy decoding
output_text_indexes = list(groupby(predictions_indexes))
# Remove grouper objects
output_text_indexes, _ = np.transpose(output_text_indexes, (1, 0))
# Remove blank symbols
output_text_indexes = output_text_indexes[output_text_indexes != 0]
# Assign letters to indexes from output array
output_text = [letters[letter_index] for letter_index in output_text_indexes]
# Print Output
plt.figure(figsize=(20, 1))
plt.axis("off")
plt.imshow(resized_image, cmap="gray", vmin=0, vmax=255)
print("".join(output_text))