A variety of deep learning approaches have been proposed to automatically classify Alzheimer’s disease (AD) from medical images. However, common approaches such as traditional convolutional neural networks (CNNs), lack interpretability and are prone to overfitting when trained on small datasets. As an alternative, significantly less work has explored applying deep learning approaches to region-based features that are commonly attained from atlas partitions of known regions of interest (ROI). In this work, we combine CNNs with graph neural networks (GNNs) to jointly learn an adjacency matrix of connectivity’s between ROIs as a prior for learning meaningful features for AD prediction. We apply our method to the ADNI dataset and systematically inspect the different intermediate layers of our network using t-SNE projections that show strong separation on out-of-sample data. Finally, we show that the edge probabilities alone are sufficient to reach high classification accuracy by training a secondary random forest classifier on the adjacency matrices outputted from our network and illustrate the interpretability properties of the graphs by visualizing the feature importance for all edges.
|