Approach for Easy Visual Comparison between ground-truth and predicted classes

Although classification metrics are good for summarizing a model’s performance on a dataset, they disconnect the user from the data itself. Similarly, a confusion matrix might tell us that performance is suffering because of false positives, but it obscures information about what patterns may have caused those misclassifications and what types of false positives there might be. 

One way to gain interpretability is to group sampled images by the category of their output (true negative, false negative, false positive, true positive), and display them in a powerpoint file for facile review. These visualizable categories make it easy to identify patterns in misclassified data that can be exploited to improve performance (e.g., hard negative mining, or image analysis based filtering).

This blog post describes and demonstrates a workflow that produces such a powerpoint slide deck automatically for review, as shown below:

The complete workflow consists of two parts:

  1. Perform Image prediction
  2. Powerpoint slide deck generation

For some users, Part 1 of the workflow might not be necessary. This is the case if you already have a .pytable file with the following datasets (pytable terminology for named stored arrays): “labels”, “filenames”, “imgs”, and “predictions”. If not, Part 1 provides a script for generating a predictions dataset using the other three. Users who have completed the “Digital Pathology Classification Using PyTorch + DenseNet” blog post, will find both parts to work seamlessly with their existing workflow.

Part 1: Perform Image Prediction

The generate_densenet_predictions.py script created a “predictions” dataset in the provided .pytable file. This column will eventually contain the classification results from the user-provided model when applied to the patch.

Here is a template for running the script at the command line:

python3 generate_densenet_predictions.py [pytable path] [model checkpoint path] --patch_size [patch size] --gpuid [gpu id] --batch_size [batch size]

Script Requirements

  1. A .pytable file with the following keys: 
    • “labels”: returns an array of integers in the range [0,#classes – 1]
    • “filenames”: returns an array of strings corresponding to image filepaths
    • “imgs”: returns an array of images with the dimension (width, height, n_channels)
  2. A pytorch model checkpoint file.
    • To use this script with a model other than Densenet, pay attention to the model initialization step of the code walkthrough.

Code Walkthrough

This script is similar to the train_densenet_albumentations.py script in the aforementioned DenseNet blog post.

We begin by defining a Dataset class. This class is the interface from a user’s data storage to PyTorch’s DataLoader class. We will initialize a Dataset object later, but know that the __init__() method takes an image transform object and a .pytable file path. 

Note that the Dataset class defined below is different from the pytable dataset that we introduced earlier, and for that reason we format the two differently.

  1. class Dataset(object):
  2.     def __init__(self, fname ,img_transform=None):
  3.         #nothing special here, just internalizing the constructor parameters
  4.         self.fname=fname
  5.  
  6.         self.img_transform=img_transform
  7.        
  8.         with tables.open_file(self.fname,'r') as db:
  9.             # self.classsizes=db.root.classsizes[:]
  10.             self.nitems=db.root.imgs.shape[0]
  11.        
  12.         self.imgs = None
  13.         self.labels = None
  14.        
  15.     def __getitem__(self, index):
  16.         #opening should be done in __init__ but seems to be
  17.         #an issue with multithreading so doing here. need to do it everytime, otherwise hdf5 crashes
  18.  
  19.         with tables.open_file(self.fname,'r') as db:
  20.             self.imgs=db.root.imgs
  21.             self.labels=db.root.labels
  22.             self.fnames=db.root.filenames
  23.  
  24.             #get the requested image and mask from the pytable
  25.             img = self.imgs[index,:,:,:]
  26.             label = self.labels[index]
  27.        
  28.        
  29.         img_new = img
  30.        
  31.         if self.img_transform:
  32.             img_new = self.img_transform(image=img)['image']
  33.  
  34.         return img_new, label, img
  35.     def __len__(self):
  36.         return self.nitems

Next we check the .pytable file for a “predictions” dataset. If such a dataset already exists, the user should be aware that it will be overwritten.

  1. with tables.open_file(args.pytable_path, 'r') as f:
  2.         if 'predictions' in f.root:
  3.             input(f'{args.pytable_path} already contains a "predictions" dataset.\nPress [ENTER] to overwrite the existing dataset, or ctrl+C to safely kill this script.')

Then we set the device to the specified gpu, if available:

  1. # set device
  2. print(torch.cuda.get_device_properties(args.gpuid))
  3. torch.cuda.set_device(args.gpuid)
  4. device = torch.device(f'cuda:{args.gpuid}' if torch.cuda.is_available() else 'cpu')

Next, we compose an image transform object. We have chosen to apply just one transform, the CenterCrop. The reason for using this transform is to crop each image to 224×224, which is required by Densenet.

  1. # compose image transforms
  2. img_transform = Compose([
  3.     # VerticalFlip(p=.5),
  4.     # HorizontalFlip(p=.5),
  5.     # HueSaturationValue(hue_shift_limit=(-25,0),sat_shift_limit=0,val_shift_limit=0,p=1),
  6.     # Rotate(p=1, border_mode=cv2.BORDER_CONSTANT,value=0),
  7.     # RandomSizedCrop((args.patch_size,args.patch_size), args.patch_size,args.patch_size),
  8.     CenterCrop(args.patch_size,args.patch_size),
  9.     ToTensor()
  10. ])

Then we initialize a Dataset and pass it into a PyTorch DataLoader. Note that the DataLoader takes args.batch_size as an argument. A higher value for args.batch_size will speed up the inference process, and the argument’s upper bound will depend on your machine. In the code, the default batch size is set to 32.

  1. # initialize dataset and dataloader
  2. dset = Dataset(args.pytable_path, img_transform)
  3. dloader = DataLoader(dset, batch_size=args.batch_size,
  4.                                 shuffle=False, num_workers=8,pin_memory=True)

Now we have reached the model initialization step. Here we load the model checkpoint, and pass its hyper-parameters into a DenseNet. If you are using this script for a different kind of PyTorch model, import your model and set model = YourModel(…)

  1. # initialize DenseNet model
  2. checkpoint = torch.load(args.model_checkpoint)
  3. model = DenseNet(growth_rate=checkpoint['growth_rate'],
  4.     block_config=checkpoint['block_config'],
  5.     num_init_features=checkpoint['num_init_features'],
  6.     bn_size=checkpoint['bn_size'],
  7.     drop_rate=checkpoint['drop_rate'],
  8.     num_classes=checkpoint['num_classes']).to(device)

We finish the model initialization by loading the weights dictionary, setting the model to evaluation mode, and sending it to the preset device:

  1. model.load_state_dict(checkpoint["model_dict"])
  2. model.eval()
  3. model.to(device)
  4. print(model)
  5. print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")

We are ready to generate an array of predictions! We load each image batch into the model and append each respective output (a batch of predicted labels) to the predictions array.

  1. # generate predictions
  2. dtype = tables.UInt8Atom()
  3. predictions = []
  4. for img, _, _ in tqdm(dloader):
  5.     img = img.to(device)
  6.     pred = model(img)
  7.     p=pred.detach().cpu().numpy()
  8.     predflat=np.argmax(p,axis=1).flatten()
  9.     predictions.extend(predflat)

Once all of the predictions have been generated, we check for a pre-existing “predictions” dataset in the the pytable and remove it. Finally we create a new “predictions” dataset in the pytable and fill it with the contents of our predictions array.

  1. with tables.open_file(args.pytable_path, 'a') as f:
  2.     if 'predictions' in f.root: # remove the current predictions dataset and start fresh, in case the number of predictions has changed.
  3.         f.root.predictions.remove()    
  4.        
  5.     f.create_carray(f.root, "predictions", dtype, np.array(predictions).shape)
  6.     f.root.predictions[:] = predictions
  7.     print(f'Predictions have been saved to {args.pytable_path}')

With this set of predictions we can use the second script.

Part 2: Powerpoint slide deck generation

The visualize_classification_groups.py script outputs a .pptx file at the path specified by the ppt_save argument. Here is a template for running the script at the command line:

python3 visualize_classification_groups.py [pytable path] [ppt_save] --str_criteria [“criteria1_name” “criteria2_name” … “criteriaN_name”] --criteria [d1d2 d3d4 …] --num_rows [a] --num_cols [b]

Importantly: the criteria item corresponds to a pptx slide you want created. It should be a pair of integer digits, e.g., the item “01” tells the script to produce a slide of images with a ground truth label of 0 and prediction of 1.

Script Requirements

  1. A .pytable file with the following keys: 
    • “labels”: stores an array of integers in the range [0, #classes – 1]
    • “filenames”: stores an array of strings corresponding to image filepaths
    • “imgs”: stores an array of images with the dimension (width, height, n_channels)
    • “predictions”: stores an array of integers in the range [0,#classes – 1]
  2. For parsing simplicity, we assume that class labels are single-digit integers (0 through 9).

Code Walkthrough

We begin by defining the method addimagetoslide(). This method reads an image from a bytestream object and adds it to a powerpoint slide.

  1. def addimagetoslide(slide,image_stream,left,top, height, width, resize = 1.0, comment = ''):
  2.     slide.shapes.add_picture(image_stream, left, top ,height,width)
  3.     txBox = slide.shapes.add_textbox(left, Inches(1), width, height)
  4.     tf = txBox.text_frame
  5.     tf.text = comment

Then we parse the criteria arguments and create a dict with the following structure:

criteria_dict = {
        criteria1_name: [d1, d2],
        criteria2_name: [d3, d4],
        …
}
  1. ##### INITIALIZE CRITERIA DICT #####
  2. criteria_dict = {args.str_criteria[i]: [int(j) for j in args.criteria[i]] for i in range(len(args.str_criteria))}

Then we read in the ground truth labels and predictions from the provided .pytable file.

  1. # get predictions and ground truths
  2. with tables.open_file(args.pytable_path, 'r') as f:
  3.     gts = np.array(f.root.labels[:]).flatten()
  4.     preds = np.array(f.root.predictions[:]).flatten()
  5.     print(gts.shape)

We initialize a Powerpoint object, and define a coordinate grid using np.meshgrid. The dimensions of the grid are specified by the arguments num_rows and num_cols.

  1. # init presentation and compute coordinate grid.
  2. ppt = Presentation()
  3. grid_width = args.num_cols
  4. grid_height = args.num_rows
  5. grid = np.mgrid[0:grid_height, 0:grid_width]
  6. cartesian_coords = np.vstack([grid[0].ravel(), grid[1].ravel()]).T
  7. print(cartesian_coords)

Now we loop through each classification group in the criteria_dict. For each group, we find all of the images that match the criteria, and choose at most num_rows * num_cols images.

  1. slidenames = criteria_dict.keys()
  2. for slidename in slidenames:
  3.    criterion = criteria_dict[slidename]
  4.  
  5.    # find indices where the criterion is true
  6.    inds = np.argwhere(np.logical_and(gts==criterion[0], preds==criterion[1])).flatten()
  7.    grid_size = grid_width*grid_height
  8.    num_imgs = min(grid_size, len(inds))
  9.    selected_inds = np.random.choice(inds, num_imgs, replace=False)
  10.    print(f'Number of images in {slidename}: {len(inds)}')

For each classification group, we also append a new slide to the powerpoint and add the slide name.

  1. # create new slide
  2. blank_slide_layout = ppt.slide_layouts[6]
  3. slide = ppt.slides.add_slide(blank_slide_layout)
  4. txBox = slide.shapes.add_textbox(Inches(ppt.slide_height.inches/2), Inches(0.2), Inches(2), Inches(1))  # add slide title
  5. txBox.text_frame.text = slidename

For each new slide added, we loop through the selected images. For each image we attach a file name title, load the image into an image buffer, and add it to the slide.

  1. # fill slide grid with selected images
  2. for j, ind in tqdm(enumerate(selected_inds)):
  3.    coord = cartesian_coords[j]
  4.  
  5.    with tables.open_file(args.pytable_path, 'r') as f:
  6.       img_filename = f.root.filenames[ind]
  7.       label = f.root.labels[ind]
  8.       img = f.root.imgs[ind, :, :, :]
  9.  
  10.    plt.imshow(img)
  11.    title = img_filename.decode("utf-8").split('/')[-1]
  12.    plt.title(title, wrap=True)
  13.    plt.tick_params(color='b',bottom=False, left=False)
  14.    plt.xticks([])
  15.    plt.yticks([])
  16.    with BytesIO() as img_buf:
  17.       plt.savefig(img_buf, format='png', bbox_inches='tight')
  18.       plt.close()
  19.       im = Image.open(img_buf)
  20.       addimagetoslide(slide, img_buf, top=Inches(coord[0]*1.2+1), left=Inches(coord[1]*1.2 + 0.25), width=Inches(im.height/im.width), height=Inches(1))
  21.       im.close()

We’re done! The powerpoint is ready to be saved.

  1. ppt.save(args.ppt_save)

Let’s test this script on the DenseNet workflow. Assuming that we are using the subtyping dataset from the DenseNet blog post and that the data-storage pytable has a predictions array, use this command to generate a four-slide powerpoint:

python3 visualize_classification_groups.py [pytable path] [ppt_save] --str_criteria “gt:CLL pred:CLL” “gt:CLL pred:FL” “gt:FL pred:FL” “gt:FL pred:MCL --criteria 00 01 11 12

The script saved a powerpoint with the following slides:

Conclusion

Now that we have a slide deck that groups several combinations of ground truth labels and predictions, we can hunt for patterns in the data that may lead our model to misclassify images!

I hope that this computational method for visualizing classification groups will be helpful in your own workflows. Feel free to leave a comment if you have any questions.

The code is freely available for both generate_densenet_predictions.py and visualize_classification_groups.py.

A Note on the Author and Contributions

Jackson Jacobs is a 4th year undergraduate student at Case Western Reserve University expecting to graduate in the spring of 2023. The focus of his research is to improve Multiple Instance Learning techniques for detecting cancer in Whole Slide Images.

This blog post was conceived of and edited by Andrew Janowczyk, and was written by Jackson Jacobs.

Leave a Reply

Your email address will not be published. Required fields are marked *