Commit 34bb9513 authored by Yin Jianwu's avatar Yin Jianwu

Upload New File

parent bb453766
Pipeline #623 canceled with stages
import os.path as osp
import time
import sys
import torch
from torch_geometric.datasets import Planetoid,CoraFull,Coauthor,Amazon
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv,GATConv,SGConv
from sklearn.model_selection import KFold
import numpy as np
import torch_geometric.nn as nn
import torch.nn.init as init
# dataset = 'Cora'
# dataset = 'CiteSeer'
#---------get the name of dataset and model--------
dataset = sys.argv[1]
modelName = sys.argv[2]
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
# dataset = 'PubMed'
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
#---------choose dataset--------------
#--i make normalization here
if dataset == 'Cora':
dataset = Planetoid(path, dataset)
elif dataset == 'CiteSeer':
dataset = Planetoid(path, dataset)
elif dataset == 'PubMed':
dataset = Planetoid(path, dataset)
elif dataset == 'CoraFull':
dataset = CoraFull(path)
elif dataset == 'Physics':
dataset = Coauthor(path, dataset)
elif dataset == 'CS':
dataset = Coauthor(path, dataset)
elif dataset == 'Computers':
dataset = Amazon(path, dataset)
elif dataset == 'Photo':
dataset = Amazon(path, dataset)
# there is only one graph
data = dataset[0]
#------------------make 5 fold parts on dataset---------
kf = KFold(n_splits=5,shuffle=True)
for train, test in kf.split(data.y):
#---------gcn-structure of layer---------------------------------------
class GcnNet(torch.nn.Module):
#------- two layers--------
def __init__(self):
super(GcnNet, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16, cached=True)
self.conv2 = GCNConv(16, dataset.num_classes, cached=True)
def forward(self):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x,
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
#-----gat-structure layer-----------
class GatNet(torch.nn.Module):
def __init__(self):
super(GatNet, self).__init__()
self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)
# On the Pubmed dataset, use heads=8 in conv2.
self.conv2 = GATConv(
8 * 8, dataset.num_classes, heads=1, concat=True, dropout=0.6)
def forward(self):
x = F.dropout(data.x, p=0.6,
x = F.elu(self.conv1(x, data.edge_index))
x = F.dropout(x, p=0.6,
x = self.conv2(x, data.edge_index)
return F.log_softmax(x, dim=1)
#------sgc structure layer----
class SgcNet(torch.nn.Module):
def __init__(self):
super(SgcNet, self).__init__()
self.conv1 = SGConv(dataset.num_features, dataset.num_classes, K=2, cached=True)
def forward(self):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
return F.log_softmax(x, dim=1)
#-------use GPU or CPU-----------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#------choose model and set some settings--------------
if modelName == 'gcn':
model, data = GcnNet().to(device),
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
if modelName == 'gat':
model, data = GatNet().to(device),
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
if modelName == 'sgc':
model, data = SgcNet().to(device),
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=0.005)
def train():
def test():
logits = model()
pred = logits[test_idx].max(1)[1]
accs = pred.eq(data.y[test_idx]).sum().item() / len(test_idx)
return accs
#------------reset weights---------
def weights_init(m):
# SGConv GCNConv or GATConv
if isinstance(m, nn.SGConv):
# init.xavier_uniform_(
# init.xavier_uniform_(
results = []
#----------make 5fold-validation--------
for i in range(1,6):
train_idx = train_sets[i-1]
test_idx = test_sets[i-1]
for epoch in range(1, 201):
acc = test()
print('Accuracy: {:.4f}'.format(acc))
#-------compute means and std-----------
a = np.mean(results)
b = np.std(results,ddof=1)
#--------write results to txt file--------
file = open('validation.txt','w')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment