Edward stuff
Setup
%matplotlib inline %config InlineBackend.figure_formats = set(['svg']) import edward as ed import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.special as sp import tensorflow as tf
Learn directed Bayesian network
N = 1000 np.random.seed(0) cloudy = np.random.uniform(size=N) < .5 rain = np.random.uniform(size=N) < np.where(cloudy, .5, .6) sprinkler = np.random.uniform(size=N) < np.where(cloudy, .8, .3) wet = np.random.uniform(size=N) < .99 * rain * sprinkler + .1 * (~rain + ~sprinkler) data = np.vstack((cloudy, wet)).astype(int).T
MLE in Edward
\[ x_i \sim N(\mu, 1) \]
N = 100 np.random.seed(0) mu = np.random.normal() x = np.random.normal(loc=mu, scale=1, size=(N, 1)).astype(np.float32) px = ed.models.Normal(loc=tf.Variable(tf.zeros([1])) * tf.ones([N, 1]), scale=tf.ones([1])) inf = ed.KLqp(data={px: x}) inf.run() print(mu, x.mean(), ed.get_session().run(px.mean()[0, 0]))
1.764052345967664 1.82505 1.82505
Mixture of regressions
\[ y \mid z = N(x beta . z + mu . z, sigma . z) \]
N = 500 np.random.seed(0) x = np.random.normal(size=N) z = np.random.uniform(size=N) < .7 beta = np.where(z, 1, 5) mu = np.where(z, 3, 1) sigma = np.where(z, .2, .3) y = x * beta + mu + np.random.normal(scale=sigma, size=N)
K = 2 x_ph = tf.placeholder(tf.float32, [N, 1]) p_z = ed.models.OneHotCategorical(logits=tf.ones([N, K])) p_beta = ed.models.Normal(loc=tf.zeros([K, 1]), scale=tf.ones([K, 1])) p_mu = ed.models.Normal(loc=tf.zeros([K, 1]), scale=tf.ones([K, 1])) p_sigma = ed.models.InverseGamma(concentration=tf.ones([K, 1]), rate=tf.ones([K, 1])) p_y = ed.models.Normal(loc=x_ph * tf.matmul(tf.cast(p_z, tf.float32), p_beta) + tf.matmul(tf.cast(p_z, tf.float32), p_mu), scale=tf.matmul(tf.cast(p_z, tf.float32), p_sigma)) q_z = ed.models.OneHotCategorical(logits=tf.Variable(tf.ones([N, K]))) q_beta = ed.models.PointMass(params=tf.Variable(tf.zeros([K, 1]))) q_mu = ed.models.PointMass(params=tf.Variable(tf.zeros([K, 1]))) q_sigma = ed.models.PointMass(params=tf.nn.softplus(tf.Variable(tf.ones([K, 1]))))
E = ed.KLqp( data={ x_ph: x.astype(np.float32).reshape(-1, 1), p_y: y.astype(np.float32).reshape(-1, 1), p_beta: q_beta, p_mu: q_mu, p_sigma: q_sigma, }, latent_vars={ q_z: p_z, }) M = ed.MAP( data={ x_ph: x.astype(np.float32).reshape(-1, 1), p_y: y.astype(np.float32).reshape(-1, 1), p_z: q_z, }, latent_vars={ p_beta: q_beta, p_mu: q_mu, p_sigma: q_sigma, })
E.initialize() M.initialize() ed.get_session().run(tf.global_variables_initializer()) for i in range(1): for _ in range(100): res0 = E.update() res = M.update() print(res0['loss'], res['loss'])
ed.get_session().run( [q_z.probs, tf.reshape(q_beta, [-1]), tf.reshape(q_mu, [-1]), tf.reshape(q_sigma, [-1])])
[array([[ 0.71002579, 0.28997418], [ 0.48043731, 0.51956272], [ 0.38829982, 0.61170018], [ 0.45394057, 0.54605943], [ 0.60275078, 0.39724925], [ 0.56599832, 0.43400165], [ 0.42639849, 0.57360154], [ 0.58124381, 0.41875616], [ 0.55708236, 0.44291764], [ 0.39320481, 0.60679525], [ 0.45562571, 0.54437423], [ 0.53851324, 0.46148679], [ 0.43844658, 0.56155342], [ 0.46309105, 0.53690892], [ 0.4733164 , 0.52668363], [ 0.38441676, 0.61558324], [ 0.63565713, 0.3643429 ], [ 0.59845626, 0.40154377], [ 0.40672451, 0.59327549], [ 0.42584515, 0.57415485], [ 0.38219059, 0.61780947], [ 0.2090683 , 0.7909317 ], [ 0.57429117, 0.4257088 ], [ 0.57058483, 0.4294152 ], [ 0.61770642, 0.38229361], [ 0.51058275, 0.48941731], [ 0.56709599, 0.43290398], [ 0.41292256, 0.58707744], [ 0.50020838, 0.49979153], [ 0.57302105, 0.42697892], [ 0.51298302, 0.48701701], [ 0.66017324, 0.33982676], [ 0.54463935, 0.45536068], [ 0.51573318, 0.48426682], [ 0.32967111, 0.67032892], [ 0.29039744, 0.70960259], [ 0.50654632, 0.49345359], [ 0.53788632, 0.46211365], [ 0.61870939, 0.38129056], [ 0.48693788, 0.51306206], [ 0.53016818, 0.46983185], [ 0.48576871, 0.51423126], [ 0.44574535, 0.55425465], [ 0.48011023, 0.51988977], [ 0.4652572 , 0.53474277], [ 0.54088908, 0.45911092], [ 0.5020507 , 0.49794933], [ 0.38352886, 0.61647111], [ 0.47106972, 0.52893025], [ 0.15485103, 0.84514898], [ 0.4664011 , 0.53359896], [ 0.30208117, 0.69791883], [ 0.50660372, 0.49339628], [ 0.47084692, 0.52915305], [ 0.64215273, 0.3578473 ], [ 0.57530797, 0.42469206], [ 0.57450587, 0.42549413], [ 0.51293218, 0.48706779], [ 0.54589957, 0.45410049], [ 0.42113671, 0.57886332], [ 0.66159713, 0.33840284], [ 0.53784293, 0.46215701], [ 0.45368984, 0.54631019], [ 0.60997987, 0.3900201 ], [ 0.65123248, 0.34876749], [ 0.37624544, 0.62375456], [ 0.53248215, 0.46751788], [ 0.40773889, 0.59226114], [ 0.29519737, 0.70480263], [ 0.61828786, 0.38171217], [ 0.56785941, 0.43214053], [ 0.43748084, 0.56251913], [ 0.61715895, 0.38284105], [ 0.61459643, 0.38540357], [ 0.43189934, 0.56810063], [ 0.44795257, 0.55204749], [ 0.3988784 , 0.6011216 ], [ 0.59767079, 0.40232918], [ 0.53826112, 0.46173885], [ 0.4576849 , 0.54231513], [ 0.42446387, 0.57553613], [ 0.39003873, 0.60996127], [ 0.38686582, 0.61313421], [ 0.36126259, 0.63873738], [ 0.5862332 , 0.41376686], [ 0.52707517, 0.47292483], [ 0.60705346, 0.39294654], [ 0.47495189, 0.52504814], [ 0.32205719, 0.67794281], [ 0.47547859, 0.52452147], [ 0.51039815, 0.48960185], [ 0.38047826, 0.61952174], [ 0.30331373, 0.69668621], [ 0.64933407, 0.35066593], [ 0.58514827, 0.4148517 ], [ 0.58274013, 0.4172599 ], [ 0.33704156, 0.66295844], [ 0.51428324, 0.48571673], [ 0.33593464, 0.66406536], [ 0.65539223, 0.34460768], [ 0.54574114, 0.45425886], [ 0.58412308, 0.4158769 ], [ 0.53556174, 0.46443826], [ 0.46937943, 0.53062057], [ 0.61514378, 0.38485622], [ 0.41186836, 0.58813161], [ 0.47821105, 0.52178895], [ 0.55419874, 0.44580123], [ 0.44436356, 0.55563647], [ 0.45664942, 0.54335058], [ 0.4024387 , 0.5975613 ], [ 0.48182285, 0.51817721], [ 0.44133377, 0.55866629], [ 0.43799591, 0.56200403], [ 0.39235443, 0.60764557], [ 0.31296471, 0.68703526], [ 0.34281242, 0.65718758], [ 0.43191525, 0.56808472], [ 0.30023715, 0.69976282], [ 0.44088694, 0.55911309], [ 0.52743107, 0.47256896], [ 0.53777021, 0.46222973], [ 0.75503856, 0.24496143], [ 0.41529727, 0.58470267], [ 0.5416109 , 0.45838913], [ 0.60855252, 0.39144745], [ 0.56409019, 0.43590978], [ 0.60207796, 0.39792207], [ 0.43606761, 0.56393248], [ 0.60557753, 0.39442244], [ 0.63035196, 0.36964798], [ 0.42812946, 0.57187057], [ 0.40494242, 0.59505755], [ 0.41584125, 0.58415878], [ 0.52344775, 0.47655222], [ 0.42652997, 0.57347006], [ 0.61050302, 0.38949701], [ 0.46699005, 0.53300995], [ 0.54979461, 0.45020539], [ 0.62898892, 0.37101102], [ 0.61520046, 0.38479954], [ 0.37691894, 0.62308109], [ 0.59981608, 0.40018392], [ 0.34312534, 0.65687466], [ 0.39728737, 0.60271257], [ 0.52067167, 0.4793283 ], [ 0.5152083 , 0.48479176], [ 0.76850319, 0.2314968 ], [ 0.5312919 , 0.46870807], [ 0.29801649, 0.70198351], [ 0.51338738, 0.48661265], [ 0.40474135, 0.59525859], [ 0.5791117 , 0.42088836], [ 0.55248088, 0.44751909], [ 0.51236618, 0.48763379], [ 0.35830188, 0.64169812], [ 0.47314191, 0.52685809], [ 0.48809928, 0.51190078], [ 0.4083834 , 0.59161657], [ 0.27255264, 0.72744727], [ 0.5082413 , 0.49175876], [ 0.46433172, 0.53566825], [ 0.49951553, 0.50048453], [ 0.41685823, 0.58314174], [ 0.40450916, 0.59549081], [ 0.63181913, 0.36818093], [ 0.5749262 , 0.42507377], [ 0.59248257, 0.40751746], [ 0.51416624, 0.4858337 ], [ 0.4571057 , 0.5428943 ], [ 0.60158944, 0.39841059], [ 0.49728325, 0.50271678], [ 0.44007769, 0.55992228], [ 0.58551878, 0.41448122], [ 0.46919104, 0.53080893], [ 0.50912935, 0.49087059], [ 0.40775898, 0.59224099], [ 0.5088917 , 0.49110833], [ 0.38495445, 0.61504549], [ 0.47667813, 0.52332181], [ 0.63478494, 0.36521503], [ 0.47245151, 0.52754849], [ 0.35618833, 0.6438117 ], [ 0.564062 , 0.43593797], [ 0.3959046 , 0.6040954 ], [ 0.65318727, 0.3468127 ], [ 0.42544624, 0.57455379], [ 0.33444902, 0.66555095], [ 0.43941447, 0.56058556], [ 0.49924177, 0.50075823], [ 0.46403575, 0.53596431], [ 0.58585149, 0.41414857], [ 0.44646743, 0.55353248], [ 0.54713833, 0.45286173], [ 0.43269864, 0.56730133], [ 0.36501101, 0.63498908], [ 0.39428982, 0.60571021], [ 0.76552308, 0.23447697], [ 0.567922 , 0.43207803], [ 0.59970969, 0.40029037], [ 0.58460778, 0.41539231], [ 0.24526666, 0.75473338], [ 0.26229888, 0.73770118], [ 0.53816956, 0.46183044], [ 0.41838986, 0.58161014], [ 0.28172556, 0.71827441], [ 0.73562717, 0.26437283], [ 0.57271713, 0.42728293], [ 0.67609668, 0.32390332], [ 0.60700214, 0.39299786], [ 0.54531962, 0.45468035], [ 0.79055327, 0.2094468 ], [ 0.50417173, 0.49582824], [ 0.36504081, 0.63495922], [ 0.67351651, 0.32648349], [ 0.48000541, 0.51999462], [ 0.54124218, 0.45875776], [ 0.49603325, 0.50396675], [ 0.52238524, 0.47761476], [ 0.68409765, 0.31590229], [ 0.64417231, 0.35582769], [ 0.35448414, 0.64551586], [ 0.42745626, 0.5725438 ], [ 0.50095332, 0.49904668], [ 0.63306636, 0.36693358], [ 0.50398809, 0.49601182], [ 0.4958204 , 0.50417954], [ 0.47330362, 0.52669644], [ 0.70182741, 0.29817262], [ 0.28064489, 0.71935511], [ 0.2785818 , 0.72141826], [ 0.50518888, 0.49481112], [ 0.55749619, 0.44250387], [ 0.48519689, 0.51480311], [ 0.50366759, 0.49633238], [ 0.53342515, 0.46657488], [ 0.62646013, 0.37353992], [ 0.59017068, 0.40982935], [ 0.51482952, 0.48517048], [ 0.48711106, 0.51288891], [ 0.68457484, 0.31542522], [ 0.34068 , 0.65932 ], [ 0.43009904, 0.56990105], [ 0.6076529 , 0.3923471 ], [ 0.42084491, 0.57915503], [ 0.47324702, 0.52675301], [ 0.37014487, 0.6298551 ], [ 0.48246294, 0.51753706], [ 0.5963726 , 0.4036274 ], [ 0.62080538, 0.37919465], [ 0.50523442, 0.49476561], [ 0.52521145, 0.47478855], [ 0.59784442, 0.40215558], [ 0.49823123, 0.50176883], [ 0.6080004 , 0.39199957], [ 0.42829162, 0.57170838], [ 0.67309278, 0.32690722], [ 0.51378357, 0.48621637], [ 0.63700259, 0.36299741], [ 0.54894733, 0.45105267], [ 0.43539476, 0.56460518], [ 0.46971148, 0.53028852], [ 0.64895523, 0.35104477], [ 0.42190167, 0.57809836], [ 0.49398458, 0.50601542], [ 0.51754951, 0.48245054], [ 0.49798828, 0.50201166], [ 0.41089019, 0.58910984], [ 0.27574155, 0.72425842], [ 0.46945363, 0.53054637], [ 0.38797656, 0.61202347], [ 0.42296138, 0.57703859], [ 0.41252366, 0.58747631], [ 0.47734672, 0.52265328], [ 0.36938927, 0.63061076], [ 0.32669395, 0.67330605], [ 0.60283309, 0.39716697], [ 0.38714966, 0.61285037], [ 0.50990999, 0.49009001], [ 0.5869168 , 0.4130832 ], [ 0.497078 , 0.50292206], [ 0.49703607, 0.5029639 ], [ 0.40899321, 0.59100676], [ 0.65545309, 0.34454691], [ 0.3524608 , 0.64753926], [ 0.31158391, 0.68841612], [ 0.47008419, 0.52991581], [ 0.42623898, 0.57376099], [ 0.55362576, 0.44637421], [ 0.68412167, 0.31587836], [ 0.35418135, 0.64581865], [ 0.25532928, 0.74467069], [ 0.43706328, 0.56293672], [ 0.66076267, 0.33923739], [ 0.4919304 , 0.50806957], [ 0.45944887, 0.54055119], [ 0.63495106, 0.36504894], [ 0.67940694, 0.32059309], [ 0.62442881, 0.37557113], [ 0.51191801, 0.48808205], [ 0.37476775, 0.62523222], [ 0.74282789, 0.25717211], [ 0.42770943, 0.57229054], [ 0.43511936, 0.56488067], [ 0.44826511, 0.55173486], [ 0.31072369, 0.68927634], [ 0.41287073, 0.58712924], [ 0.59510809, 0.40489194], [ 0.70356369, 0.29643634], [ 0.37164 , 0.62835997], [ 0.3765277 , 0.62347233], [ 0.58470893, 0.41529107], [ 0.59956676, 0.40043321], [ 0.51642925, 0.48357072], [ 0.55386239, 0.44613764], [ 0.68319321, 0.31680682], [ 0.23681407, 0.76318586], [ 0.3577981 , 0.6422019 ], [ 0.55373985, 0.44626021], [ 0.69725144, 0.3027485 ], [ 0.55120802, 0.44879198], [ 0.49777493, 0.50222498], [ 0.4483172 , 0.55168277], [ 0.50391287, 0.49608716], [ 0.41267136, 0.58732873], [ 0.50214213, 0.49785787], [ 0.63410002, 0.36590004], [ 0.33282155, 0.66717845], [ 0.30768725, 0.69231272], [ 0.49322575, 0.50677419], [ 0.35773757, 0.64226246], [ 0.54091048, 0.45908952], [ 0.7552458 , 0.2447542 ], [ 0.5050723 , 0.49492779], [ 0.45320013, 0.5467999 ], [ 0.32543126, 0.67456871], [ 0.35045615, 0.64954382], [ 0.42995608, 0.57004386], [ 0.44718689, 0.55281317], [ 0.42096293, 0.57903707], [ 0.37363988, 0.62636012], [ 0.43621966, 0.56378037], [ 0.3766996 , 0.62330037], [ 0.65387052, 0.34612948], [ 0.51497632, 0.48502368], [ 0.50727457, 0.49272546], [ 0.51152873, 0.48847124], [ 0.45450982, 0.54549009], [ 0.51226872, 0.48773128], [ 0.39616162, 0.60383838], [ 0.49920642, 0.50079358], [ 0.55853719, 0.44146281], [ 0.55030149, 0.44969845], [ 0.31780696, 0.68219304], [ 0.42395264, 0.5760473 ], [ 0.52481049, 0.47518945], [ 0.39745817, 0.60254186], [ 0.44410568, 0.55589437], [ 0.57883829, 0.42116168], [ 0.57630509, 0.42369485], [ 0.4114745 , 0.58852559], [ 0.27175623, 0.72824377], [ 0.28102168, 0.71897835], [ 0.57849431, 0.42150572], [ 0.49072948, 0.50927049], [ 0.53144264, 0.46855742], [ 0.63396198, 0.36603802], [ 0.57280034, 0.42719963], [ 0.49170384, 0.50829613], [ 0.53744221, 0.46255773], [ 0.37506071, 0.6249392 ], [ 0.6688059 , 0.33119407], [ 0.51032859, 0.48967141], [ 0.52401632, 0.47598365], [ 0.54270416, 0.45729586], [ 0.60252625, 0.39747372], [ 0.34316862, 0.65683138], [ 0.49852461, 0.50147545], [ 0.43459946, 0.5654006 ], [ 0.54910624, 0.45089373], [ 0.58483511, 0.41516492], [ 0.37603599, 0.62396401], [ 0.38383642, 0.61616361], [ 0.46504709, 0.53495294], [ 0.49452743, 0.50547254], [ 0.45659745, 0.54340255], [ 0.6947999 , 0.30520016], [ 0.34478918, 0.65521085], [ 0.76010448, 0.23989552], [ 0.52163577, 0.47836426], [ 0.49834624, 0.50165367], [ 0.57489723, 0.4251028 ], [ 0.70049012, 0.29950988], [ 0.49580365, 0.50419629], [ 0.48210365, 0.51789629], [ 0.43704611, 0.56295389], [ 0.24105078, 0.75894922], [ 0.46682742, 0.53317261], [ 0.40607652, 0.59392345], [ 0.6035645 , 0.3964355 ], [ 0.50269431, 0.49730566], [ 0.40303802, 0.59696203], [ 0.46847877, 0.53152126], [ 0.50786275, 0.49213725], [ 0.42432216, 0.57567793], [ 0.45127711, 0.54872292], [ 0.57181513, 0.42818487], [ 0.42151466, 0.57848537], [ 0.43607289, 0.56392717], [ 0.45397151, 0.54602844], [ 0.43076172, 0.56923831], [ 0.68091029, 0.31908965], [ 0.48943397, 0.51056612], [ 0.55664223, 0.44335777], [ 0.72561008, 0.27438992], [ 0.47178993, 0.52821004], [ 0.51078868, 0.48921132], [ 0.45309776, 0.54690224], [ 0.4728457 , 0.52715427], [ 0.47694656, 0.52305341], [ 0.50450379, 0.4954963 ], [ 0.55760914, 0.44239086], [ 0.46785116, 0.53214884], [ 0.50984597, 0.490154 ], [ 0.41991714, 0.58008289], [ 0.4738881 , 0.52611196], [ 0.42884421, 0.57115573], [ 0.67225224, 0.32774773], [ 0.45308289, 0.5469172 ], [ 0.55035019, 0.44964984], [ 0.44749242, 0.55250758], [ 0.45291388, 0.54708612], [ 0.1964258 , 0.80357426], [ 0.55424738, 0.44575259], [ 0.43516347, 0.5648365 ], [ 0.59923226, 0.40076771], [ 0.51310593, 0.4868941 ], [ 0.625319 , 0.37468103], [ 0.37344772, 0.62655228], [ 0.44075784, 0.55924207], [ 0.34691727, 0.65308279], [ 0.33231381, 0.66768616], [ 0.45853186, 0.54146808], [ 0.58474028, 0.41525975], [ 0.46883255, 0.53116739], [ 0.584396 , 0.41560403], [ 0.42681029, 0.57318974], [ 0.55868268, 0.44131738], [ 0.65386695, 0.34613308], [ 0.28968138, 0.71031857], [ 0.65495002, 0.34504998], [ 0.686454 , 0.313546 ], [ 0.46859944, 0.5314005 ], [ 0.44909573, 0.55090433], [ 0.63541949, 0.36458051], [ 0.30905333, 0.69094664], [ 0.51697308, 0.48302689], [ 0.66233909, 0.33766091], [ 0.6001389 , 0.39986119], [ 0.40635231, 0.59364772], [ 0.46703085, 0.53296912], [ 0.40248233, 0.59751767], [ 0.42627519, 0.57372481], [ 0.34972823, 0.65027183], [ 0.51255959, 0.48744044], [ 0.57648027, 0.42351976], [ 0.53118485, 0.46881515], [ 0.44005758, 0.55994248], [ 0.41567108, 0.58432889], [ 0.37844101, 0.62155902], [ 0.62293488, 0.37706509], [ 0.55375665, 0.44624338], [ 0.49558094, 0.50441915], [ 0.46740434, 0.53259563], [ 0.45204943, 0.54795063], [ 0.42922086, 0.57077909], [ 0.45012459, 0.54987544], [ 0.52738768, 0.47261229], [ 0.68961829, 0.31038171], [ 0.42298055, 0.57701945], [ 0.49630976, 0.50369024], [ 0.31605911, 0.68394089], [ 0.66076434, 0.33923566], [ 0.48949784, 0.51050222], [ 0.64120275, 0.35879728], [ 0.42200032, 0.57799971], [ 0.48909491, 0.51090509], [ 0.36553371, 0.63446629], [ 0.49850166, 0.50149834], [ 0.51601541, 0.48398453], [ 0.36317748, 0.63682252], [ 0.40290901, 0.59709102], [ 0.58002025, 0.41997972], [ 0.68031371, 0.31968635], [ 0.51672632, 0.48327366], [ 0.29602346, 0.70397657], [ 0.42756143, 0.57243854], [ 0.48997787, 0.5100221 ], [ 0.33808461, 0.66191542], [ 0.43681616, 0.56318378]], dtype=float32), array([ 1.93082905, 2.06234193], dtype=float32), array([ 2.3322289, 2.3039763], dtype=float32), array([ 2.00949097, 2.00899529], dtype=float32)]
Logistic
from edward.models import * ed.get_session().close() tf.reset_default_graph() ed.set_seed(0) N = 400 D = 10 T = 5000 noise_std = 0.1 np.random.seed(0) X = np.random.normal(size=(N, D)) w0 = np.random.normal(size=(D, 1)) b0 = 0.0 y = sp.expit(X.dot(w0) + b0) threshold = np.random.uniform(size=(N, 1)) y = np.less(threshold, y) y = y.ravel().astype(int) X = X.reshape((N, D)) X_ph = tf.placeholder(tf.float32, [N, D]) w = Normal(loc=tf.zeros([D]), scale=3 * tf.ones([D])) b = Normal(loc=tf.zeros([]), scale=3 * tf.ones([])) py = Bernoulli(logits=ed.dot(X_ph, w) + b) # INFERENCE qw = Empirical(params=tf.Variable(tf.random_normal([T, D]))) qb = Empirical(params=tf.Variable(tf.random_normal([T])))
inference = ed.HMC({w: qw, b: qb}, data={X_ph: X, py: y}) inference.run() w_hat, w_ci = ed.get_session().run([qw.mean(), 1.96 * qw.stddev()])
plt.clf() plt.errorbar(np.arange(D), y=w_hat, yerr=w_ci, c='k', fmt='+', label='Estimated') plt.scatter(np.arange(D), w0, c='r', s=16, label='True') plt.legend() plt.xlabel('Predictor') plt.xticks(np.arange(D), np.arange(D)) plt.ylabel('Effect size')
<matplotlib.text.Text at 0x7fc3c94f7b70>
Residual variance
ed.get_session().close() tf.reset_default_graph() Normal = ed.models.Normal def build_toy_dataset(N, w, noise_sd=0.1, data_sd=1): D = len(w) x = np.random.normal(0, data_sd, size=(N, D)) y = np.dot(x, w) + np.random.normal(0, noise_sd, size=N) return x, y ### Generate the data # Note that data_sd >> noise_sd N = 1000 D = 5 w_true = np.random.randn(D) noise_sd = 0.1 data_sd = 5 X_train, y_train = build_toy_dataset(N, w_true, noise_sd=noise_sd, data_sd=data_sd) ### Define the model X = tf.placeholder(tf.float32, [N, D]) w = Normal(loc=tf.zeros(D), scale=tf.ones(D)) b = Normal(loc=tf.zeros(1), scale=tf.ones(1)) log_sd = Normal(loc=tf.zeros(1), scale=tf.ones(1)) y = Normal(loc=ed.dot(X, w) + b, scale=tf.exp(log_sd)) qw = Normal(loc=tf.get_variable("qw/loc", [D]), scale=tf.nn.softplus(tf.get_variable("qw/scale", [D]))) qb = Normal(loc=tf.get_variable("qb/loc", [1]), scale=tf.nn.softplus(tf.get_variable("qb/scale", [1]))) qlog_sd = Normal(loc=tf.get_variable("qlog_sd/loc", [1]), scale=tf.nn.softplus(tf.get_variable("qlog_sd/scale", [1]))) ### Variational Inference # Many samples and iterations, just to go sure inference = ed.ReparameterizationKLKLqp({w: qw, b: qb, log_sd: qlog_sd}, data={X: X_train, y: y_train}) inference.run() ### Print estimates sess = ed.get_session() # w is unbiased... print("w hat: {}".format(sess.run(qw.mean()))) print("w true: {}".format(w_true)) # sd is overestimated.. print("sd hat: {}".format(sess.run(tf.exp(qlog_sd.mean())))) print("sd true: {}".format(noise_sd)) print('sd CI: {}'.format(sess.run([tf.exp(qlog_sd.mean() - 1.96 * qlog_sd.stddev()), tf.exp(qlog_sd.mean() + 1.96 * qlog_sd.stddev())])))
Bayesian RNN
curl -OL https://ti.arc.nasa.gov/m/project/prognostic-repository/CMAPSSData.zip
unzip -d cmaps CMAPSSData.zip
In the training data, the time series ends at failure. The units of the predicted output should be "Remaining Useful Life" (the number of cycles before the unit fails).
x_train = pd.read_table('cmaps/train_FD001.txt', header=None, sep=' ') y_train = x_train.groupby(0).apply(lambda x: pd.Series(np.arange(x.shape[0], 0, -1)))
Construct an Elman network with a Bayesian prior on the weight matrices.
ed.get_session().close() tf.reset_default_graph() ed.set_seed(0) Normal = ed.models.Normal H = 50 # number of hidden units D = 22 # number of features def rnn_cell(hprev, xt): return tf.tanh(ed.dot(hprev, Wh) + ed.dot(xt, Wx) + bh) Wh = Normal(loc=tf.zeros([H, H]), scale=tf.ones([H, H])) Wx = Normal(loc=tf.zeros([D, H]), scale=tf.ones([D, H])) Wy = Normal(loc=tf.zeros([H, 1]), scale=tf.ones([H, 1])) bh = Normal(loc=tf.zeros(H), scale=tf.ones(H)) by = Normal(loc=tf.zeros(1), scale=tf.ones(1)) x = tf.placeholder(tf.float32, [None, D]) h = tf.scan(rnn_cell, x, initializer=tf.zeros(H)) y = Normal(loc=tf.matmul(h, Wy) + by, scale=1.0)
_N = ed.models.NormalWithSoftplusScale qWh = _N(loc=tf.get_variable('qW_h/loc', [H, H]), scale=tf.get_variable('qW_h/scale', [H, H])) qWx = _N(loc=tf.get_variable('qW_x/loc', [D, H]), scale=tf.get_variable('qW_x/scale', [D, H])) qWy = _N(loc=tf.get_variable('qW_y/loc', [H, 1]), scale=tf.get_variable('qW_y/scale', [H, 1])) qbh = _N(loc=tf.get_variable('qb_h/loc', [H]), scale=tf.get_variable('qb_h/scale', [H])) qby = _N(loc=tf.get_variable('qb_y/loc', [1]), scale=tf.get_variable('qb_y/scale', [1]))
The training data are arranged as variable length time series, indexed by
column 0
.
series = x_train.groupby(0).groups inf = ed.KLqp(latent_vars={Wh: qWh, Wx: qWx, Wy: qWy, bh: qbh, by: qby}) inf.initialize(logdir='log') ed.get_session().run(tf.global_variables_initializer()) for i in range(inf.n_iter): res = inf.update( feed_dict={x: x_train.iloc[series[(i % len(series)) + 1],2:24].values, y: y_train.iloc[series[(i % len(series)) + 1]].values.reshape(-1, 1)}) inf.print_progress(res)
The test data are partial time series, plus reported number of cycles until failure.
x_test = pd.read_table('cmaps/test_FD001.txt', header=None, sep=' ') y_test = pd.read_table('cmaps/RUL_FD001.txt', header=None)
Get the posterior distribution of predictions by sampling from the variational approximation.
def predict(x_test): return ed.get_session().run( ed.copy(y[-1], dict_swap={Wh: qWh, Wx: qWx, Wy: qWy, bh: qbh, by: qby}), feed_dict={x: x_test.iloc[:,2:24].values})
Bayesian NN
Sample from a non-linear decision problem.
import sklearn.model_selection as skms import sklearn.datasets as skd import sklearn.preprocessing as skp X, Y = skd.make_moons(noise=0.2, random_state=0, n_samples=1000) X = skp.scale(X) X_train, X_test, Y_train, Y_test = skms.train_test_split(X, Y, test_size=.5) n, p = X_train.shape
Plot the training data.
plt.clf() plt.gcf().set_size_inches(6, 6) plt.scatter(X_train[Y_train == 0,0], X_train[Y_train == 0,1], c='b', s=3) plt.scatter(X_train[Y_train == 1,0], X_train[Y_train == 1,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba916b2d68>
Check that an MLP can successfully classify the data.
tf.reset_default_graph() with tf.variable_scope('weights'): w0 = tf.get_variable('w0', [p, 5]) w1 = tf.get_variable('w1', [5, 5]) w2 = tf.get_variable('w2', [5, 1]) x = tf.placeholder(tf.float32, [None, p], name="X") y = tf.placeholder(tf.float32, [n, 1], name="Y") h = tf.tanh(tf.matmul(x, w0)) h = tf.tanh(tf.matmul(h, w1)) h = tf.matmul(h, w2) loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=h)) train = tf.train.AdamOptimizer().minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): _, f = sess.run([train, loss], {x: X_train, y: Y_train.reshape(-1, 1)}) if not i % 100: print(i, f) yhat = sess.run(tf.reshape(h > 0.5, [-1]), {x: X_test})
plt.gcf().set_size_inches(6, 6) plt.scatter(X_test[Y_test == yhat,0], X_test[Y_test == yhat,1], c='k', s=3) plt.scatter(X_test[Y_test != yhat,0], X_test[Y_test != yhat,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba68dc9c18>
Train a Bayesian MLP.
ed.get_session().close() tf.reset_default_graph() ed.set_seed(0) def neural_network(Xp): h = tf.tanh(tf.matmul(Xp, W_0)) h = tf.tanh(tf.matmul(h, W_1)) h = tf.sigmoid(tf.matmul(h, W_2)) return tf.reshape(h, [-1]) n_hidden = 5 # MODEL Normal = ed.models.Normal Bernoulli = ed.models.Bernoulli with tf.name_scope("model"): W_0 = Normal(loc=tf.zeros([X.shape[1], n_hidden]), scale=tf.ones([X.shape[1], n_hidden]), name="W_0") W_1 = Normal(loc=tf.zeros([n_hidden, n_hidden]), scale=tf.ones([n_hidden, n_hidden]), name="W_1") W_2 = Normal(loc=tf.zeros([n_hidden, 1]), scale=tf.ones([n_hidden, 1]), name="W_2") Xp = tf.placeholder(tf.float32, [None, X.shape[1]], name="X") f = neural_network(Xp) Y = Bernoulli(f, name="Y") with tf.variable_scope("posterior"): with tf.variable_scope("qW_0"): loc = tf.get_variable("loc", [X.shape[1], n_hidden]) scale = tf.nn.softplus(tf.get_variable("scale", [X.shape[1], n_hidden])) qW_0 = Normal(loc=loc, scale=scale) with tf.variable_scope("qW_1"): loc = tf.get_variable("loc", [n_hidden, n_hidden]) scale = tf.nn.softplus(tf.get_variable("scale", [n_hidden, n_hidden])) qW_1 = Normal(loc=loc, scale=scale) with tf.variable_scope("qW_2"): loc = tf.get_variable("loc", [n_hidden, 1]) scale = tf.nn.softplus(tf.get_variable("scale", [n_hidden, 1])) qW_2 = Normal(loc=loc, scale=scale)
inference = ed.KLqp({W_0: qW_0, W_1: qW_1, W_2: qW_2}, data={Xp: X_train, Y: Y_train}) inference.run(n_iter=3000, logdir='/scratch/midway2/aksarkar/nwas/log/')
py = ed.copy(Y, {W_0: qW_0, W_1: qW_1, W_2: qW_2}) yhat = ed.get_session().run(py, {Xp: X_test})
plt.gcf().set_size_inches(6, 6) plt.scatter(X_test[Y_test == yhat,0], X_test[Y_test == yhat,1], c='k', s=3) plt.scatter(X_test[Y_test != yhat,0], X_test[Y_test != yhat,1], c='r', s=3)
<matplotlib.collections.PathCollection at 0x7fba68267240>
Convnet
from edward.models import *
Generate some data \(x_i \sim \text{Multinomial}(1, \mathbf{p})\).
N = 5000 true_param1 = np.array([.15, .35, .5]) data = np.random.choice(a=len(true_param1), p=true_param1, size=N)
Fit the model.
ed.get_session().close() tf.reset_default_graph() ed.set_seed(0) param1 = Dirichlet(tf.ones([3])) w = Categorical(logits=param1, sample_shape=[N]) qparam1 = Dirichlet(tf.Variable(tf.ones([3]))) inference = ed.KLqp({param1: qparam1}, data={w: data}) inference.run(n_iter=2000) ed.get_session().run(qparam1.mean())
array([ 0.14636859, 0.33761254, 0.51601881], dtype=float32)
Generate data \(x_i \sim \text{Multinomial}(1, h(z_1, z_2))\).
ed.get_session().close() tf.reset_default_graph() ed.set_seed(0) true_param1 = np.zeros((1,1,2,10)) true_param1[0,0,0,2] = 1 true_param1[0,0,1,5] = 1 true_param1 = tf.constant(true_param1, dtype=tf.float32) true_param2 = np.zeros((1,5,2,2)) A = np.array([[1,1,0,0,0],[0,4,30,1,0]]) B = np.array([[2,1,0,0,3],[4,5,0,0,0]]) true_param2[0,:,:,0] = np.transpose(A) true_param2[0,:,:,1] = np.transpose(B) true_param2 = tf.constant(true_param2, dtype=tf.float32) def model_latent(z): latent = tf.nn.conv2d(true_param2, z, strides=[1,1,1,1], padding="VALID") latent = tf.reduce_sum(latent, 3) latent = tf.reshape(latent, [10]) return latent logits = model_latent(true_param1) ed.get_session().run(model_latent(tf.ones([1, 1, 2, 10])))
array([ 3., 4., 2., 9., 0., 30., 0., 1., 3., 0.], dtype=float32)
N = 5000 data = ed.get_session().run(Categorical(logits=logits).sample(N))
param1 = Dirichlet(tf.ones([1,1,2,10]), name='param1') w = Categorical(logits=model_latent(param1), sample_shape=[N]) qparam1 = Dirichlet(tf.nn.softplus(tf.Variable(tf.ones([1,1,2,10]))), name="qparam1") inference = ed.KLqp({param1: qparam1}, data={w: data}) inference.run() ed.get_session().run(model_latent(qparam1))
array([ 1.90809596, -0.6436035 , 1.85443068, 2.44416738, 1.73977566, -0.3565155 , 2.36256814, 0.20984885, 2.31142139, -0.44349724], dtype=float32)
ed.get_session().run(qparam1)
array([[[[ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]], dtype=float32)