1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| import torch from torch import optim from torch import nn
from matplotlib import pyplot as plt from matplotlib.animation import ArtistAnimation
inputs = torch.unsqueeze(torch.linspace(-4,4, 600),dim=1)
targets = inputs.pow(3) + 1.89 + 3 * torch.randn(inputs.shape)
real = inputs.pow(3) + 1.89
class NetWork(nn.Module): def __init__(self,n_feature,n_hidden,n_output): super(NetWork,self).__init__() self.m = torch.nn.Sequential( torch.nn.Linear(n_feature,n_hidden), torch.nn.ReLU(), torch.nn.Linear(n_hidden,n_output) ) def forward(self,x): return self.m(x)
network = NetWork(1,100,1)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(network.parameters(),lr=0.1)
fig, ax = plt.subplots()
ims = []
for i in range(200): out = network(inputs) loss = loss_fn(out,targets) optimizer.zero_grad() loss.backward() optimizer.step() if i%2 == 0: frame0, = plt.plot(inputs.numpy(),real.numpy(),c='blue',lw='3') frame1, = plt.plot(inputs.numpy(),targets.numpy(),c='orange') frame2, = plt.plot(inputs.numpy(),out.data.numpy(),c='red',lw='3') frame3 = plt.text(1, -6, 'Time=%d Loss=%.4f' % (i, loss.data.numpy()), fontdict={'size': 15, 'color': 'red'}) ims.append([frame0,frame1,frame2,frame3])
ani = ArtistAnimation(fig, ims, interval=100, blit=False, repeat_delay=1000) ani.save('x_3.gif')
|