import functools
import json
import os
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
import requests
from openelm.codegen import model_setup, sample, set_seed, truncate
from openelm.configs import SodaraceELMConfig
from openelm.environments.sodaracer import IMPORTS, SQUARE_PREREQ, Walker
from openelm.utils.code_eval import pool_exec_processes
from openelm.utils.diff_eval import apply_diff, split_diff
[docs]class MutationModel(ABC):
"""Base model class for all mutation models."""
[docs] @abstractmethod
def generate_program(self, code_batch: list[str]) -> list[dict]:
pass
[docs]@dataclass
class FunctionTemplate:
"""
A function template for a mutation model.
Attributes:
func_name: (str) The name of the function that we want to execute.
import_line: (str) The import lines we add to the code.
func_preamble: (str) The function definition, as well as potentially a
few initial lines to generate code.
instruction (str): The instruction we give to the model, before the
preamble.
"""
func_name: str
import_line: str
func_preamble: str
instruction: str
[docs]class PromptMutationModel(MutationModel):
"""Mutation model that uses prompts to change a seed."""
[docs] def __init__(
self,
cfg: SodaraceELMConfig,
function_template: FunctionTemplate,
sandbox_server: str = "http://localhost:5000",
) -> None:
self.cfg: SodaraceELMConfig = cfg
seed: int = set_seed(self.cfg.seed)
# Use RNG to rotate random seeds during inference.
self.rng = np.random.default_rng(seed=seed)
self.sandbox_server = sandbox_server
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.model, self.tokenizer, self.device = model_setup(self.cfg)
self.func_template: FunctionTemplate = function_template
[docs] def construct_prompt(self, code: str) -> tuple[str, str]:
"""
Construct a prompt from a code string.
Args:
code (str): The code string.
Returns:
A tuple of the prompt string and imports plus instruction.
"""
prompt_str = (
code + self.func_template.instruction + self.func_template.func_preamble
)
preamble_str = (
self.func_template.import_line
+ self.func_template.instruction
+ self.func_template.func_preamble
)
return prompt_str, preamble_str
[docs] def generate_program(self, code_batch: list[str]) -> list[dict]:
"""
Generate a new program from a batch of programs.
Given a piece of code, do prompt mutation, execute the code,
and return the result.
Args:
code (str): The full code string.
Returns:
A numpy array (if successful) or the exception object.
"""
prompts, preamble_strings = zip(*map(self.construct_prompt, code_batch))
encodings = self.tokenizer(
list(prompts),
truncation=True,
padding=True,
return_tensors="pt",
)
completions: list[str] = sample(
encodings,
self.cfg,
self.model,
self.tokenizer,
batch_size=1,
)
local_scope_exec: bool = len(self.func_template.func_preamble) > 0
trunc = functools.partial(truncate, only_local_scope=local_scope_exec)
self.truncations: list[str] = [
preamble_strings[i] + trunc(completions[i]) for i in range(len(completions))
]
if self.cfg.sandbox:
results = []
for code in self.truncations:
resp = self._get_response(code, self.cfg.timeout)
if resp.status_code == 200:
return_dict = json.loads(resp.text)
results.append(return_dict)
else:
results = pool_exec_processes(
self.truncations,
func_name=self.func_template.func_name,
timeout=self.cfg.timeout,
processes=self.cfg.processes,
debug=self.cfg.debug,
)
return self._post_process(results)
@abstractmethod
def _get_response(self, code: str, timeout: float) -> requests.models.Response:
raise NotImplementedError
@abstractmethod
def _post_process(self, results: list) -> list:
raise NotImplementedError
[docs]class PromptMutationForSodarace(PromptMutationModel):
[docs] def __init__(self, cfg, sandbox_server="http://localhost:5000") -> None:
function_template = FunctionTemplate(
func_name="make_walker",
import_line=IMPORTS + SQUARE_PREREQ,
instruction="",
func_preamble="def make_walker():\n",
)
super().__init__(cfg, function_template, sandbox_server)
def _get_response(self, code: str, timeout: float) -> requests.models.Response:
return requests.post(
f"{self.sandbox_server}/gen_racer",
json={"code": code, "timeout": timeout},
timeout=timeout,
)
def _post_process(self, results: list) -> list:
if self.cfg.sandbox:
return results
else:
result_list: list = []
for i, result in enumerate(results):
try:
if isinstance(result, Walker) and result.validate():
result_list.append(
{
"program_str": self.truncations[i],
"result_obj": result.to_dict(),
}
)
else:
if self.cfg.debug:
print("Failed execution, type:", result)
print(self.truncations[i])
except Exception as e:
if self.cfg.debug:
print(type(e), e)
return result_list
[docs]class PromptMutationForImgTask(PromptMutationModel):
[docs] def __init__(self, cfg, sandbox_server="http://localhost:5000") -> None:
func_name = "draw"
func_preamble = (
f'def {func_name}():\n\t"""Draw a yellow circle.\n'
'\t"""\n\tpic = np.zeros((32, 32, 3))\n'
)
function_template = FunctionTemplate(
func_name=func_name,
import_line="import math\nimport numpy as np",
func_preamble=func_preamble,
instruction="",
)
super().__init__(cfg, function_template, sandbox_server)
[docs] def reset_shape(self, shape: tuple):
func_name = self.func_template.func_name
self.func_preamble = f'def {func_name}():\n\t"""Draw a yellow circle.\n\t"""\n\tpic = np.zeros({shape})\n'
def _get_response(self, code: str, timeout: float) -> requests.models.Response:
func_name = self.func_template.func_name
return requests.post(
f"{self.sandbox_server}/eval_imageoptim_func",
json={"code": code, "func_name": func_name, "timeout": timeout},
timeout=timeout,
)
def _post_process(self, results: list) -> list:
for i in range(len(results)):
results[i]["result_obj"] = np.array(results[i]["result_obj"])
return results
[docs]class DiffModel(PromptMutationModel):
[docs] def __init__(
self,
cfg: SodaraceELMConfig,
function_template: FunctionTemplate,
sandbox_server: str = "http://localhost:5000",
) -> None:
super().__init__(cfg, function_template, sandbox_server)
[docs] def construct_prompt(self, code: str) -> tuple[str, str]:
prompt_list = [
"<NME> walker.py\n<BEF> ",
code,
"\n<MSG> Fixed bugs",
]
prompt_str = "".join(prompt_list)
prompt_str = (
code + self.func_template.instruction + self.func_template.func_preamble
)
preamble_str = (
self.func_template.import_line
+ self.func_template.instruction
+ self.func_template.func_preamble
)
return prompt_str, preamble_str
[docs] def generate_program(self, code_batch: list[str]) -> list[dict]:
"""
Generate a new program for a diff model from a batch of programs.
Given a piece of code, do prompt mutation, execute the code,
and return the result.
Args:
code (str): The full code string.
Returns:
A numpy array (if successful) or the exception object.
"""
prompts, preamble_strings = zip(*map(self.construct_prompt, code_batch))
encodings = self.tokenizer(
list(prompts),
truncation=True,
padding=True,
return_tensors="pt",
)
completions: list[str] = sample(
encodings,
self.cfg,
self.model,
self.tokenizer,
batch_size=1,
)
local_scope_exec: bool = len(self.func_template.func_preamble) > 0
end_of_diff = re.compile("\n[^ +-@]+")
trunc = functools.partial(truncate, only_local_scope=local_scope_exec)
self.truncations: list[str] = [
preamble_strings[i] + trunc(completions[i]) for i in range(len(completions))
]
outputs = []
for i, code in enumerate(self.truncations):
# split the diff text according to <NME>, <BEF>, <MSG>, <DFF>.
parsed: dict = split_diff(code)
# truncate the diff hunk at the first line not starting with " ",
# "+", "-", or "@".
if parsed and all(
(s in parsed for s in ["name", "file", "message", "diff"])
):
diff_hunk: str = end_of_diff.split(parsed["diff"])[0]
nme_idx: int = diff_hunk.find("<NME>")
if nme_idx != -1:
diff_hunk = diff_hunk[:nme_idx]
outputs.append(apply_diff(prompts[i], diff_hunk))
if self.cfg.sandbox:
results = []
for code in outputs:
resp = self._get_response(code, self.cfg.timeout)
if resp.status_code == 200:
return_dict = json.loads(resp.text)
results.append(return_dict)
else:
results = pool_exec_processes(
outputs,
func_name=self.func_template.func_name,
timeout=self.cfg.timeout,
processes=self.cfg.processes,
debug=self.cfg.debug,
)
return self._post_process(results)
[docs]class DiffModelForSodarace(DiffModel):
[docs] def __init__(self, cfg, sandbox_server="http://localhost:5000") -> None:
function_template = FunctionTemplate(
func_name="make_walker",
import_line=IMPORTS + SQUARE_PREREQ,
instruction="",
func_preamble="def make_walker():\n",
)
super().__init__(cfg, function_template, sandbox_server)
def _get_response(self, code: str, timeout: float) -> requests.models.Response:
return requests.post(
f"{self.sandbox_server}/gen_racer",
json={"code": code, "timeout": timeout},
timeout=timeout,
)
def _post_process(self, results: list) -> list:
if self.cfg.sandbox:
return results
else:
result_list: list = []
for i, result in enumerate(results):
try:
if isinstance(result, Walker) and result.validate():
result_list.append(
{
"program_str": self.truncations[i],
"result_obj": result.to_dict(),
}
)
else:
if self.cfg.debug:
print("Failed execution, type:", result)
print(self.truncations[i])
except Exception as e:
if self.cfg.debug:
print(type(e), e)
return result_list