From 214163d8227195df301e479f4f2b4c495354d476 Mon Sep 17 00:00:00 2001 From: HypoxiE Date: Mon, 18 Aug 2025 14:30:25 +0700 Subject: [PATCH] =?UTF-8?q?=D0=94=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D0=B5=20=D0=B0=D0=B2=D1=82=D0=BE=20=D0=B4=D0=B8?= =?UTF-8?q?=D1=84=D1=84=D0=B5=D1=80=D0=B5=D0=BD=D1=86=D0=B8=D1=80=D0=BE?= =?UTF-8?q?=D0=B2=D0=B0=D0=BD=D0=B8=D1=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- auto_diff/auto_diff.py | 109 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 auto_diff/auto_diff.py diff --git a/auto_diff/auto_diff.py b/auto_diff/auto_diff.py new file mode 100644 index 0000000..052f7f7 --- /dev/null +++ b/auto_diff/auto_diff.py @@ -0,0 +1,109 @@ +import math + +class Node: + def __init__(self, val, parents=(), op="", label=None): + self.grad = 0.0 + self.val = float(val) + self.parents = parents + self.op = op + self.label = label + + self._backward = lambda: None + + def __repr__(self): + name = f"{self.label}:" if self.label else "" + return f"<{name}{self.op or 'var'} val={self.val:.6g} grad={self.grad:.6g}>" + + @staticmethod + def _to_node(x): + return x if isinstance(x, Node) else Node(x) + + def __add__(self, other): + other = Node._to_node(other) + out = Node(self.val + other.val, parents=[(self, lambda g: g), (other, lambda g: g)], op="+") + + def _backward(): + self.grad += out.grad * 1.0 + other.grad += out.grad * 1.0 + + out._backward = _backward + return out + + def __radd__(self, other): return self + other + + def __neg__(self): + out = Node(-self.val, parents=[(self, lambda g: -g)], op="neg") + def _backward(): + self.grad += -out.grad + out._backward = _backward + return out + + def __sub__(self, other): return self + (-other) + def __rsub__(self, other): return Node._to_node(other) + (-self) + + def __mul__(self, other): + other = Node._to_node(other) + out = Node(self.val * other.val, parents=[(self, lambda g: g * other.val), (other, lambda g: g * self.val)], op="*") + def _backward(): + self.grad += out.grad * other.val + other.grad += out.grad * self.val + out._backward = _backward + return out + + def __rmul__(self, other): return self * other + + def __truediv__(self, other): + other = Node._to_node(other) + out = Node(self.val / other.val, parents=[(self, lambda g: g / other.val), (other, lambda g: -g * self.val / (other.val**2))], op="/") + def _backward(): + self.grad += out.grad / other.val + other.grad += -out.grad * self.val / (other.val**2) + out._backward = _backward + return out + + def __rtruediv__(self, other): return Node._to_node(other) / self + + def __pow__(self, p: float): + # степень фиксируем скаляром p + out = Node(self.val**p, parents=[(self, lambda g: g * p * (self.val ** (p-1)))], op=f"**{p}") + def _backward(): + self.grad += out.grad * p * (self.val ** (p-1)) + out._backward = _backward + return out + + def __rpow__(self, p: float): + # степень фиксируем скаляром p + out = Node(p**self.val, parents=[(self, lambda g: g * p ** self.val * math.log(p))], op=f"{p}**") + def _backward(): + self.grad += out.grad * p ** self.val * math.log(p) + out._backward = _backward + return out + +def backward(loss: Node): + # 1) топологическая сортировка (DFS) + topo, visited = [], set() + def build(u: Node): + if u not in visited: + visited.add(u) + for p, _ in u.parents: + build(p) + topo.append(u) + build(loss) + print(topo, list(reversed(topo)), visited) + # 2) инициализируем dL/dL = 1 и идём в обратном порядке + for n in topo: + n.grad = 0.0 + loss.grad = 1.0 + for node in reversed(topo): + node._backward() + print(node) + +if __name__ == "__main__": + + x = Node(2, label="x") + y = Node(3, label="y") + print(x) + f = 8**x + print(f) + backward(f) + print(x.grad)