import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, Linear
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
self.conv2 = GATConv(hidden_channels * heads, int(hidden_channels/4), heads=1, concat=False, dropout=0.6)
self.lin = Linear(int(hidden_channels/4), out_channels)
self.sigmoid = nn.Sigmoid()
def forward(self, x, edge_index, edge_attr):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index, edge_attr))
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv2(x, edge_index, edge_attr))
x = self.lin(x)
x = self.sigmoid(x)
return x
注意该模型使用了两层卷积
|
import torch
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader
from tpf.datasets import AMLtoGraph
from tpf.gnn import GAT
# 确保每次运行都重新处理数据
# import os
# dataset_root = "/wks/datasets/data_tmp"
# if os.path.exists(os.path.join(dataset_root, "processed/data.pt")):
# os.remove(os.path.join(dataset_root, "processed/data.pt"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = AMLtoGraph(root='/wks/datasets/data_tmp',raw_file_names='HI-Small_Trans.csv',processed_file_names='data.pt')
data = dataset[0]
epoch = 100
model = GAT(in_channels=data.num_features, hidden_channels=16, out_channels=1, heads=8)
model = model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
split = T.RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0)
data = split(data)
train_loader = loader = NeighborLoader(
data,
num_neighbors=[30] * 2,
batch_size=256,
input_nodes=data.train_mask,
)
test_loader = loader = NeighborLoader(
data,
num_neighbors=[30] * 2,
batch_size=256,
input_nodes=data.val_mask,
)
for i in range(epoch):
total_loss = 0
model.train()
for data in train_loader:
optimizer.zero_grad()
data.to(device)
pred = model(data.x, data.edge_index, data.edge_attr)
ground_truth = data.y
loss = criterion(pred, ground_truth.unsqueeze(1))
loss.backward()
optimizer.step()
total_loss += float(loss)
if epoch%10 == 0:
print(f"Epoch: {i:03d}, Loss: {total_loss:.4f}")
model.eval()
acc = 0
total = 0
for test_data in test_loader:
test_data.to(device)
pred = model(test_data.x, test_data.edge_index, test_data.edge_attr)
ground_truth = test_data.y
correct = (pred == ground_truth.unsqueeze(1)).sum().item()
total += len(ground_truth)
acc += correct
acc = acc/total
print('accuracy:', acc)
Epoch: 000, Loss: 1959.1876
accuracy: 0.5208217933939149
Epoch: 001, Loss: 1925.4403
accuracy: 0.5655238465797704
...
...
Epoch: 058, Loss: 1528.0990
accuracy: 0.9186715859789615
Epoch: 059, Loss: 1537.8420
accuracy: 0.9203916262675181
...
...
accuracy: 0.9383725399178434
Epoch: 098, Loss: 1504.6287
accuracy: 0.9375358709593704
Epoch: 099, Loss: 1505.1355
accuracy: 0.9435815246968131
|
num_neighbors=[30] * 2 的含义是:
第1层采样:每个目标节点采样 30个邻居(用于 GATConv 的第1层)。
第2层采样:每个邻居节点再采样 30个邻居(用于 GATConv 的第2层)。
这样,采样深度 = 2,正好匹配 2个 GATConv 层 的计算需求。
如果模型有3层 GATConv,则应该设为 num_neighbors=[k1, k2, k3],以此类推。
参数调整建议
平衡效率与性能:
增大 num_neighbors(如 [50, 50])会捕获更广的邻域,但增加计算负担。
减小它(如 [10, 10])会降低内存占用,但可能丢失信息。
非均匀采样:
例如 [30, 10] 表示第一层采样30邻居,第二层从每个第一层邻居再采样10个。
如果你的模型增加到了3层 GATConv,则需要调整采样:
num_neighbors = [30] * 3 # 对应3层GATConv
|