ImProver Overview

Background

ImProver, is a LLM-powered AI agent for proof optimization tasks, built around a general-purpose neural theorem proving framework.

It allows for arbitrary formal Lean proofs to be optimized for an arbitrary metric, empowering users to freely and automatically optimize their formal proofs by just providing a metric to score against. It is built upon an general framework that allows for strict control over how LLMs interface with theorems and integrating the LLM pipeline with symbolic data from the Lean compiler directly, providing a higher degree of accuracy and control over how proofs are (re)written.

With the rise in popularity of modern interactive theorem provers, there is a need for greater control over how formal proof are structured and developed, and by allowing for higher control over what metric to rewrite a given proof for, ImProver shows that modern language models can be taught to perform such optimization to a high degree of accuracy and effectiveness.

Some metrics we explore include:

  • Length:

    Optimizing for length allows mathematicians to write more concise and more efficient proofs.

  • Readability:

    Optimizing for more readable proofs (according to a quantifiable standard of readability) facilitates better understanding of complex proofs and provides a distinct pedagogical advantage.

  • Completion:

    One may reframe the complete neural theorem proving problem (i.e. fully generating proofs from scratch) as a "metric" to optimize for.

Optimizing for these arbitrary auxilary metrics, on top of simply generating semantically and syntactically correct proofs, allows for mathematicians and ML researchers to have far greater control over netural theorem proving and proof optimization tasks, enabling more efficient, readable, intuitive, and reusable proof generation at scale.

Features

To create the ImProver agent, we integrate many symbolic and prompting features to provide more detailed context to the model, and use it to improve the accuracy and improvement abilities of the black-box generator model. Namely, we use:

  • Chain-of-States

    We extract and parse proof tree datastructures from the Lean elaborator in order to get metavariable hypotheses and goal states after each tactic invocation. These are interleaved with tactics to forward this symbolic information effectively to the LLM.

    Without CoS
    
    example : s ∩ t ∪ s ∩ u ⊆ s ∩ (t ∪ u)  := by
      rintro x (⟨xs, xt⟩ | ⟨xs, xu⟩)
      · use xs; left; exact xt
      . use xs; right; exact xu
                    
    With CoS
    
    example : s ∩ t ∪ s ∩ u ⊆ s ∩ (t ∪ u)  := by
      rintro x (⟨xs, xt⟩ | ⟨xs, xu⟩)
      /-
      case inl.intro
      α : Type u_1
      s t u : Set α
      x : α
      xs : x ∈ s
      xt : x ∈ t
      ⊢ x ∈ s ∩ (t ∪ u)
      case inr.intro
      α : Type u_1
      s t u : Set α
      x : α
      xs : x ∈ s
      xu : x ∈ u
      ⊢ x ∈ s ∩ (t ∪ u)
      -/
      · use xs; left; exact xt
      /-
      Goals Solved!
      -/
      . use xs; right; exact xu
      /-
      Goals Solved!
      -/
                            
  • Symbolic dependency search

    Oftentimes, theorems are dependent on definitions and lemmas outside the current module. We symbolically retrieve the types, statements, and full defintions of these dependencies, and filter the most relavent and important ones as context for the theorem.

  • Output Formatters

    We analyze the affect of enforced output schemas on the LLM performance by considering proofs as simple strings, sequences of tactics, or trees of tactics.

  • Best-of-N sampling

    We apply a standard Best-of-N sampling method with the following score function:

    \[S(y,y')=\begin{cases} \max(y,y',\text{key: } x\mapsto \mu(x)),&E(y)=E(y')=0\\ y,&E(y)=0, E(y')>0\\ y',&E(y)>0, E(y')=0\\ \min(y,y',\text{key: } x\mapsto E(x)),&E(y)=E(y')>0\\ \end{cases}\]

    Where for an output \(y\), \(\mu(y)\) is the metric score, and \(E(y)\) is the number of errors in the proof.

  • Refinement and Error Correction

    We identify and corrects errors in the generated proofs by iteratively refining its outputs. Each iteration carries information on the last prev_num iterations, including input, output, metric score, correctness, and error messages.

    This iterative refinement is combined with the Best-of-N sampling to create compound sampling functions.

  • RAG

    We use MMR-based RAG document retrieval to augment the prompt with examples relevant to the optimization of the specific metric, syntax help, and lemmas from Mathlib.

Installation and Usage

For a more detailed usage guide (including method and custom metric configurations), follow the instructions in the README.

  1. Clone the Github Repo locally on your machine and download the python packages from requirements.txt (Requires Python 3.11+)

  2. Set up JSON build configuration in ./configs/. For example:

    [
      {
          "path": "/Users/user/Desktop/lean-project",
          "lean": "leanprover/lean4:v4.9.0",
          "name": "LeanProject",
          "import_file": "LeanProject.lean",
          "imports": ["LeanProject"]
      },
      {
          "repo": "https://github.com/leanprover-community/mathlib4",
          "commit": "v4.9.0",
          "lean": "leanprover/lean4:v4.9.0",
          "name": "mathlib",
          "import_file": "Mathlib.lean",
          "imports": ["Mathlib"],
          "build": false
      }
    ]

    Note that the built Lean 4 projects must be on version 4.9.0+

  3. Run the build script and cache the outputs in ./.cache/ by running:

    python scripts/build.py --config configs/CONFIG_FILE_NAME.json

  4. Configure the run parameters using the get_methods or improver functions in your script by importing from ./benchmark/tools.py

    For configuring these parameters with custom tunings and custom metrics, see the README or demo