from layoutlm import FunsdDataset, LayoutlmConfig, LayoutlmForTokenClassification from transformers import BertTokenizer import torch MODEL_CLASSES = { "layoutlm": (LayoutlmConfig, LayoutlmForTokenClassification, BertTokenizer), } def main(): if torch.cuda.is_available(): device = torch.device("cuda") print("GPU is available") else: device = torch.device("cpu") print("GPU is not available, using CPU instead") labels = get_labels(labels) # in our case labels will be x-axis,y-axis,title num_labels = len(labels) # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later pad_token_label_id = CrossEntropyLoss().ignore_index config = config_class.from_pretrained( "layoutlm-base-uncased/", num_labels=num_labels, force_download = True, ignore_mismatched_sizes=True, cache_dir= cache_dir_path else None, ) tokenizer = tokenizer_class.from_pretrained( "microsoft/layoutlm-base-uncased", do_lower_case=True, force_download = True, ignore_mismatched_sizes=True, cache_dir= cache_dir_path else None, ) model = model_class.from_pretrained( "layoutlm-base-uncased/", config=config, ) model.to(args.device) train_dataset = FunsdDataset( args, tokenizer, labels, pad_token_label_id, mode="train" ) global_step, tr_loss = train( args, train_dataset, model, tokenizer, labels, pad_token_label_id ) tokenizer = tokenizer_class.from_pretrained( "microsoft/layoutlm-base-uncased",force_download = True, do_lower_case=args.do_lower_case,ignore_mismatched_sizes=True) model = model_class.from_pretrained(args.output_dir) model.to(args.device) result, predictions = evaluate( args, model, tokenizer, labels, pad_token_label_id, mode="test" ) return result,predictions
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter