Edward stuff

Setup

%matplotlib inline
%config InlineBackend.figure_formats = set(['svg'])

import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.special as sp
import tensorflow as tf

Learn directed Bayesian network

N = 1000
np.random.seed(0)
cloudy = np.random.uniform(size=N) < .5
rain = np.random.uniform(size=N) < np.where(cloudy, .5, .6)
sprinkler = np.random.uniform(size=N) < np.where(cloudy, .8, .3)
wet = np.random.uniform(size=N) < .99 * rain * sprinkler + .1 * (~rain + ~sprinkler)
data = np.vstack((cloudy, wet)).astype(int).T

MLE in Edward

\[ x_i \sim N(\mu, 1) \]

N = 100
np.random.seed(0)
mu = np.random.normal()
x = np.random.normal(loc=mu, scale=1, size=(N, 1)).astype(np.float32)

px = ed.models.Normal(loc=tf.Variable(tf.zeros([1])) * tf.ones([N, 1]), scale=tf.ones([1]))
inf = ed.KLqp(data={px: x})
inf.run()
print(mu, x.mean(), ed.get_session().run(px.mean()[0, 0]))

1.764052345967664 1.82505 1.82505

Mixture of regressions

\[ y \mid z = N(x beta . z + mu . z, sigma . z) \]

N = 500
np.random.seed(0)
x = np.random.normal(size=N)
z = np.random.uniform(size=N) < .7

beta = np.where(z, 1, 5)
mu = np.where(z, 3, 1)
sigma = np.where(z, .2, .3)

y = x * beta + mu + np.random.normal(scale=sigma, size=N)
K = 2
x_ph = tf.placeholder(tf.float32, [N, 1])

p_z = ed.models.OneHotCategorical(logits=tf.ones([N, K]))
p_beta = ed.models.Normal(loc=tf.zeros([K, 1]), scale=tf.ones([K, 1]))
p_mu = ed.models.Normal(loc=tf.zeros([K, 1]), scale=tf.ones([K, 1]))
p_sigma = ed.models.InverseGamma(concentration=tf.ones([K, 1]), rate=tf.ones([K, 1]))
p_y = ed.models.Normal(loc=x_ph * tf.matmul(tf.cast(p_z, tf.float32), p_beta) + tf.matmul(tf.cast(p_z, tf.float32), p_mu),
                       scale=tf.matmul(tf.cast(p_z, tf.float32), p_sigma))

q_z = ed.models.OneHotCategorical(logits=tf.Variable(tf.ones([N, K])))
q_beta = ed.models.PointMass(params=tf.Variable(tf.zeros([K, 1])))
q_mu = ed.models.PointMass(params=tf.Variable(tf.zeros([K, 1])))
q_sigma = ed.models.PointMass(params=tf.nn.softplus(tf.Variable(tf.ones([K, 1]))))
E = ed.KLqp(
  data={
    x_ph: x.astype(np.float32).reshape(-1, 1),
    p_y: y.astype(np.float32).reshape(-1, 1),
    p_beta: q_beta,
    p_mu: q_mu,
    p_sigma: q_sigma,
  },
  latent_vars={
    q_z: p_z,
  })
M = ed.MAP(
  data={
    x_ph: x.astype(np.float32).reshape(-1, 1),
    p_y: y.astype(np.float32).reshape(-1, 1),
    p_z: q_z,
  },
  latent_vars={
    p_beta: q_beta,
    p_mu: q_mu,
    p_sigma: q_sigma,
  })
E.initialize()
M.initialize()
ed.get_session().run(tf.global_variables_initializer())
for i in range(1):
  for _ in range(100):
    res0 = E.update()
  res = M.update()
  print(res0['loss'], res['loss'])
ed.get_session().run(
  [q_z.probs,
   tf.reshape(q_beta, [-1]),
   tf.reshape(q_mu, [-1]),
   tf.reshape(q_sigma, [-1])])
[array([[ 0.71002579,  0.28997418],
[ 0.48043731,  0.51956272],
[ 0.38829982,  0.61170018],
[ 0.45394057,  0.54605943],
[ 0.60275078,  0.39724925],
[ 0.56599832,  0.43400165],
[ 0.42639849,  0.57360154],
[ 0.58124381,  0.41875616],
[ 0.55708236,  0.44291764],
[ 0.39320481,  0.60679525],
[ 0.45562571,  0.54437423],
[ 0.53851324,  0.46148679],
[ 0.43844658,  0.56155342],
[ 0.46309105,  0.53690892],
[ 0.4733164 ,  0.52668363],
[ 0.38441676,  0.61558324],
[ 0.63565713,  0.3643429 ],
[ 0.59845626,  0.40154377],
[ 0.40672451,  0.59327549],
[ 0.42584515,  0.57415485],
[ 0.38219059,  0.61780947],
[ 0.2090683 ,  0.7909317 ],
[ 0.57429117,  0.4257088 ],
[ 0.57058483,  0.4294152 ],
[ 0.61770642,  0.38229361],
[ 0.51058275,  0.48941731],
[ 0.56709599,  0.43290398],
[ 0.41292256,  0.58707744],
[ 0.50020838,  0.49979153],
[ 0.57302105,  0.42697892],
[ 0.51298302,  0.48701701],
[ 0.66017324,  0.33982676],
[ 0.54463935,  0.45536068],
[ 0.51573318,  0.48426682],
[ 0.32967111,  0.67032892],
[ 0.29039744,  0.70960259],
[ 0.50654632,  0.49345359],
[ 0.53788632,  0.46211365],
[ 0.61870939,  0.38129056],
[ 0.48693788,  0.51306206],
[ 0.53016818,  0.46983185],
[ 0.48576871,  0.51423126],
[ 0.44574535,  0.55425465],
[ 0.48011023,  0.51988977],
[ 0.4652572 ,  0.53474277],
[ 0.54088908,  0.45911092],
[ 0.5020507 ,  0.49794933],
[ 0.38352886,  0.61647111],
[ 0.47106972,  0.52893025],
[ 0.15485103,  0.84514898],
[ 0.4664011 ,  0.53359896],
[ 0.30208117,  0.69791883],
[ 0.50660372,  0.49339628],
[ 0.47084692,  0.52915305],
[ 0.64215273,  0.3578473 ],
[ 0.57530797,  0.42469206],
[ 0.57450587,  0.42549413],
[ 0.51293218,  0.48706779],
[ 0.54589957,  0.45410049],
[ 0.42113671,  0.57886332],
[ 0.66159713,  0.33840284],
[ 0.53784293,  0.46215701],
[ 0.45368984,  0.54631019],
[ 0.60997987,  0.3900201 ],
[ 0.65123248,  0.34876749],
[ 0.37624544,  0.62375456],
[ 0.53248215,  0.46751788],
[ 0.40773889,  0.59226114],
[ 0.29519737,  0.70480263],
[ 0.61828786,  0.38171217],
[ 0.56785941,  0.43214053],
[ 0.43748084,  0.56251913],
[ 0.61715895,  0.38284105],
[ 0.61459643,  0.38540357],
[ 0.43189934,  0.56810063],
[ 0.44795257,  0.55204749],
[ 0.3988784 ,  0.6011216 ],
[ 0.59767079,  0.40232918],
[ 0.53826112,  0.46173885],
[ 0.4576849 ,  0.54231513],
[ 0.42446387,  0.57553613],
[ 0.39003873,  0.60996127],
[ 0.38686582,  0.61313421],
[ 0.36126259,  0.63873738],
[ 0.5862332 ,  0.41376686],
[ 0.52707517,  0.47292483],
[ 0.60705346,  0.39294654],
[ 0.47495189,  0.52504814],
[ 0.32205719,  0.67794281],
[ 0.47547859,  0.52452147],
[ 0.51039815,  0.48960185],
[ 0.38047826,  0.61952174],
[ 0.30331373,  0.69668621],
[ 0.64933407,  0.35066593],
[ 0.58514827,  0.4148517 ],
[ 0.58274013,  0.4172599 ],
[ 0.33704156,  0.66295844],
[ 0.51428324,  0.48571673],
[ 0.33593464,  0.66406536],
[ 0.65539223,  0.34460768],
[ 0.54574114,  0.45425886],
[ 0.58412308,  0.4158769 ],
[ 0.53556174,  0.46443826],
[ 0.46937943,  0.53062057],
[ 0.61514378,  0.38485622],
[ 0.41186836,  0.58813161],
[ 0.47821105,  0.52178895],
[ 0.55419874,  0.44580123],
[ 0.44436356,  0.55563647],
[ 0.45664942,  0.54335058],
[ 0.4024387 ,  0.5975613 ],
[ 0.48182285,  0.51817721],
[ 0.44133377,  0.55866629],
[ 0.43799591,  0.56200403],
[ 0.39235443,  0.60764557],
[ 0.31296471,  0.68703526],
[ 0.34281242,  0.65718758],
[ 0.43191525,  0.56808472],
[ 0.30023715,  0.69976282],
[ 0.44088694,  0.55911309],
[ 0.52743107,  0.47256896],
[ 0.53777021,  0.46222973],
[ 0.75503856,  0.24496143],
[ 0.41529727,  0.58470267],
[ 0.5416109 ,  0.45838913],
[ 0.60855252,  0.39144745],
[ 0.56409019,  0.43590978],
[ 0.60207796,  0.39792207],
[ 0.43606761,  0.56393248],
[ 0.60557753,  0.39442244],
[ 0.63035196,  0.36964798],
[ 0.42812946,  0.57187057],
[ 0.40494242,  0.59505755],
[ 0.41584125,  0.58415878],
[ 0.52344775,  0.47655222],
[ 0.42652997,  0.57347006],
[ 0.61050302,  0.38949701],
[ 0.46699005,  0.53300995],
[ 0.54979461,  0.45020539],
[ 0.62898892,  0.37101102],
[ 0.61520046,  0.38479954],
[ 0.37691894,  0.62308109],
[ 0.59981608,  0.40018392],
[ 0.34312534,  0.65687466],
[ 0.39728737,  0.60271257],
[ 0.52067167,  0.4793283 ],
[ 0.5152083 ,  0.48479176],
[ 0.76850319,  0.2314968 ],
[ 0.5312919 ,  0.46870807],
[ 0.29801649,  0.70198351],
[ 0.51338738,  0.48661265],
[ 0.40474135,  0.59525859],
[ 0.5791117 ,  0.42088836],
[ 0.55248088,  0.44751909],
[ 0.51236618,  0.48763379],
[ 0.35830188,  0.64169812],
[ 0.47314191,  0.52685809],
[ 0.48809928,  0.51190078],
[ 0.4083834 ,  0.59161657],
[ 0.27255264,  0.72744727],
[ 0.5082413 ,  0.49175876],
[ 0.46433172,  0.53566825],
[ 0.49951553,  0.50048453],
[ 0.41685823,  0.58314174],
[ 0.40450916,  0.59549081],
[ 0.63181913,  0.36818093],
[ 0.5749262 ,  0.42507377],
[ 0.59248257,  0.40751746],
[ 0.51416624,  0.4858337 ],
[ 0.4571057 ,  0.5428943 ],
[ 0.60158944,  0.39841059],
[ 0.49728325,  0.50271678],
[ 0.44007769,  0.55992228],
[ 0.58551878,  0.41448122],
[ 0.46919104,  0.53080893],
[ 0.50912935,  0.49087059],
[ 0.40775898,  0.59224099],
[ 0.5088917 ,  0.49110833],
[ 0.38495445,  0.61504549],
[ 0.47667813,  0.52332181],
[ 0.63478494,  0.36521503],
[ 0.47245151,  0.52754849],
[ 0.35618833,  0.6438117 ],
[ 0.564062  ,  0.43593797],
[ 0.3959046 ,  0.6040954 ],
[ 0.65318727,  0.3468127 ],
[ 0.42544624,  0.57455379],
[ 0.33444902,  0.66555095],
[ 0.43941447,  0.56058556],
[ 0.49924177,  0.50075823],
[ 0.46403575,  0.53596431],
[ 0.58585149,  0.41414857],
[ 0.44646743,  0.55353248],
[ 0.54713833,  0.45286173],
[ 0.43269864,  0.56730133],
[ 0.36501101,  0.63498908],
[ 0.39428982,  0.60571021],
[ 0.76552308,  0.23447697],
[ 0.567922  ,  0.43207803],
[ 0.59970969,  0.40029037],
[ 0.58460778,  0.41539231],
[ 0.24526666,  0.75473338],
[ 0.26229888,  0.73770118],
[ 0.53816956,  0.46183044],
[ 0.41838986,  0.58161014],
[ 0.28172556,  0.71827441],
[ 0.73562717,  0.26437283],
[ 0.57271713,  0.42728293],
[ 0.67609668,  0.32390332],
[ 0.60700214,  0.39299786],
[ 0.54531962,  0.45468035],
[ 0.79055327,  0.2094468 ],
[ 0.50417173,  0.49582824],
[ 0.36504081,  0.63495922],
[ 0.67351651,  0.32648349],
[ 0.48000541,  0.51999462],
[ 0.54124218,  0.45875776],
[ 0.49603325,  0.50396675],
[ 0.52238524,  0.47761476],
[ 0.68409765,  0.31590229],
[ 0.64417231,  0.35582769],
[ 0.35448414,  0.64551586],
[ 0.42745626,  0.5725438 ],
[ 0.50095332,  0.49904668],
[ 0.63306636,  0.36693358],
[ 0.50398809,  0.49601182],
[ 0.4958204 ,  0.50417954],
[ 0.47330362,  0.52669644],
[ 0.70182741,  0.29817262],
[ 0.28064489,  0.71935511],
[ 0.2785818 ,  0.72141826],
[ 0.50518888,  0.49481112],
[ 0.55749619,  0.44250387],
[ 0.48519689,  0.51480311],
[ 0.50366759,  0.49633238],
[ 0.53342515,  0.46657488],
[ 0.62646013,  0.37353992],
[ 0.59017068,  0.40982935],
[ 0.51482952,  0.48517048],
[ 0.48711106,  0.51288891],
[ 0.68457484,  0.31542522],
[ 0.34068   ,  0.65932   ],
[ 0.43009904,  0.56990105],
[ 0.6076529 ,  0.3923471 ],
[ 0.42084491,  0.57915503],
[ 0.47324702,  0.52675301],
[ 0.37014487,  0.6298551 ],
[ 0.48246294,  0.51753706],
[ 0.5963726 ,  0.4036274 ],
[ 0.62080538,  0.37919465],
[ 0.50523442,  0.49476561],
[ 0.52521145,  0.47478855],
[ 0.59784442,  0.40215558],
[ 0.49823123,  0.50176883],
[ 0.6080004 ,  0.39199957],
[ 0.42829162,  0.57170838],
[ 0.67309278,  0.32690722],
[ 0.51378357,  0.48621637],
[ 0.63700259,  0.36299741],
[ 0.54894733,  0.45105267],
[ 0.43539476,  0.56460518],
[ 0.46971148,  0.53028852],
[ 0.64895523,  0.35104477],
[ 0.42190167,  0.57809836],
[ 0.49398458,  0.50601542],
[ 0.51754951,  0.48245054],
[ 0.49798828,  0.50201166],
[ 0.41089019,  0.58910984],
[ 0.27574155,  0.72425842],
[ 0.46945363,  0.53054637],
[ 0.38797656,  0.61202347],
[ 0.42296138,  0.57703859],
[ 0.41252366,  0.58747631],
[ 0.47734672,  0.52265328],
[ 0.36938927,  0.63061076],
[ 0.32669395,  0.67330605],
[ 0.60283309,  0.39716697],
[ 0.38714966,  0.61285037],
[ 0.50990999,  0.49009001],
[ 0.5869168 ,  0.4130832 ],
[ 0.497078  ,  0.50292206],
[ 0.49703607,  0.5029639 ],
[ 0.40899321,  0.59100676],
[ 0.65545309,  0.34454691],
[ 0.3524608 ,  0.64753926],
[ 0.31158391,  0.68841612],
[ 0.47008419,  0.52991581],
[ 0.42623898,  0.57376099],
[ 0.55362576,  0.44637421],
[ 0.68412167,  0.31587836],
[ 0.35418135,  0.64581865],
[ 0.25532928,  0.74467069],
[ 0.43706328,  0.56293672],
[ 0.66076267,  0.33923739],
[ 0.4919304 ,  0.50806957],
[ 0.45944887,  0.54055119],
[ 0.63495106,  0.36504894],
[ 0.67940694,  0.32059309],
[ 0.62442881,  0.37557113],
[ 0.51191801,  0.48808205],
[ 0.37476775,  0.62523222],
[ 0.74282789,  0.25717211],
[ 0.42770943,  0.57229054],
[ 0.43511936,  0.56488067],
[ 0.44826511,  0.55173486],
[ 0.31072369,  0.68927634],
[ 0.41287073,  0.58712924],
[ 0.59510809,  0.40489194],
[ 0.70356369,  0.29643634],
[ 0.37164   ,  0.62835997],
[ 0.3765277 ,  0.62347233],
[ 0.58470893,  0.41529107],
[ 0.59956676,  0.40043321],
[ 0.51642925,  0.48357072],
[ 0.55386239,  0.44613764],
[ 0.68319321,  0.31680682],
[ 0.23681407,  0.76318586],
[ 0.3577981 ,  0.6422019 ],
[ 0.55373985,  0.44626021],
[ 0.69725144,  0.3027485 ],
[ 0.55120802,  0.44879198],
[ 0.49777493,  0.50222498],
[ 0.4483172 ,  0.55168277],
[ 0.50391287,  0.49608716],
[ 0.41267136,  0.58732873],
[ 0.50214213,  0.49785787],
[ 0.63410002,  0.36590004],
[ 0.33282155,  0.66717845],
[ 0.30768725,  0.69231272],
[ 0.49322575,  0.50677419],
[ 0.35773757,  0.64226246],
[ 0.54091048,  0.45908952],
[ 0.7552458 ,  0.2447542 ],
[ 0.5050723 ,  0.49492779],
[ 0.45320013,  0.5467999 ],
[ 0.32543126,  0.67456871],
[ 0.35045615,  0.64954382],
[ 0.42995608,  0.57004386],
[ 0.44718689,  0.55281317],
[ 0.42096293,  0.57903707],
[ 0.37363988,  0.62636012],
[ 0.43621966,  0.56378037],
[ 0.3766996 ,  0.62330037],
[ 0.65387052,  0.34612948],
[ 0.51497632,  0.48502368],
[ 0.50727457,  0.49272546],
[ 0.51152873,  0.48847124],
[ 0.45450982,  0.54549009],
[ 0.51226872,  0.48773128],
[ 0.39616162,  0.60383838],
[ 0.49920642,  0.50079358],
[ 0.55853719,  0.44146281],
[ 0.55030149,  0.44969845],
[ 0.31780696,  0.68219304],
[ 0.42395264,  0.5760473 ],
[ 0.52481049,  0.47518945],
[ 0.39745817,  0.60254186],
[ 0.44410568,  0.55589437],
[ 0.57883829,  0.42116168],
[ 0.57630509,  0.42369485],
[ 0.4114745 ,  0.58852559],
[ 0.27175623,  0.72824377],
[ 0.28102168,  0.71897835],
[ 0.57849431,  0.42150572],
[ 0.49072948,  0.50927049],
[ 0.53144264,  0.46855742],
[ 0.63396198,  0.36603802],
[ 0.57280034,  0.42719963],
[ 0.49170384,  0.50829613],
[ 0.53744221,  0.46255773],
[ 0.37506071,  0.6249392 ],
[ 0.6688059 ,  0.33119407],
[ 0.51032859,  0.48967141],
[ 0.52401632,  0.47598365],
[ 0.54270416,  0.45729586],
[ 0.60252625,  0.39747372],
[ 0.34316862,  0.65683138],
[ 0.49852461,  0.50147545],
[ 0.43459946,  0.5654006 ],
[ 0.54910624,  0.45089373],
[ 0.58483511,  0.41516492],
[ 0.37603599,  0.62396401],
[ 0.38383642,  0.61616361],
[ 0.46504709,  0.53495294],
[ 0.49452743,  0.50547254],
[ 0.45659745,  0.54340255],
[ 0.6947999 ,  0.30520016],
[ 0.34478918,  0.65521085],
[ 0.76010448,  0.23989552],
[ 0.52163577,  0.47836426],
[ 0.49834624,  0.50165367],
[ 0.57489723,  0.4251028 ],
[ 0.70049012,  0.29950988],
[ 0.49580365,  0.50419629],
[ 0.48210365,  0.51789629],
[ 0.43704611,  0.56295389],
[ 0.24105078,  0.75894922],
[ 0.46682742,  0.53317261],
[ 0.40607652,  0.59392345],
[ 0.6035645 ,  0.3964355 ],
[ 0.50269431,  0.49730566],
[ 0.40303802,  0.59696203],
[ 0.46847877,  0.53152126],
[ 0.50786275,  0.49213725],
[ 0.42432216,  0.57567793],
[ 0.45127711,  0.54872292],
[ 0.57181513,  0.42818487],
[ 0.42151466,  0.57848537],
[ 0.43607289,  0.56392717],
[ 0.45397151,  0.54602844],
[ 0.43076172,  0.56923831],
[ 0.68091029,  0.31908965],
[ 0.48943397,  0.51056612],
[ 0.55664223,  0.44335777],
[ 0.72561008,  0.27438992],
[ 0.47178993,  0.52821004],
[ 0.51078868,  0.48921132],
[ 0.45309776,  0.54690224],
[ 0.4728457 ,  0.52715427],
[ 0.47694656,  0.52305341],
[ 0.50450379,  0.4954963 ],
[ 0.55760914,  0.44239086],
[ 0.46785116,  0.53214884],
[ 0.50984597,  0.490154  ],
[ 0.41991714,  0.58008289],
[ 0.4738881 ,  0.52611196],
[ 0.42884421,  0.57115573],
[ 0.67225224,  0.32774773],
[ 0.45308289,  0.5469172 ],
[ 0.55035019,  0.44964984],
[ 0.44749242,  0.55250758],
[ 0.45291388,  0.54708612],
[ 0.1964258 ,  0.80357426],
[ 0.55424738,  0.44575259],
[ 0.43516347,  0.5648365 ],
[ 0.59923226,  0.40076771],
[ 0.51310593,  0.4868941 ],
[ 0.625319  ,  0.37468103],
[ 0.37344772,  0.62655228],
[ 0.44075784,  0.55924207],
[ 0.34691727,  0.65308279],
[ 0.33231381,  0.66768616],
[ 0.45853186,  0.54146808],
[ 0.58474028,  0.41525975],
[ 0.46883255,  0.53116739],
[ 0.584396  ,  0.41560403],
[ 0.42681029,  0.57318974],
[ 0.55868268,  0.44131738],
[ 0.65386695,  0.34613308],
[ 0.28968138,  0.71031857],
[ 0.65495002,  0.34504998],
[ 0.686454  ,  0.313546  ],
[ 0.46859944,  0.5314005 ],
[ 0.44909573,  0.55090433],
[ 0.63541949,  0.36458051],
[ 0.30905333,  0.69094664],
[ 0.51697308,  0.48302689],
[ 0.66233909,  0.33766091],
[ 0.6001389 ,  0.39986119],
[ 0.40635231,  0.59364772],
[ 0.46703085,  0.53296912],
[ 0.40248233,  0.59751767],
[ 0.42627519,  0.57372481],
[ 0.34972823,  0.65027183],
[ 0.51255959,  0.48744044],
[ 0.57648027,  0.42351976],
[ 0.53118485,  0.46881515],
[ 0.44005758,  0.55994248],
[ 0.41567108,  0.58432889],
[ 0.37844101,  0.62155902],
[ 0.62293488,  0.37706509],
[ 0.55375665,  0.44624338],
[ 0.49558094,  0.50441915],
[ 0.46740434,  0.53259563],
[ 0.45204943,  0.54795063],
[ 0.42922086,  0.57077909],
[ 0.45012459,  0.54987544],
[ 0.52738768,  0.47261229],
[ 0.68961829,  0.31038171],
[ 0.42298055,  0.57701945],
[ 0.49630976,  0.50369024],
[ 0.31605911,  0.68394089],
[ 0.66076434,  0.33923566],
[ 0.48949784,  0.51050222],
[ 0.64120275,  0.35879728],
[ 0.42200032,  0.57799971],
[ 0.48909491,  0.51090509],
[ 0.36553371,  0.63446629],
[ 0.49850166,  0.50149834],
[ 0.51601541,  0.48398453],
[ 0.36317748,  0.63682252],
[ 0.40290901,  0.59709102],
[ 0.58002025,  0.41997972],
[ 0.68031371,  0.31968635],
[ 0.51672632,  0.48327366],
[ 0.29602346,  0.70397657],
[ 0.42756143,  0.57243854],
[ 0.48997787,  0.5100221 ],
[ 0.33808461,  0.66191542],
[ 0.43681616,  0.56318378]], dtype=float32),
array([ 1.93082905,  2.06234193], dtype=float32),
array([ 2.3322289,  2.3039763], dtype=float32),
array([ 2.00949097,  2.00899529], dtype=float32)]

Logistic

from edward.models import *

ed.get_session().close()
tf.reset_default_graph()
ed.set_seed(0)

N = 400
D = 10
T = 5000
noise_std = 0.1

np.random.seed(0)
X = np.random.normal(size=(N, D))
w0 = np.random.normal(size=(D, 1))
b0 = 0.0
y = sp.expit(X.dot(w0) + b0)
threshold = np.random.uniform(size=(N, 1))
y = np.less(threshold, y)
y = y.ravel().astype(int)
X = X.reshape((N, D))

X_ph = tf.placeholder(tf.float32, [N, D])
w = Normal(loc=tf.zeros([D]), scale=3 * tf.ones([D]))
b = Normal(loc=tf.zeros([]), scale=3 * tf.ones([]))
py = Bernoulli(logits=ed.dot(X_ph, w) + b)

# INFERENCE
qw = Empirical(params=tf.Variable(tf.random_normal([T, D])))
qb = Empirical(params=tf.Variable(tf.random_normal([T])))
inference = ed.HMC({w: qw, b: qb}, data={X_ph: X, py: y})
inference.run()
w_hat, w_ci = ed.get_session().run([qw.mean(), 1.96 * qw.stddev()])
plt.clf()
plt.errorbar(np.arange(D), y=w_hat, yerr=w_ci, c='k', fmt='+', label='Estimated')
plt.scatter(np.arange(D), w0, c='r', s=16, label='True')
plt.legend()
plt.xlabel('Predictor')
plt.xticks(np.arange(D), np.arange(D))
plt.ylabel('Effect size')
<matplotlib.text.Text at 0x7fc3c94f7b70>

Sorry, your browser does not support SVG.

Residual variance

ed.get_session().close()
tf.reset_default_graph()
Normal = ed.models.Normal

def build_toy_dataset(N, w, noise_sd=0.1, data_sd=1):
    D = len(w)
    x = np.random.normal(0, data_sd, size=(N, D))
    y = np.dot(x, w) + np.random.normal(0, noise_sd, size=N)
    return x, y


### Generate the data
# Note that data_sd >> noise_sd
N = 1000
D = 5

w_true = np.random.randn(D)
noise_sd = 0.1
data_sd = 5
X_train, y_train = build_toy_dataset(N, w_true, noise_sd=noise_sd, data_sd=data_sd)


### Define the model
X = tf.placeholder(tf.float32, [N, D])
w = Normal(loc=tf.zeros(D), scale=tf.ones(D))
b = Normal(loc=tf.zeros(1), scale=tf.ones(1))
log_sd = Normal(loc=tf.zeros(1), scale=tf.ones(1))
y = Normal(loc=ed.dot(X, w) + b, scale=tf.exp(log_sd))

qw = Normal(loc=tf.get_variable("qw/loc", [D]),
            scale=tf.nn.softplus(tf.get_variable("qw/scale", [D])))
qb = Normal(loc=tf.get_variable("qb/loc", [1]),
            scale=tf.nn.softplus(tf.get_variable("qb/scale", [1])))
qlog_sd = Normal(loc=tf.get_variable("qlog_sd/loc", [1]),
            scale=tf.nn.softplus(tf.get_variable("qlog_sd/scale", [1])))


### Variational Inference
# Many samples and iterations, just to go sure
inference = ed.ReparameterizationKLKLqp({w: qw, b: qb, log_sd: qlog_sd}, data={X: X_train, y: y_train})
inference.run()


### Print estimates
sess = ed.get_session()

# w is unbiased...
print("w hat: {}".format(sess.run(qw.mean())))
print("w true: {}".format(w_true))

# sd is overestimated..
print("sd hat: {}".format(sess.run(tf.exp(qlog_sd.mean()))))
print("sd true: {}".format(noise_sd))

print('sd CI: {}'.format(sess.run([tf.exp(qlog_sd.mean() - 1.96 * qlog_sd.stddev()),
   tf.exp(qlog_sd.mean() + 1.96 * qlog_sd.stddev())])))

Bayesian RNN

curl -OL https://ti.arc.nasa.gov/m/project/prognostic-repository/CMAPSSData.zip
unzip -d cmaps CMAPSSData.zip

In the training data, the time series ends at failure. The units of the predicted output should be "Remaining Useful Life" (the number of cycles before the unit fails).

x_train = pd.read_table('cmaps/train_FD001.txt', header=None, sep=' ')
y_train = x_train.groupby(0).apply(lambda x: pd.Series(np.arange(x.shape[0], 0, -1)))

Construct an Elman network with a Bayesian prior on the weight matrices.

ed.get_session().close()
tf.reset_default_graph()
ed.set_seed(0)

Normal = ed.models.Normal

H = 50  # number of hidden units
D = 22  # number of features

def rnn_cell(hprev, xt):
  return tf.tanh(ed.dot(hprev, Wh) + ed.dot(xt, Wx) + bh)

Wh = Normal(loc=tf.zeros([H, H]), scale=tf.ones([H, H]))
Wx = Normal(loc=tf.zeros([D, H]), scale=tf.ones([D, H]))
Wy = Normal(loc=tf.zeros([H, 1]), scale=tf.ones([H, 1]))
bh = Normal(loc=tf.zeros(H), scale=tf.ones(H))
by = Normal(loc=tf.zeros(1), scale=tf.ones(1))

x = tf.placeholder(tf.float32, [None, D])
h = tf.scan(rnn_cell, x, initializer=tf.zeros(H))
y = Normal(loc=tf.matmul(h, Wy) + by, scale=1.0)
_N = ed.models.NormalWithSoftplusScale

qWh = _N(loc=tf.get_variable('qW_h/loc', [H, H]),
         scale=tf.get_variable('qW_h/scale', [H, H]))
qWx = _N(loc=tf.get_variable('qW_x/loc', [D, H]),
         scale=tf.get_variable('qW_x/scale', [D, H]))
qWy = _N(loc=tf.get_variable('qW_y/loc', [H, 1]),
         scale=tf.get_variable('qW_y/scale', [H, 1]))
qbh = _N(loc=tf.get_variable('qb_h/loc', [H]),
         scale=tf.get_variable('qb_h/scale', [H]))
qby = _N(loc=tf.get_variable('qb_y/loc', [1]),
         scale=tf.get_variable('qb_y/scale', [1]))

The training data are arranged as variable length time series, indexed by column 0.

series = x_train.groupby(0).groups

inf = ed.KLqp(latent_vars={Wh: qWh, Wx: qWx, Wy: qWy, bh: qbh, by: qby})
inf.initialize(logdir='log')
ed.get_session().run(tf.global_variables_initializer())
for i in range(inf.n_iter):
  res = inf.update(
    feed_dict={x: x_train.iloc[series[(i % len(series)) + 1],2:24].values,
               y: y_train.iloc[series[(i % len(series)) + 1]].values.reshape(-1, 1)})
  inf.print_progress(res)

The test data are partial time series, plus reported number of cycles until failure.

x_test = pd.read_table('cmaps/test_FD001.txt', header=None, sep=' ')
y_test = pd.read_table('cmaps/RUL_FD001.txt', header=None)

Get the posterior distribution of predictions by sampling from the variational approximation.

def predict(x_test):
  return ed.get_session().run(
    ed.copy(y[-1], dict_swap={Wh: qWh, Wx: qWx, Wy: qWy, bh: qbh, by: qby}),
    feed_dict={x: x_test.iloc[:,2:24].values})

Bayesian NN

Sample from a non-linear decision problem.

import sklearn.model_selection as skms
import sklearn.datasets as skd
import sklearn.preprocessing as skp

X, Y = skd.make_moons(noise=0.2, random_state=0, n_samples=1000)
X = skp.scale(X)
X_train, X_test, Y_train, Y_test = skms.train_test_split(X, Y, test_size=.5)

n, p = X_train.shape

Plot the training data.

plt.clf()
plt.gcf().set_size_inches(6, 6)
plt.scatter(X_train[Y_train == 0,0], X_train[Y_train == 0,1], c='b', s=3)
plt.scatter(X_train[Y_train == 1,0], X_train[Y_train == 1,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba916b2d68>

Sorry, your browser does not support SVG.

Check that an MLP can successfully classify the data.

tf.reset_default_graph()

with tf.variable_scope('weights'):
  w0 = tf.get_variable('w0', [p, 5])
  w1 = tf.get_variable('w1', [5, 5])
  w2 = tf.get_variable('w2', [5, 1])

x = tf.placeholder(tf.float32, [None, p], name="X")
y = tf.placeholder(tf.float32, [n, 1], name="Y")

h = tf.tanh(tf.matmul(x, w0))
h = tf.tanh(tf.matmul(h, w1))
h = tf.matmul(h, w2)

loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=h))
train = tf.train.AdamOptimizer().minimize(loss)
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(1000):
    _, f = sess.run([train, loss], {x: X_train, y: Y_train.reshape(-1, 1)})
    if not i % 100:
      print(i, f)
  yhat = sess.run(tf.reshape(h > 0.5, [-1]), {x: X_test})
plt.gcf().set_size_inches(6, 6)
plt.scatter(X_test[Y_test == yhat,0], X_test[Y_test == yhat,1], c='k', s=3)
plt.scatter(X_test[Y_test != yhat,0], X_test[Y_test != yhat,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba68dc9c18>

Sorry, your browser does not support SVG.

Train a Bayesian MLP.

ed.get_session().close()
tf.reset_default_graph()
ed.set_seed(0)

def neural_network(Xp):
  h = tf.tanh(tf.matmul(Xp, W_0))
  h = tf.tanh(tf.matmul(h, W_1))
  h = tf.sigmoid(tf.matmul(h, W_2))
  return tf.reshape(h, [-1])

n_hidden = 5

# MODEL
Normal = ed.models.Normal
Bernoulli = ed.models.Bernoulli

with tf.name_scope("model"):
  W_0 = Normal(loc=tf.zeros([X.shape[1], n_hidden]), scale=tf.ones([X.shape[1], n_hidden]), name="W_0")
  W_1 = Normal(loc=tf.zeros([n_hidden, n_hidden]), scale=tf.ones([n_hidden, n_hidden]), name="W_1")
  W_2 = Normal(loc=tf.zeros([n_hidden, 1]), scale=tf.ones([n_hidden, 1]), name="W_2")
  Xp = tf.placeholder(tf.float32, [None, X.shape[1]], name="X")
  f = neural_network(Xp)
  Y = Bernoulli(f, name="Y")

with tf.variable_scope("posterior"):
  with tf.variable_scope("qW_0"):
    loc = tf.get_variable("loc", [X.shape[1], n_hidden])
    scale = tf.nn.softplus(tf.get_variable("scale", [X.shape[1], n_hidden]))
    qW_0 = Normal(loc=loc, scale=scale)
  with tf.variable_scope("qW_1"):
    loc = tf.get_variable("loc", [n_hidden, n_hidden])
    scale = tf.nn.softplus(tf.get_variable("scale", [n_hidden, n_hidden]))
    qW_1 = Normal(loc=loc, scale=scale)
  with tf.variable_scope("qW_2"):
    loc = tf.get_variable("loc", [n_hidden, 1])
    scale = tf.nn.softplus(tf.get_variable("scale", [n_hidden, 1]))
    qW_2 = Normal(loc=loc, scale=scale)
inference = ed.KLqp({W_0: qW_0, W_1: qW_1, W_2: qW_2}, data={Xp: X_train, Y: Y_train})
inference.run(n_iter=3000, logdir='/scratch/midway2/aksarkar/nwas/log/')
py = ed.copy(Y, {W_0: qW_0, W_1: qW_1, W_2: qW_2})
yhat = ed.get_session().run(py, {Xp: X_test})
plt.gcf().set_size_inches(6, 6)
plt.scatter(X_test[Y_test == yhat,0], X_test[Y_test == yhat,1], c='k', s=3)
plt.scatter(X_test[Y_test != yhat,0], X_test[Y_test != yhat,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba68267240>

Sorry, your browser does not support SVG.

Convnet

from edward.models import *

Generate some data \(x_i \sim \text{Multinomial}(1, \mathbf{p})\).

N = 5000
true_param1 = np.array([.15, .35, .5])
data = np.random.choice(a=len(true_param1), p=true_param1, size=N)

Fit the model.

ed.get_session().close()
tf.reset_default_graph()  
ed.set_seed(0)

param1 = Dirichlet(tf.ones([3]))
w = Categorical(logits=param1, sample_shape=[N])

qparam1 = Dirichlet(tf.Variable(tf.ones([3])))

inference = ed.KLqp({param1: qparam1}, data={w: data})
inference.run(n_iter=2000)
ed.get_session().run(qparam1.mean())
array([ 0.14636859,  0.33761254,  0.51601881], dtype=float32)

Generate data \(x_i \sim \text{Multinomial}(1, h(z_1, z_2))\).

ed.get_session().close()
tf.reset_default_graph()  
ed.set_seed(0)

true_param1 = np.zeros((1,1,2,10))
true_param1[0,0,0,2] = 1
true_param1[0,0,1,5] = 1
true_param1 = tf.constant(true_param1, dtype=tf.float32)

true_param2  =  np.zeros((1,5,2,2))
A = np.array([[1,1,0,0,0],[0,4,30,1,0]])
B = np.array([[2,1,0,0,3],[4,5,0,0,0]])
true_param2[0,:,:,0] = np.transpose(A)
true_param2[0,:,:,1] = np.transpose(B)
true_param2 = tf.constant(true_param2, dtype=tf.float32)

def model_latent(z):
  latent = tf.nn.conv2d(true_param2, z, strides=[1,1,1,1], padding="VALID")
  latent = tf.reduce_sum(latent, 3)
  latent = tf.reshape(latent, [10])
  return latent

logits = model_latent(true_param1)
ed.get_session().run(model_latent(tf.ones([1, 1, 2, 10])))
array([  3.,   4.,   2.,   9.,   0.,  30.,   0.,   1.,   3.,   0.], dtype=float32)

N = 5000
data = ed.get_session().run(Categorical(logits=logits).sample(N))
param1 = Dirichlet(tf.ones([1,1,2,10]), name='param1')
w = Categorical(logits=model_latent(param1), sample_shape=[N])

qparam1 = Dirichlet(tf.nn.softplus(tf.Variable(tf.ones([1,1,2,10]))), name="qparam1") 

inference = ed.KLqp({param1: qparam1}, data={w: data})
inference.run()
ed.get_session().run(model_latent(qparam1))
array([ 1.90809596, -0.6436035 ,  1.85443068,  2.44416738,  1.73977566,
-0.3565155 ,  2.36256814,  0.20984885,  2.31142139, -0.44349724], dtype=float32)
ed.get_session().run(qparam1)
array([[[[ 0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1],
[ 0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1]]]], dtype=float32)

Author: Abhishek Sarkar

Created: 2018-05-08 Tue 23:15

Validate