Drug-Deug Adverse Effect Prediction with Graph Co-Attention

Drug-Drug Adverse Effect Prediction with Graph Co-Attention

Andreea Deac, Yu-Hsiang Huang, Petar Velickovic, Pietro Lio, Jian Tang

Architecture

Inputs

  • drug, $$d_x$$
  • atom, $$a_i^{d_x}$$
  • edge, $$\big( ai^{d_x} a_j^{d_x}\big)$$, $$e{ij}^{(dx)}$$

有两种预测的方法

  1. 将side-effect作为input特征之一,输出针对于该input的prediction
  2. 类似Decagon,同时输出所有side-effect的prediction

Message-Passing

$$^{(d_x)}h_i^t$$是atom i drug x在第t步上的特征。初始状态将其设定为projected input features

$$ ^{(d_x)}h_i^t0 = f_i (a_i^{(d_x)})

$$

其中$$f_i$$是一个小的MLP。我们所有的设定中,MLP会投影到32-D。

后续的message,沿着j到i的edge上的message表示为$$^{(dx)} m{ij}^t$$。message会同时考虑到j的特征,和edge的特征。

$$ ^{(dx)} m{ij}^t = fe^t(e{ij}^{(d_x)}) \odot f_v(^{(d_x)} h_j^{t-1})

$$

其中$$f_e^t, f_v^t$$是小的MLP。

然后每一个atom会将收到的message进行aggregate,产生新的internal message

$$ ^{(dx)} m_i^t = \sum{j \in N(i)} \ ^{(dx)m{ij}^t}

$$

Co-attention

为了给drug-drug interaction建模,我们考虑co-attention mechanism。

考虑两个drug,x和y,的各自一个atom,i和j。第t步的feature分别是$$^{(dx)}h_i^t$$和$$^{(d_y)}h_j^t$$。对于每一个pair,我们计算attentional coefficient,$$\alpha{ij}^t$$ 使用了简化的Transformer attention mechanism。

$$ \alpha_{ij}^t = softmax_j \big(\langle W_k^t \cdot ^{(d_x)}h_i^{t-1}, \langle W_k^t \cdot ^{(d_y)}h_j^{t-1} \rangle\big)

$$

softmax是第二个drug上所有的node上取得的。这个coefficient $$\alpha_{ij}$$可以理解为atom j对于atom i的重要性。

这个coefficient然后就用来计算drug x atom i的outer message,表示为$$^{(d_x)} n_i^t$$,是atom y所有atom features的linear combination。

$$ ^{(dx)} n_i^t = \sum{a \in dy} \alpha{ij}^t \cdot W_v^t \, \cdot ^{d_y} h_j^{t-1}

$$

最近一些work显示了multi-head attention可以稳定learning process,并且从不同的level学习信息。因此,公式5的机制在K个不同的head之间是相互独立的。生成的vector会concat一下,再用MLP来产生每一个atom最终的outer message

$$ ^{(dx)}n_i^t = f_0^t \big( \parallel{k=1}^K \sum_{j \in d_y} \ ^{(k)}W_v^t \cdot ^{(d_y)}h_j^{t-1} \big)

$$

Update function

每当每一个atom的inner message和outer message得到之后,就会用来node-level feature。

这一步aggregating (通过summation)如下

$$ ^{(d_x)} h_i^t = LayerNorm( ^{(d_x)}h_i^{t-1} + ^{(d_x)} m_i^t + ^{(d_x)}n_i^t )

$$

Readout and Scoring

$$ dx = \sum{i \in d_x} f_r(^{(d_x)}h_i^T)

$$

后续如何在$$d_x, d_y$$上进行prediction基于模型的构造机制。

Appendix

主要的区别点就在于将message passing分为inner message和outer message

  • inner message 只考虑neighborhood aggregation (传统做法)
  • outer message 会考虑到当前node和另外一个drug所有node之间的aggregation

一个问题是是否可以交换?或者是否ordering sensitive。

results matching ""

    No results matching ""