diff --git a/auto_diff/auto_diff.py b/auto_diff/auto_diff.py index 052f7f7..5a1c6bc 100644 --- a/auto_diff/auto_diff.py +++ b/auto_diff/auto_diff.py @@ -63,21 +63,45 @@ class Node: def __rtruediv__(self, other): return Node._to_node(other) / self - def __pow__(self, p: float): - # степень фиксируем скаляром p - out = Node(self.val**p, parents=[(self, lambda g: g * p * (self.val ** (p-1)))], op=f"**{p}") + def __pow__(self, other): + other = Node._to_node(other) + out = Node(self.val**other.val, parents=[(self, lambda g: g * other.val * (self.val ** (other.val-1))), (other, lambda g: g * (self.val ** other.val) * math.log(self.val))], op=f"**{other.val}") def _backward(): - self.grad += out.grad * p * (self.val ** (p-1)) + self.grad += out.grad * other.val * (self.val ** (other.val-1)) + other.grad += out.grad * (self.val ** other.val) * math.log(self.val) out._backward = _backward return out def __rpow__(self, p: float): - # степень фиксируем скаляром p + # основание фиксируем скаляром p out = Node(p**self.val, parents=[(self, lambda g: g * p ** self.val * math.log(p))], op=f"{p}**") def _backward(): self.grad += out.grad * p ** self.val * math.log(p) out._backward = _backward return out + + def exp(self): return math.exp(1) ** self + + def sin(self): + out = Node(math.sin(self.val), parents=[(self, lambda g: g * math.cos(self.val))], op=f"sin") + def _backward(): + self.grad += out.grad * math.cos(self.val) + out._backward = _backward + return out + + def cos(self): + out = Node(math.cos(self.val), parents=[(self, lambda g: -g * math.sin(self.val))], op=f"cos") + def _backward(): + self.grad += -out.grad * math.sin(self.val) + out._backward = _backward + return out + + def log(self): + out = Node(math.log(self.val), parents=[(self, lambda g: g / self.val)], op=f"log") + def _backward(): + self.grad += out.grad / self.val + out._backward = _backward + return out def backward(loss: Node): # 1) топологическая сортировка (DFS) @@ -103,7 +127,7 @@ if __name__ == "__main__": x = Node(2, label="x") y = Node(3, label="y") print(x) - f = 8**x + f = (2*x).exp() print(f) backward(f) print(x.grad)