def show_dataset(thumb_size, cols, rows, ds):
mosaic = PIL.Image.new(mode='RGB', size=(thumb_size*cols + (cols-1),
thumb_size*rows + (rows-1)))
for idx, data in enumerate(iter(ds)):
img, target_or_imgid = data
ix = idx % cols
iy = idx // cols
img = np.clip(img.numpy() * 255, 0, 255).astype(np.uint8)
img = PIL.Image.fromarray(img)
img = img.resize((thumb_size, thumb_size), resample=PIL.Image.BILINEAR)
mosaic.paste(img, (ix*thumb_size + ix,
iy*thumb_size + iy))
display(mosaic)
ds = get_dataset(files_train, CFG).unbatch().take(12*5)
show_dataset(64, 12, 5, ds)