Drug-Deug Adverse Effect Prediction with Graph Co-Attention
Drug-Drug Adverse Effect Prediction with Graph Co-Attention
Andreea Deac, Yu-Hsiang Huang, Petar Velickovic, Pietro Lio, Jian Tang
Architecture
Inputs
- drug, $$d_x$$
- atom, $$a_i^{d_x}$$
- edge, $$\big( ai^{d_x} a_j^{d_x}\big)$$, $$e{ij}^{(dx)}$$
有两种预测的方法
- 将side-effect作为input特征之一,输出针对于该input的prediction
- 类似Decagon,同时输出所有side-effect的prediction
Message-Passing
$$^{(d_x)}h_i^t$$是atom i drug x在第t步上的特征。初始状态将其设定为projected input features
$$ ^{(d_x)}h_i^t0 = f_i (a_i^{(d_x)})
$$
其中$$f_i$$是一个小的MLP。我们所有的设定中,MLP会投影到32-D。
后续的message,沿着j到i的edge上的message表示为$$^{(dx)} m{ij}^t$$。message会同时考虑到j的特征,和edge的特征。
$$ ^{(dx)} m{ij}^t = fe^t(e{ij}^{(d_x)}) \odot f_v(^{(d_x)} h_j^{t-1})
$$
其中$$f_e^t, f_v^t$$是小的MLP。
然后每一个atom会将收到的message进行aggregate,产生新的internal message
$$ ^{(dx)} m_i^t = \sum{j \in N(i)} \ ^{(dx)m{ij}^t}
$$
Co-attention
为了给drug-drug interaction建模,我们考虑co-attention mechanism。
考虑两个drug,x和y,的各自一个atom,i和j。第t步的feature分别是$$^{(dx)}h_i^t$$和$$^{(d_y)}h_j^t$$。对于每一个pair,我们计算attentional coefficient,$$\alpha{ij}^t$$ 使用了简化的Transformer attention mechanism。
$$ \alpha_{ij}^t = softmax_j \big(\langle W_k^t \cdot ^{(d_x)}h_i^{t-1}, \langle W_k^t \cdot ^{(d_y)}h_j^{t-1} \rangle\big)
$$
softmax是第二个drug上所有的node上取得的。这个coefficient $$\alpha_{ij}$$可以理解为atom j对于atom i的重要性。
这个coefficient然后就用来计算drug x atom i的outer message,表示为$$^{(d_x)} n_i^t$$,是atom y所有atom features的linear combination。
$$ ^{(dx)} n_i^t = \sum{a \in dy} \alpha{ij}^t \cdot W_v^t \, \cdot ^{d_y} h_j^{t-1}
$$
最近一些work显示了multi-head attention可以稳定learning process,并且从不同的level学习信息。因此,公式5的机制在K个不同的head之间是相互独立的。生成的vector会concat一下,再用MLP来产生每一个atom最终的outer message
$$ ^{(dx)}n_i^t = f_0^t \big( \parallel{k=1}^K \sum_{j \in d_y} \ ^{(k)}W_v^t \cdot ^{(d_y)}h_j^{t-1} \big)
$$
Update function
每当每一个atom的inner message和outer message得到之后,就会用来node-level feature。
这一步aggregating (通过summation)如下
$$ ^{(d_x)} h_i^t = LayerNorm( ^{(d_x)}h_i^{t-1} + ^{(d_x)} m_i^t + ^{(d_x)}n_i^t )
$$
Readout and Scoring
$$ dx = \sum{i \in d_x} f_r(^{(d_x)}h_i^T)
$$
后续如何在$$d_x, d_y$$上进行prediction基于模型的构造机制。
Appendix
主要的区别点就在于将message passing分为inner message和outer message
- inner message 只考虑neighborhood aggregation (传统做法)
- outer message 会考虑到当前node和另外一个drug所有node之间的aggregation
一个问题是是否可以交换?或者是否ordering sensitive。