Unverified Commit 6267e939 authored by xuty-007's avatar xuty-007 Committed by GitHub
Browse files

Merge pull request #4 from MihailSalnikov/master

Added tensorboard
parents 41bacc98 f4192a86
...@@ -20,6 +20,7 @@ requests==2.14.2 ...@@ -20,6 +20,7 @@ requests==2.14.2
scikit-learn==0.20.3 scikit-learn==0.20.3
scipy==1.2.1 scipy==1.2.1
six==1.10.0 six==1.10.0
torch==0.4.0 torch>=0.4.0,<0.5.0
torchvision==0.2.1 torchvision==0.2.1
tensorboard==2.1.0
tensorboardX==2.0
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from tensorboardX import SummaryWriter
from earlystopping import EarlyStopping from earlystopping import EarlyStopping
from sample import Sampler from sample import Sampler
...@@ -47,6 +48,7 @@ parser.add_argument('--dataset', default="cora", help="The data set") ...@@ -47,6 +48,7 @@ parser.add_argument('--dataset', default="cora", help="The data set")
parser.add_argument('--datapath', default="data/", help="The data path.") parser.add_argument('--datapath', default="data/", help="The data path.")
parser.add_argument("--early_stopping", type=int, parser.add_argument("--early_stopping", type=int,
default=0, help="The patience of earlystopping. Do not adopt the earlystopping when it equals 0.") default=0, help="The patience of earlystopping. Do not adopt the earlystopping when it equals 0.")
parser.add_argument("--no_tensorboard", default=False, help="Disable writing logs to tensorboard")
# Model parameter # Model parameter
parser.add_argument('--type', parser.add_argument('--type',
...@@ -152,6 +154,10 @@ if args.early_stopping > 0: ...@@ -152,6 +154,10 @@ if args.early_stopping > 0:
early_stopping = EarlyStopping(patience=args.early_stopping, verbose=False) early_stopping = EarlyStopping(patience=args.early_stopping, verbose=False)
print("Model is saving to: %s" % (early_stopping.fname)) print("Model is saving to: %s" % (early_stopping.fname))
if args.no_tensorboard is False:
tb_writer = SummaryWriter(
comment=f"-dataset_{args.dataset}-type_{args.type}"
)
def get_lr(optimizer): def get_lr(optimizer):
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
...@@ -263,6 +269,13 @@ for epoch in range(args.epochs): ...@@ -263,6 +269,13 @@ for epoch in range(args.epochs):
's_time: {:.4f}s'.format(sampling_t), 's_time: {:.4f}s'.format(sampling_t),
't_time: {:.4f}s'.format(outputs[5]), 't_time: {:.4f}s'.format(outputs[5]),
'v_time: {:.4f}s'.format(outputs[6])) 'v_time: {:.4f}s'.format(outputs[6]))
if args.no_tensorboard is False:
tb_writer.add_scalars('Loss', {'train': outputs[0], 'val': outputs[2]}, epoch)
tb_writer.add_scalars('Accuracy', {'train': outputs[1], 'val': outputs[3]}, epoch)
tb_writer.add_scalar('lr', outputs[4], epoch)
tb_writer.add_scalars('Time', {'train': outputs[5], 'val': outputs[6]}, epoch)
loss_train[epoch], acc_train[epoch], loss_val[epoch], acc_val[epoch] = outputs[0], outputs[1], outputs[2], outputs[ loss_train[epoch], acc_train[epoch], loss_val[epoch], acc_val[epoch] = outputs[0], outputs[1], outputs[2], outputs[
3] 3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment