Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Open sidebar
Vibhav Oswal
LDS-GNN
Commits
c244a5aa
Commit
c244a5aa
authored
May 29, 2021
by
Vibhav Oswal
Browse files
changes bernoulli
parent
7e3bbee1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
8 deletions
+14
-8
lds_gnn/.nfs0000000e0034b68c000001ec
lds_gnn/.nfs0000000e0034b68c000001ec
+0
-0
lds_gnn/lds.py
lds_gnn/lds.py
+14
-8
No files found.
lds_gnn/.nfs0000000e0034b68c000001ec
0 → 100644
View file @
c244a5aa
File added
lds_gnn/lds.py
View file @
c244a5aa
...
...
@@ -152,7 +152,7 @@
import
gcn.metrics
from
sklearn.neighbors
import
kneighbors_graph
import
tensorflow_probability
as
tfp
from
tfp.distributions
import
Bernoulli
,
Normal
#
from tfp.distributions import Bernoulli, Normal
try
:
from
lds_gnn.data
import
ConfigData
,
UCI
,
EdgeDelConfigData
...
...
@@ -271,10 +271,10 @@ def noisy_features(features, mask, noise, ogb):
else
:
print
(
"Invalid noise input encountered while adding noise to features"
)
else
:
print
(
features
)
print
(
features
.
shape
)
print
(
mask
)
print
(
mask
.
shape
)
#
print(features)
#
print(features.shape)
#
print(mask)
#
print(
type(mask)
)
masked_features
=
features
*
(
1
-
mask
)
return
masked_features
...
...
@@ -302,7 +302,6 @@ def lds(data_conf: ConfigData, config: LDSConfig):
#nones = torch.sum(features > 0.0).float()
nones
=
tf
.
math
.
reduce_sum
(
tf
.
cast
(
features
>
0.0
,
dtype
=
tf
.
float32
))
print
(
features
.
shape
)
nzeros
=
features
.
shape
[
0
]
*
features
.
shape
[
1
]
-
nones
pzeros
=
nones
/
nzeros
/
r
*
nr
...
...
@@ -317,8 +316,15 @@ def lds(data_conf: ConfigData, config: LDSConfig):
#probs[features > 0.0] = 1 / r
b_broadcast
=
tf
.
ones
(
tf
.
shape
(
features
),
dtype
=
features
.
dtype
)
*
(
1
/
r
)
features
=
tf
.
where
(
tf
.
greater
(
features
,
0.0
),
b_broadcast
,
features
)
mask
=
Bernoulli
(
probs
)
print
(
"the shape of probs, temp are:"
)
mask
=
tfp
.
distributions
.
Bernoulli
(
probs
).
sample
()
print
(
mask
.
shape
)
#mask = Bernoulli(probs)
mask
=
tf
.
cast
(
mask
,
dtype
=
tf
.
float32
)
print
(
mask
.
shape
)
#mask = torch.bernoulli(probs)
ogb
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment