Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Zhang, Zijian
DropEdge
Commits
1d555125
Commit
1d555125
authored
Nov 18, 2020
by
Joshua Ghost
Browse files
finish GAT
parent
dcb1718f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
61 deletions
+154
-61
src/process.py
src/process.py
+63
-0
src/sample.py
src/sample.py
+30
-34
src/train_new.py
src/train_new.py
+61
-27
No files found.
src/process.py
0 → 100644
View file @
1d555125
import
numpy
as
np
import
pickle
as
pkl
import
networkx
as
nx
import
scipy.sparse
as
sp
from
scipy.sparse.linalg.eigen.arpack
import
eigsh
import
sys
def
preprocess_adj_bias
(
adj
):
num_nodes
=
adj
.
shape
[
0
]
adj
=
adj
+
sp
.
eye
(
num_nodes
)
# self-loop
adj
[
adj
>
0.0
]
=
1.0
if
not
sp
.
isspmatrix_coo
(
adj
):
adj
=
adj
.
tocoo
()
adj
=
adj
.
astype
(
np
.
float32
)
indices
=
np
.
vstack
((
adj
.
col
,
adj
.
row
)).
transpose
()
# This is where I made a mistake, I used (adj.row, adj.col) instead
# return tf.SparseTensor(indices=indices, values=adj.data, dense_shape=adj.shape)
return
indices
,
adj
.
data
,
adj
.
shape
def
adj_to_bias
(
adj
,
sizes
,
nhood
=
1
):
nb_graphs
=
adj
.
shape
[
0
]
mt
=
np
.
empty
(
adj
.
shape
)
for
g
in
range
(
nb_graphs
):
mt
[
g
]
=
np
.
eye
(
adj
.
shape
[
1
])
for
_
in
range
(
nhood
):
mt
[
g
]
=
np
.
matmul
(
mt
[
g
],
(
adj
[
g
]
+
np
.
eye
(
adj
.
shape
[
1
])))
for
i
in
range
(
sizes
[
g
]):
for
j
in
range
(
sizes
[
g
]):
if
mt
[
g
][
i
][
j
]
>
0.0
:
mt
[
g
][
i
][
j
]
=
1.0
return
-
1e9
*
(
1.0
-
mt
)
def
sparse_to_tuple
(
sparse_mx
):
"""Convert sparse matrix to tuple representation."""
def
to_tuple
(
mx
):
if
not
sp
.
isspmatrix_coo
(
mx
):
mx
=
mx
.
tocoo
()
coords
=
np
.
vstack
((
mx
.
row
,
mx
.
col
)).
transpose
()
values
=
mx
.
data
shape
=
mx
.
shape
return
coords
,
values
,
shape
if
isinstance
(
sparse_mx
,
list
):
for
i
in
range
(
len
(
sparse_mx
)):
sparse_mx
[
i
]
=
to_tuple
(
sparse_mx
[
i
])
else
:
sparse_mx
=
to_tuple
(
sparse_mx
)
return
sparse_mx
def
preprocess_features
(
features
):
"""Row-normalize feature matrix and convert to tuple representation"""
features
=
sp
.
lil_matrix
(
features
.
numpy
())
rowsum
=
np
.
array
(
features
.
sum
(
1
))
r_inv
=
np
.
power
(
rowsum
,
-
1
).
flatten
()
r_inv
[
np
.
isinf
(
r_inv
)]
=
0.
r_mat_inv
=
sp
.
diags
(
r_inv
)
features
=
r_mat_inv
.
dot
(
features
)
return
features
.
todense
(),
sparse_to_tuple
(
features
)
#return features, sparse_to_tuple(features)
\ No newline at end of file
src/sample.py
View file @
1d555125
...
...
@@ -45,39 +45,34 @@ class Sampler:
#print(type(self.train_adj))
self
.
degree_p
=
None
def
_preprocess_adj
(
self
,
normalization
,
adj
,
cuda
):
def
_preprocess_adj
(
self
,
normalization
,
adj
,
device
):
adj_normalizer
=
fetch_normalization
(
normalization
)
r_adj
=
adj_normalizer
(
adj
)
r_adj
=
sparse_mx_to_torch_sparse_tensor
(
r_adj
).
float
()
if
cuda
:
r_adj
=
r_adj
.
cuda
()
return
r_adj
return
r_adj
.
to
(
device
=
device
)
def
_preprocess_fea
(
self
,
fea
,
cuda
):
if
cuda
:
return
fea
.
cuda
()
else
:
return
fea
def
_preprocess_fea
(
self
,
fea
,
device
):
return
fea
.
to
(
device
=
device
)
def
stub_sampler
(
self
,
normalization
,
cuda
):
def
stub_sampler
(
self
,
normalization
,
device
):
"""
The stub sampler. Return the original data.
"""
if
normalization
in
self
.
trainadj_cache
:
r_adj
=
self
.
trainadj_cache
[
normalization
]
else
:
r_adj
=
self
.
_preprocess_adj
(
normalization
,
self
.
train_adj
,
cuda
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
self
.
train_adj
,
device
)
self
.
trainadj_cache
[
normalization
]
=
r_adj
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
cuda
)
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
device
)
return
r_adj
,
fea
def
randomedge_sampler
(
self
,
percent
,
normalization
,
cuda
):
def
randomedge_sampler
(
self
,
percent
,
normalization
,
device
):
"""
Randomly drop edge and preserve percent% edges.
"""
"Opt here"
if
percent
>=
1.0
:
return
self
.
stub_sampler
(
normalization
,
cuda
)
return
self
.
stub_sampler
(
normalization
,
device
)
nnz
=
self
.
train_adj
.
nnz
perm
=
np
.
random
.
permutation
(
nnz
)
...
...
@@ -87,16 +82,16 @@ class Sampler:
(
self
.
train_adj
.
row
[
perm
],
self
.
train_adj
.
col
[
perm
])),
shape
=
self
.
train_adj
.
shape
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
cuda
)
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
cuda
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
device
)
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
device
)
return
r_adj
,
fea
def
vertex_sampler
(
self
,
percent
,
normalization
,
cuda
):
def
vertex_sampler
(
self
,
percent
,
normalization
,
device
):
"""
Randomly drop vertexes.
"""
if
percent
>=
1.0
:
return
self
.
stub_sampler
(
normalization
,
cuda
)
return
self
.
stub_sampler
(
normalization
,
device
)
self
.
learning_type
=
"inductive"
pos_nnz
=
len
(
self
.
pos_train_idx
)
# neg_neighbor_nnz = 0.4 * percent
...
...
@@ -117,16 +112,16 @@ class Sampler:
# print(r_fea.shape)
# print(r_adj.shape)
# print(len(all_samples))
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
cuda
)
r_fea
=
self
.
_preprocess_fea
(
r_fea
,
cuda
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
device
)
r_fea
=
self
.
_preprocess_fea
(
r_fea
,
device
)
return
r_adj
,
r_fea
,
all_samples
def
degree_sampler
(
self
,
percent
,
normalization
,
cuda
):
def
degree_sampler
(
self
,
percent
,
normalization
,
device
):
"""
Randomly drop edge wrt degree (high degree, low probility).
"""
if
percent
>=
0
:
return
self
.
stub_sampler
(
normalization
,
cuda
)
return
self
.
stub_sampler
(
normalization
,
device
)
if
self
.
degree_p
is
None
:
degree_adj
=
self
.
train_adj
.
multiply
(
self
.
degree
)
self
.
degree_p
=
degree_adj
.
data
/
(
1.0
*
np
.
sum
(
degree_adj
.
data
))
...
...
@@ -138,37 +133,38 @@ class Sampler:
(
self
.
train_adj
.
row
[
perm
],
self
.
train_adj
.
col
[
perm
])),
shape
=
self
.
train_adj
.
shape
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
cuda
)
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
cuda
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
r_adj
,
device
)
fea
=
self
.
_preprocess_fea
(
self
.
train_features
,
device
)
return
r_adj
,
fea
def
get_test_set
(
self
,
normalization
,
cuda
):
def
get_test_set
(
self
,
normalization
,
device
):
"""
Return the test set.
"""
if
self
.
learning_type
==
"transductive"
:
return
self
.
stub_sampler
(
normalization
,
cuda
)
return
self
.
stub_sampler
(
normalization
,
device
)
else
:
if
normalization
in
self
.
adj_cache
:
r_adj
=
self
.
adj_cache
[
normalization
]
else
:
r_adj
=
self
.
_preprocess_adj
(
normalization
,
self
.
adj
,
cuda
)
r_adj
=
self
.
_preprocess_adj
(
normalization
,
self
.
adj
,
device
)
self
.
adj_cache
[
normalization
]
=
r_adj
fea
=
self
.
_preprocess_fea
(
self
.
features
,
cuda
)
fea
=
self
.
_preprocess_fea
(
self
.
features
,
device
)
return
r_adj
,
fea
def
get_val_set
(
self
,
normalization
,
cuda
):
def
get_val_set
(
self
,
normalization
,
device
):
"""
Return the validataion set. Only for the inductive task.
Currently behave the same with get_test_set
"""
return
self
.
get_test_set
(
normalization
,
cuda
)
return
self
.
get_test_set
(
normalization
,
device
)
def
get_label_and_idxes
(
self
,
cuda
):
def
get_label_and_idxes
(
self
,
device
):
"""
Return all labels and indexes.
"""
if
cuda
:
return
self
.
labels_torch
.
cuda
(),
self
.
idx_train_torch
.
cuda
(),
self
.
idx_val_torch
.
cuda
(),
self
.
idx_test_torch
.
cuda
()
return
self
.
labels_torch
,
self
.
idx_train_torch
,
self
.
idx_val_torch
,
self
.
idx_test_torch
return
self
.
labels_torch
.
to
(
device
=
device
),
\
self
.
idx_train_torch
.
to
(
device
=
device
),
\
self
.
idx_val_torch
.
to
(
device
=
device
),
\
self
.
idx_test_torch
.
to
(
device
=
device
)
\ No newline at end of file
src/train_new.py
View file @
1d555125
...
...
@@ -7,6 +7,10 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
from
src.process
import
adj_to_bias
,
preprocess_features
from
src.utils
import
sparse_mx_to_torch_sparse_tensor
import
scipy.sparse
as
sp
import
torch.optim
as
optim
from
tensorboardX
import
SummaryWriter
...
...
@@ -25,7 +29,7 @@ from src.sample import Sampler
# Training settings
parser
=
argparse
.
ArgumentParser
()
# Training parameter
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
Fals
e
,
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
Tru
e
,
help
=
'Disables CUDA training.'
)
parser
.
add_argument
(
'--fastmode'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Disable validation during training.'
)
...
...
@@ -83,7 +87,7 @@ args = parser.parse_args()
if
args
.
debug
:
print
(
args
)
# pre setting
args
.
cuda
=
not
args
.
no_cuda
and
torch
.
cuda
.
is_available
()
args
.
device
=
'cuda'
if
(
not
args
.
no_cuda
and
torch
.
cuda
.
is_available
()
)
else
'cpu'
args
.
mixmode
=
args
.
no_cuda
and
args
.
mixmode
and
torch
.
cuda
.
is_available
()
if
args
.
aggrmethod
==
"default"
:
if
args
.
type
==
"resgcn"
:
...
...
@@ -101,33 +105,46 @@ if args.type == "mutigcn":
# random seed setting
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
if
args
.
cuda
or
args
.
mixmode
:
if
args
.
device
==
'
cuda
'
or
args
.
mixmode
:
torch
.
cuda
.
manual_seed
(
args
.
seed
)
# should we need fix random seed here?
sampler
=
Sampler
(
args
.
dataset
,
args
.
datapath
,
args
.
task_type
)
# get labels and indexes
labels
,
idx_train
,
idx_val
,
idx_test
=
sampler
.
get_label_and_idxes
(
args
.
cuda
)
labels
,
idx_train
,
idx_val
,
idx_test
=
sampler
.
get_label_and_idxes
(
device
=
args
.
device
)
nfeat
=
sampler
.
nfeat
nclass
=
sampler
.
nclass
print
(
"nclass: %d
\t
nfea:%d"
%
(
nclass
,
nfeat
))
# The model
model
=
GCNModel
(
nfeat
=
nfeat
,
nhid
=
args
.
hidden
,
nclass
=
nclass
,
nhidlayer
=
args
.
nhiddenlayer
,
dropout
=
args
.
dropout
,
baseblock
=
args
.
type
,
inputlayer
=
args
.
inputlayer
,
outputlayer
=
args
.
outputlayer
,
nbaselayer
=
args
.
nbaseblocklayer
,
activation
=
F
.
relu
,
withbn
=
args
.
withbn
,
withloop
=
args
.
withloop
,
aggrmethod
=
args
.
aggrmethod
,
mixmode
=
args
.
mixmode
)
if
args
.
type
==
"GAT"
:
model
=
GAT
(
in_size
=
nfeat
,
hidden_sizes
=
[
8
],
nbclasses
=
nclass
,
attn_drop
=
0.6
,
p_drop
=
0.6
,
n_heads
=
[
8
,
1
],
activation
=
F
.
elu
,
residual
=
False
)
args
.
lr
=
0.005
#TODO: add l2 regularization# lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if v.name not
# in ['bias', 'gamma', 'b', 'g', 'beta']]) * l2_coef
else
:
model
=
GCNModel
(
nfeat
=
nfeat
,
nhid
=
args
.
hidden
,
nclass
=
nclass
,
nhidlayer
=
args
.
nhiddenlayer
,
dropout
=
args
.
dropout
,
baseblock
=
args
.
type
,
inputlayer
=
args
.
inputlayer
,
outputlayer
=
args
.
outputlayer
,
nbaselayer
=
args
.
nbaseblocklayer
,
activation
=
F
.
relu
,
withbn
=
args
.
withbn
,
withloop
=
args
.
withloop
,
aggrmethod
=
args
.
aggrmethod
,
mixmode
=
args
.
mixmode
)
print
(
args
.
aggrmethod
)
print
(
model
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
...
...
@@ -136,11 +153,10 @@ optimizer = optim.Adam(model.parameters(),
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, factor=0.618)
scheduler
=
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
200
,
300
,
400
,
500
,
600
,
700
],
gamma
=
0.5
)
# convert to cuda
if
args
.
cuda
:
model
.
cuda
()
model
=
model
.
to
(
device
=
args
.
device
)
# For the mix mode, lables and indexes are in cuda.
if
args
.
cuda
or
args
.
mixmode
:
if
args
.
device
==
'
cuda
'
or
args
.
mixmode
:
labels
=
labels
.
cuda
()
idx_train
=
idx_train
.
cuda
()
idx_val
=
idx_val
.
cuda
()
...
...
@@ -167,7 +183,6 @@ def get_lr(optimizer):
return
param_group
[
'lr'
]
# define the training function.
def
train
(
epoch
,
train_adj
,
train_fea
,
idx_train
,
val_adj
=
None
,
val_fea
=
None
):
if
val_adj
is
None
:
val_adj
=
train_adj
...
...
@@ -240,13 +255,28 @@ acc_val = np.zeros((args.epochs,))
sampling_t
=
0
def
gat_preprocess
(
adj
,
feature
):
feature
,
_
=
preprocess_features
(
feature
)
feature
=
torch
.
Tensor
(
np
.
expand_dims
(
feature
,
0
)).
to
(
args
.
device
)
nb_nodes
=
feature
.
shape
[
1
]
adj
=
adj
.
to_dense
().
numpy
()
adj
=
sp
.
coo_matrix
(
adj_to_bias
(
np
.
expand_dims
(
adj
,
0
),
[
nb_nodes
],
nhood
=
1
)[
0
])
adj
=
torch
.
unsqueeze
(
sparse_mx_to_torch_sparse_tensor
(
adj
),
dim
=
0
).
to
(
device
=
args
.
device
)
return
adj
,
feature
for
epoch
in
range
(
args
.
epochs
):
input_idx_train
=
idx_train
sampling_t
=
time
.
time
()
# no sampling
# randomedge sampling if args.sampling_percent >= 1.0, it behaves the same as stub_sampler.
(
train_adj
,
train_fea
)
=
sampler
.
randomedge_sampler
(
percent
=
args
.
sampling_percent
,
normalization
=
args
.
normalization
,
cuda
=
args
.
cuda
)
device
=
"cpu"
)
if
args
.
type
==
'GAT'
:
train_adj
,
train_fea
=
gat_preprocess
(
train_adj
,
train_fea
)
if
args
.
mixmode
:
train_adj
=
train_adj
.
cuda
()
...
...
@@ -257,7 +287,9 @@ for epoch in range(args.epochs):
if
False
:
outputs
=
train
(
epoch
,
train_adj
,
train_fea
,
input_idx_train
)
else
:
(
val_adj
,
val_fea
)
=
sampler
.
get_test_set
(
normalization
=
args
.
normalization
,
cuda
=
args
.
cuda
)
(
val_adj
,
val_fea
)
=
sampler
.
get_test_set
(
normalization
=
args
.
normalization
,
device
=
"cpu"
)
if
args
.
type
==
'GAT'
:
val_adj
,
val_fea
=
gat_preprocess
(
val_adj
,
val_fea
)
if
args
.
mixmode
:
val_adj
=
val_adj
.
cuda
()
outputs
=
train
(
epoch
,
train_adj
,
train_fea
,
input_idx_train
,
val_adj
,
val_fea
)
...
...
@@ -279,8 +311,7 @@ for epoch in range(args.epochs):
tb_writer
.
add_scalar
(
'lr'
,
outputs
[
4
],
epoch
)
tb_writer
.
add_scalars
(
'Time'
,
{
'train'
:
outputs
[
5
],
'val'
:
outputs
[
6
]},
epoch
)
loss_train
[
epoch
],
acc_train
[
epoch
],
loss_val
[
epoch
],
acc_val
[
epoch
]
=
outputs
[
0
],
outputs
[
1
],
outputs
[
2
],
outputs
[
3
]
loss_train
[
epoch
],
acc_train
[
epoch
],
loss_val
[
epoch
],
acc_val
[
epoch
]
=
outputs
[:
4
]
if
args
.
early_stopping
>
0
and
early_stopping
.
early_stop
:
print
(
"Early stopping."
)
...
...
@@ -295,7 +326,10 @@ if args.debug:
print
(
"Total time elapsed: {:.4f}s"
.
format
(
time
.
time
()
-
t_total
))
# Testing
(
test_adj
,
test_fea
)
=
sampler
.
get_test_set
(
normalization
=
args
.
normalization
,
cuda
=
args
.
cuda
)
(
test_adj
,
test_fea
)
=
sampler
.
get_test_set
(
normalization
=
args
.
normalization
,
device
=
"cpu"
)
if
args
.
type
==
'GAT'
:
test_adj
,
test_fea
=
gat_preprocess
(
test_adj
,
test_fea
)
if
args
.
mixmode
:
test_adj
=
test_adj
.
cuda
()
(
loss_test
,
acc_test
)
=
test
(
test_adj
,
test_fea
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment