Производительность больше вам не помешает
This commit is contained in:
10
main.py
10
main.py
@@ -9,12 +9,16 @@ dataset = generate.generate_dataset(100)
|
||||
|
||||
# Создаём и обучаем сеть
|
||||
nn = neuro_defs.SimpleNN()
|
||||
nn.train(dataset.train, dataset.train_answs, epochs=100)
|
||||
epoch = 100
|
||||
for i in range(epoch):
|
||||
nn.train(dataset.train, dataset.train_answs, epochs=1)
|
||||
|
||||
if epoch % 10 == 0:
|
||||
print("*"*(i//10) + "-"*((epoch-i)//10))
|
||||
|
||||
# Проверяем на новой точке
|
||||
for dot in range(len(dataset.test)):
|
||||
print(nn.forward(dataset.test[dot]).val, dataset.test_answs[dot])
|
||||
print(nn.forward(*dataset.test[dot]).val, dataset.test_answs[dot])
|
||||
print()
|
||||
print(nn.w_out.val, nn.b_out.val)
|
||||
# visual.plot_dataset(dataset)
|
||||
# visual.plt_show()
|
||||
Reference in New Issue
Block a user