I looked today at the sMMM e-graph array optimization I first started in #44
from collections.abc import Callable import numpy as np def sum_over(arr: np.ndarray, body: Callable[[int, np.ndarray], float]) -> float: n = arr.shape[-1] return float(sum((float(body(k, arr[..., k])) for k in range(n)), 0.0)) # ---------------- ΣMMM (scalar) ---------------- # Q = sum_{i,j,k} A[i,k] * B[k,j] # A ∈ R^{I×K}, B ∈ R^{K×J} def sMMM_triple_struct(A: np.ndarray, B: np.ndarray) -> float: """ Mirrors the IR: (sum k Ak (var A) (sum i Aik (var Ak) (sum j Bkj (get B k) Aik * Bkj))) """ return sum_over( A, lambda k, Ak: sum_over( Ak, lambda i, Aik: sum_over( B[k, ...], lambda j, Bkj: float(Aik) * float(Bkj), ), ), ) def sMMM_factored_struct(A: np.ndarray, B: np.ndarray) -> float: """ Mirrors the equality-saturated IR: (sum k Ak (var A) (* (sum i Aik (var Ak) Aik) (sum j Bkj (get B k) Bkj))) """ return sum_over( A, lambda k, Ak: ( sum_over(Ak, lambda i, Aik: float(Aik)) * # . sum_over(B[k], lambda j, Bkj: float(Bkj)) ), ) # --------- (optional) quick check ---------- if __name__ == "__main__": rng = np.random.default_rng(0) I, K, J = 4, 5, 3 A = rng.integers(-2, 3, size=(I, K)).astype(float) B = rng.integers(-2, 3, size=(K, J)).astype(float) q1 = sMMM_triple_struct(A, B) q2 = sMMM_factored_struct(A, B) q_ref = float((A @ B).sum()) # same as np.einsum('ik,kj->', A, B) assert np.isclose(q1, q_ref) assert np.isclose(q2, q_ref) print(q1, q2, q_ref)
We can see here that the core rewrite being implementing is lifting a multiplication out of a loop like:
Rewrite
sum_over(x, lambda a, b: c * d(a, b)))
toc * sum_over(x, lambda a, b: d(a, b)))
This shows up in the "Optimizing Tensor Programs on Flexible Storage" paper as well:
The question then might be, how would we represent functions in egglog in such a way that we could write this rewrite and then extract out the optimized function?
Next I might look at some more examples, in particular the BATAX
one which their model errored out on. This might be a good chance to explore that as well. That one errored out because of associativity blow up. It could also be nice to try to address this through a multiset operation…
So the paper could also be focused on these array operations, and then implementing them with a couple different new techniques.