1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import torch from torch import optim from torch import nn
from matplotlib import pyplot as plt
import matplotlib.animation as animation
inputs = torch.unsqueeze(torch.linspace(-10,10, 600),dim=1)
targets = 2.55*inputs + 1.89 + torch.normal(0, 1, inputs.shape)
real = 2.55*inputs + 1.89
class NetWork(nn.Module): def __init__(self): super(NetWork,self).__init__() self.linear = torch.nn.Linear(1,1) def forward(self,x): x = self.linear(x) return x
network = NetWork()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(network.parameters(),lr=0.001)
fig, ax = plt.subplots() ims = [] for i in range(101): out = network(inputs) loss = loss_fn(out,targets) optimizer.zero_grad() loss.backward() optimizer.step() if i%2 == 0: frame0, = ax.plot(inputs.numpy(),real.numpy(),c='blue',lw='3') frame1, = ax.plot(inputs.numpy(),targets.numpy(),c='orange') frame2, = ax.plot(inputs.numpy(),out.data.numpy(),c='red',lw='3') title = ax.text(1, -6, 'Time=%d Loss=%.4f' % (i, loss.data.numpy()), fontdict={'size': 15, 'color': 'red'}) ims.append([frame0,frame1,frame2,title]) ani = animation.ArtistAnimation(fig, ims, interval=50, blit=False, repeat_delay=1000) ani.save('dynamic_images.gif')
|