41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import matplotlib.pyplot as plt
|
|
import matplotlib.colors as mcolors
|
|
import numpy as np
|
|
|
|
|
|
def plot_dataset(dataset):
|
|
x0 = [dot.x for dot in dataset.train if not dot.classification]
|
|
y0 = [dot.y for dot in dataset.train if not dot.classification]
|
|
x1 = [dot.x for dot in dataset.train if dot.classification]
|
|
y1 = [dot.y for dot in dataset.train if dot.classification]
|
|
|
|
plt.scatter(x0, y0, color='green', label='Class 0')
|
|
plt.scatter(x1, y1, color='red', label='Class 1')
|
|
|
|
def plot_decision_surface(network, resolution=0.02):
|
|
x_min, x_max = -1, 1
|
|
y_min, y_max = -1, 1
|
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, resolution),
|
|
np.arange(y_min, y_max, resolution))
|
|
|
|
# прогоняем сетку через сеть
|
|
Z = np.array([network.predict([x, y]) for x, y in zip(xx.ravel(), yy.ravel())])
|
|
Z = Z.reshape(xx.shape)
|
|
|
|
# закрашиваем фон по вероятности
|
|
plt.contourf(xx, yy, Z, levels=50, cmap='RdYlGn', alpha=0.3)
|
|
|
|
def plot_all(dataset, network):
|
|
plt.figure(figsize=(6,6))
|
|
plot_decision_surface(network)
|
|
plot_dataset(dataset)
|
|
plt.xlim(-1, 1)
|
|
plt.ylim(-1, 1)
|
|
plt.legend()
|
|
|
|
def plt_show():
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pass |