Incorporate Knowledge Graph with Attention in EHR — A Case Study
In this article, the main idea of the paper “GRAM: Graph-based Attention Model for Healthcare Representation Learning” (2017) [1] is explained (its Source Code [8]). I read it, investigated it, made sense out of it. So you don’t have to. 😊 Disclaimer: There is no Graph Neural Network to be discussed.
Introduction
EHR (Electronic Health Record) is everywhere. As I know, originally it was created to facilitate the billing purpose, but now it is an essential tool to keep track of health information on individuals. Challenges still remain due to different formats and terminology systems as well as data exchange and interoperability [2]. EHR can contain anything, from Blood Test result, to X-Ray images; but we focus on Diagnose Codes for patients, noted down in every of their visits to doctors. Such (concise) Diagnose Codes are actually standardized: ICD [3]. Imaging that for every doctor visit, the patient ends up with a series of Diagnose Codes for that specific visit. Overtime, the patient has a collection of such codes, and an idea comes up: for the patient, can we predict their future Diagnose Codes based on the code history? That makes sense. For example, the Diagnose Code I10 [4] for Hypertension and E78.00 for High Cholesterol Level [5] may predict the chance of I63.9 for Stroke [6] in the future, as the relation is well known [7].
This can be framed as a prediction task for RNN (Recurrent Neural Network). Despite being powerful in mapping arbitrary Input to arbitrary Output, Neural Network needs our help in feature-engineering so that it can converge faster and generalize better. Paper [1] proposes enriching the Diagnose Codes, as Input to a RNN, with Knowledge Graph and Attention.
Paper Summary
We start with the Knowledge Graph. Where does it come from? And how does it fit in the whole scheme here? The ICD Codes [3] are structured as tree, schematically encoding the parent-child-sibling relations between dieseases. Doctors are advised to note down the patient’s diagnosed dieases at the as much granular level as possible: typically the leaf nodes. For example, Hypertension (I10) has the parent “Hypertensive diseases” and grandparent “Diseases of circulatory system”.
Based on this ICD tree, we have a Knowledge Graph. The missing piece of this graph is the Nodes’ Features, which is not available in ICD. So, how do we create the Node Feature in a meaningful way? The paper leveraged the technique GloVe [11], which prerequires the construction of a Co-occurrence matrix for Diagnose Codes. As an example, assuming that in a Visit Vₜ, the patient is diagnosed with Codes cd, cᵢ, cₖ:
Those Codes are Leaf Nodes within ICD Tree. For each, we also collect all the Codes corresponding to their Ancestor Nodes. Thus, end up with V’ₜ:
The Co-occurrence measure for cᵢ and cⱼ within V’ₜ:
This process is repeated for all pairs of Codes in all augmented visits of all patients to obtain the co-occurrence matrix. Then follow the guidance in [11] for training the embedding vectors for Feature Nodes.
Implication of this setup: Diagnose Codes of more general medical concepts (i.e. the Ancestor Nodes in ICD) appear more frequently than the others, thus contributing higher to medical events (i.e. Visit).
We now have the Knowledge Graph (Nodes with Connections and Node Features). It’s worth pointing out that we do not intend to use Graph Neural Network; but instead, embed the Knowledge Graph into the Input of our Model, which is currently as:
The paper’s proposal of embedding Knowledge Graph is as followed:
- In each Visit, there are a series of Diagnose Codes, which are among Leaf Nodes in ICD. Encoded as multi-hot vector (image above).
- For each Leaf Node, gather all of their Ancestor Nodes. E.g. Leaf Node (aka. Diagnose Code) c₁ has parent-of-c₁ and grandparent-of-c₁ Nodes, each of which has its own Feature vectors, thanks to GloVe [11] as aforementioned.
- A summation of these Feature Vectors corresponding to c₁, parent-of-c₁, grandparent-of-c₁… can better represent the Diagnose Code c₁. However, sum at what ratio? This is where Attention comes into play.
- The attention technique is called Additive Attention, influenced by the paper Bahdanau et al. (2016) [13]. To make sure all the Attention Weights summed to 1, Softmax is used. Each weight is computed as the output of a Feed-Forward Network with one hidden layer, which in turn accepts input as a concatenation vector between c₁ and its Ancestor. Note: We consider c₁ itself is among its Ancestors.
Let’s call these steps as the construction of the Embedding Matrix G (below).
The embedding matrix G: G ∈ Rᵐ ˣ ᶜ, where C is the number of Diagnose Codes (i.e. Leaf Nodes), and m is the number of dimensions of Feature Nodes in Knowledge Graph. The Model Input (as in Figure 2): x ∈ Rᶜ ˣ ᵀ. Doing the matrix multiplication, we got the Input for our RNN. And the rest of operation is nothing special.
Implication of this setup:
- The Model Input is now more expressive (instead of just a multi-hot vector like in Figure 2): it contains not only information about the direct Diagnose Code, but also information of more general medical concepts through the parent/grandparent relations in ICD. The proportional contribution of child/parent/grandparent is determined via Attention.
- The possibility of transfer knowledge from one patient to another. Intuitive example: Considering Patient L, M and N, respectively diagnosed with Codes cₗ, cₘ, cₙ. Those codes share their parent node or grandparent node (in ICD). Patient L and M are later diagnosed with Code cₓ. Reasonably, Patient N may later also get cₓ! Without embedding the Knowledge Graph, this can-be-transferred knowledge is buried.
Some other details not discussed so far:
- The Diagnose Code in EHR follows ICD, but in the paper, they convert such ICD Codes (Leaf Node) to CCS Codes because the former is too fine-grained, which may hinder performance while the latter can provide a better summary. Find out more in [10].
- The data used in the paper is of ICD-9, but the Diagnose Code example in this Medium article is of ICD-10.
- The basic embeddings of Diagnose Codes, initialized with GloVe, are also fine-tuned during model training via backpropagation.
- The final hidden state of RNN is put through a FC (Fully Connected Layer) and Softmax to make prediction.
- RNN can be either LSTM, GRU or IRNN…
- Instead of GloVe [11], Skip-gram [12] may be used.
- The loss function is binary cross entropy. In actual implementation, the author takes the average of the individual loss for multiple patients due to batch operation.
- From time to time, as noted in the paper, the real-life EHR data still contain Diagnose Code not at the most granular level, but at the parent level (in ICD tree). In this case, the paper treats the Diagnose Code as just other Leaf Node.
Final Remark
In my observation, the majority of papers usually try to solve specific narrow tasks, given specific constraints (e.g. data, computation, assumption). One may wonder if age and sex or even BMI (Body mass index) can be also relevant in predicting Diagnose Codes for a specific patient in their future. But the model in this paper [1] did not factore in such. Nevertheless, we have learned one technique in making Model Input more expressive, and enabling transfer knowledge across data points. All these improvements potentially facilitate the Learning to converge faster and generalize better.
Last but not least, reading papers is great; at some point, one can come to the realization that:
References
[1] GRAM: Graph-based Attention Model for Healthcare Representation Learning (2017) (https://arxiv.org/abs/1611.07012)
[2] Attention-based neural networks for clinical prediction modelling on electronic health records (2023) (https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-023-02112-2)
[3] International Statistical Classification of Diseases (ICD) (https://www.who.int/standards/classifications/classification-of-diseases)
[4] ICD I10 (https://www.icd10data.com/ICD10CM/Codes/I00-I99/I10-I1A/I10-/I10)
[5] E78.00 (https://www.icd10data.com/ICD10CM/Codes/E00-E89/E70-E88/E78-/E78.00)
[6] I63.9 (https://www.icd10data.com/ICD10CM/Codes/I00-I99/I60-I69/I63-/I63.9)
[7] Risk Factors for Stroke (https://www.cdc.gov/stroke/risk-factors/index.html)
[8] GRAM Source Code (https://github.com/mp2893/gram)
[9] ICD Code Browser (https://icd.who.int/browse10/2019/en#/I10)
[10] Representing EHR data: ICD codes vs. CCS codes (https://glassboxmedicine.com/2018/11/13/representing-ehr-data-icd-codes-vs-ccs-codes/)
[11] GloVe: Global Vectors for Word Representation (https://nlp.stanford.edu/pubs/glove.pdf)
[12] Efficient Estimation of Word Representations in Vector Space (https://arxiv.org/abs/1301.3781)
[13] Neural Machine Translation by Jointly Learning to Align and Translate (https://arxiv.org/abs/1409.0473)