Commit c244a5aa authored by Vibhav Oswal's avatar Vibhav Oswal
Browse files

changes bernoulli

parent 7e3bbee1
......@@ -152,7 +152,7 @@
import gcn.metrics
from sklearn.neighbors import kneighbors_graph
import tensorflow_probability as tfp
from tfp.distributions import Bernoulli, Normal
#from tfp.distributions import Bernoulli, Normal
try:
from lds_gnn.data import ConfigData, UCI, EdgeDelConfigData
......@@ -271,10 +271,10 @@ def noisy_features(features, mask, noise, ogb):
else:
print("Invalid noise input encountered while adding noise to features")
else:
print(features)
print(features.shape)
print(mask)
print(mask.shape)
#print(features)
#print(features.shape)
#print(mask)
#print(type(mask))
masked_features = features * (1 - mask)
return masked_features
......@@ -302,7 +302,6 @@ def lds(data_conf: ConfigData, config: LDSConfig):
#nones = torch.sum(features > 0.0).float()
nones = tf.math.reduce_sum(tf.cast(features > 0.0, dtype=tf.float32))
print(features.shape)
nzeros = features.shape[0] * features.shape[1] - nones
pzeros = nones / nzeros / r * nr
......@@ -317,8 +316,15 @@ def lds(data_conf: ConfigData, config: LDSConfig):
#probs[features > 0.0] = 1 / r
b_broadcast = tf.ones(tf.shape(features), dtype=features.dtype) * (1 / r)
features = tf.where(tf.greater(features, 0.0), b_broadcast, features)
mask = Bernoulli(probs)
print("the shape of probs, temp are:")
mask = tfp.distributions.Bernoulli(probs).sample()
print(mask.shape)
#mask = Bernoulli(probs)
mask = tf.cast(mask, dtype=tf.float32)
print(mask.shape)
#mask = torch.bernoulli(probs)
ogb = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment