def show_dataset(thumb_size, cols, rows, ds): mosaic = PIL.Image.new(mode='RGB', size=(thumb_size*cols + (cols-1), thumb_size*rows + (rows-1))) for idx, data in enumerate(iter(ds)): img, target_or_imgid = data ix = idx % cols iy = idx // cols img = np.clip(img.numpy() * 255, 0, 255).astype(np.uint8) img = PIL.Image.fromarray(img) img = img.resize((thumb_size, thumb_size), resample=PIL.Image.BILINEAR) mosaic.paste(img, (ix*thumb_size + ix, iy*thumb_size + iy)) display(mosaic) ds = get_dataset(files_train, CFG).unbatch().take(12*5) show_dataset(64, 12, 5, ds)