|
|
""" |
|
|
Functions used in several different places. This file should not import from any other non-lib files to prevent |
|
|
circular dependencies. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from copy import copy |
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union |
|
|
|
|
|
TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"} |
|
|
|
|
|
|
|
|
def get_json_from_model_output(input_generated_json: str): |
|
|
""" |
|
|
Parses a string, potentially containing Markdown code fences, into a JSON object. |
|
|
|
|
|
This function attempts to extract and parse a JSON object from a string, |
|
|
often the output of a language model. It handles cases where the JSON |
|
|
is enclosed in Markdown code fences (```json ... ``` or ``` ... ```). |
|
|
If the initial parsing fails, it attempts a more robust parsing using |
|
|
`_get_valid_json_from_string` and |
|
|
logs debug messages indicating success or failure. If all attempts fail, |
|
|
it returns an empty dictionary. |
|
|
|
|
|
Args: |
|
|
input_generated_json: A string potentially containing a JSON object. |
|
|
|
|
|
Returns: |
|
|
A tuple containing: |
|
|
- The parsed JSON object (a dictionary) or an empty dictionary if parsing failed. |
|
|
- An integer representing the number of times parsing failed initially. |
|
|
""" |
|
|
originally_invalid_json_count = 0 |
|
|
|
|
|
generated_json_attempt_1 = copy(input_generated_json) |
|
|
try: |
|
|
code_split = generated_json_attempt_1.split("```") |
|
|
if len(code_split) > 1: |
|
|
generated_json_attempt_1 = json.loads( |
|
|
("```" + code_split[1]).replace("```json", "") |
|
|
) |
|
|
else: |
|
|
generated_json_attempt_1 = json.loads( |
|
|
generated_json_attempt_1.replace("```json", "").replace("```", "") |
|
|
) |
|
|
except Exception as exc: |
|
|
logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.") |
|
|
|
|
|
generated_json_attempt_1 = {} |
|
|
some_value_in_attempt_1_is_not_a_dict = check_contents_valid( |
|
|
generated_json_attempt_1 |
|
|
) |
|
|
attempt_1_failed = ( |
|
|
not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict |
|
|
) |
|
|
generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {} |
|
|
if attempt_1_failed: |
|
|
logging.debug( |
|
|
"Attempting to make output valid to obtain better metrics (this works in limited cases where " |
|
|
"the model output was simply cut off)" |
|
|
) |
|
|
try: |
|
|
code_split = generated_json_attempt_2.split("```") |
|
|
if len(code_split) > 1: |
|
|
generated_json_attempt_2 = json.loads( |
|
|
_get_valid_json_from_string( |
|
|
("```" + code_split[1]).replace("```json", "") |
|
|
) |
|
|
) |
|
|
else: |
|
|
stripped_output = generated_json_attempt_2.replace( |
|
|
"```json", "" |
|
|
).replace("```", "") |
|
|
balance_outcome = attempt( |
|
|
json.loads, (balance_braces(stripped_output),) |
|
|
) |
|
|
if "error" not in balance_outcome: |
|
|
generated_json_attempt_2 = balance_outcome |
|
|
else: |
|
|
generated_json_attempt_2 = json.loads( |
|
|
_get_valid_json_from_string(stripped_output) |
|
|
) |
|
|
|
|
|
logging.debug( |
|
|
"Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..." |
|
|
) |
|
|
except Exception as exc: |
|
|
logging.debug( |
|
|
"Failed. Setting model output as empty JSON to enable metrics comparison." |
|
|
) |
|
|
generated_json_attempt_2 = {} |
|
|
some_value_in_attempt_2_is_not_a_dict = ( |
|
|
attempt_1_failed |
|
|
and isinstance(generated_json_attempt_2, dict) |
|
|
and check_contents_valid(generated_json_attempt_2) |
|
|
) |
|
|
if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict: |
|
|
logging.debug(f"Could not recover model output json, aborting!") |
|
|
originally_invalid_json_count += 1 |
|
|
generated_json = ( |
|
|
generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2 |
|
|
) |
|
|
return generated_json, originally_invalid_json_count |
|
|
|
|
|
|
|
|
def check_contents_valid(generated_json_attempt_1: Union[list, dict]): |
|
|
""" |
|
|
Checks that the sub nodes are not lists or anything |
|
|
|
|
|
Args: |
|
|
generated_json_attempt_1 (Union[list, dict]): data to check |
|
|
|
|
|
Returns: |
|
|
truthy based on contents of input |
|
|
""" |
|
|
if isinstance(generated_json_attempt_1, list): |
|
|
for item in generated_json_attempt_1: |
|
|
if not isinstance(item, dict): |
|
|
return item |
|
|
return None |
|
|
elif ( |
|
|
isinstance(generated_json_attempt_1, dict) |
|
|
and "nodes" in generated_json_attempt_1.keys() |
|
|
): |
|
|
for item in generated_json_attempt_1.get("nodes", []): |
|
|
if not isinstance(item, dict): |
|
|
return item |
|
|
return None |
|
|
else: |
|
|
for item in generated_json_attempt_1.values(): |
|
|
if not isinstance(item, dict): |
|
|
return item |
|
|
return None |
|
|
|
|
|
|
|
|
def _get_valid_json_from_string(s): |
|
|
""" |
|
|
Given a JSON string with potentially unclosed strings, arrays, or objects, close those things |
|
|
to hopefully be able to parse as valid JSON |
|
|
""" |
|
|
double_quotes = 0 |
|
|
single_quotes = 0 |
|
|
brackets = [] |
|
|
|
|
|
for i, c in enumerate(s): |
|
|
if c == '"': |
|
|
double_quotes = 1 - double_quotes |
|
|
elif c == "'": |
|
|
single_quotes = 1 - single_quotes |
|
|
elif c in "{[": |
|
|
brackets.append((i, c)) |
|
|
elif c in "}]": |
|
|
if double_quotes == 0 and single_quotes == 0: |
|
|
if brackets: |
|
|
last_opened = brackets.pop() |
|
|
if (c == "}" and last_opened[1] != "{") or ( |
|
|
c == "]" and last_opened[1] != "[" |
|
|
): |
|
|
raise ValueError( |
|
|
f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} " |
|
|
f"but closed {c} @ {i}" |
|
|
) |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if s.strip().endswith(","): |
|
|
logging.debug("Removing ending ,") |
|
|
s = s.strip().rstrip(",") |
|
|
|
|
|
closing_chars = "" |
|
|
|
|
|
|
|
|
if double_quotes > 0: |
|
|
closing_chars += '"' |
|
|
if single_quotes > 0: |
|
|
closing_chars += "'" |
|
|
|
|
|
|
|
|
while brackets: |
|
|
last_opened = brackets.pop() |
|
|
if last_opened[1] == "{": |
|
|
closing_chars += "}" |
|
|
elif last_opened[1] == "[": |
|
|
closing_chars += "]" |
|
|
|
|
|
logging.debug(f"closing_chars: {closing_chars}") |
|
|
|
|
|
output_string = s + closing_chars |
|
|
|
|
|
try: |
|
|
json.loads(output_string) |
|
|
except Exception: |
|
|
logging.debug( |
|
|
"JSON string still fails to be parseable, attempting another modification..." |
|
|
) |
|
|
|
|
|
|
|
|
new_closing_chars = "" |
|
|
found_first_double_quote = False |
|
|
for char in closing_chars: |
|
|
if not found_first_double_quote and char == '"': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_closing_chars += '": ""' |
|
|
else: |
|
|
new_closing_chars += char |
|
|
|
|
|
logging.debug(f"new closing_chars: {new_closing_chars}") |
|
|
output_string = s + new_closing_chars |
|
|
|
|
|
return output_string |
|
|
|
|
|
|
|
|
def on_fail( |
|
|
outcome: Union[Any, Dict[str, str]], |
|
|
fallback: Union[Any, Callable] = None, |
|
|
): |
|
|
""" |
|
|
Allows you to provide a fallback to recover from a failed outcome. |
|
|
|
|
|
Args: |
|
|
outcome |
|
|
fallback |
|
|
|
|
|
Returns: |
|
|
|
|
|
""" |
|
|
is_fail = isinstance(outcome, dict) and "error" in outcome |
|
|
is_callable = isinstance(fallback, Callable) |
|
|
if is_fail and is_callable: |
|
|
return fallback(outcome) |
|
|
elif is_fail: |
|
|
return fallback |
|
|
return outcome |
|
|
|
|
|
|
|
|
def attempt( |
|
|
func: Callable, |
|
|
args: Tuple[Any, ...] = (), |
|
|
kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Union[Any, Dict[str, str]]: |
|
|
""" |
|
|
Attempts to execute a function with the provided arguments. |
|
|
|
|
|
If the function raises an exception, the exception is caught and returned in a dict. |
|
|
Args: |
|
|
func (Callable): The function to execute. |
|
|
args (Tuple[Any, ...], optional): A tuple of positional arguments for the function. |
|
|
kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function. |
|
|
Returns: |
|
|
Function result or {"error": <msg>} response |
|
|
""" |
|
|
kwargs = kwargs or {} |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as exc: |
|
|
return {"error": str(exc)} |
|
|
|
|
|
|
|
|
def balance_braces(s: str) -> str: |
|
|
""" |
|
|
Primitive function that just tries to add '{}' style braces to try to recover |
|
|
the model string. |
|
|
|
|
|
Args: |
|
|
s(str): string to balance braces on. |
|
|
|
|
|
Returns: |
|
|
provided string with balanced braces if possible |
|
|
""" |
|
|
open_count = s.count("{") |
|
|
close_count = s.count("}") |
|
|
|
|
|
if open_count > close_count: |
|
|
s += "}" * (open_count - close_count) |
|
|
elif close_count > open_count: |
|
|
s = "{" * (close_count - open_count) + s |
|
|
|
|
|
return s |
|
|
|
|
|
|
|
|
def flatten_list(coll): |
|
|
flattened_data = [] |
|
|
for set_list in coll: |
|
|
flattened_data = flattened_data + list(set_list) |
|
|
return flattened_data |
|
|
|
|
|
|
|
|
def keep_errors(collection): |
|
|
""" |
|
|
Given a set of outcomes, keeps any that resulted in an error |
|
|
|
|
|
Args: |
|
|
collection (Collection): collection of outcomes to filter. |
|
|
|
|
|
Returns: |
|
|
All instances of the collection that contain an error response. |
|
|
""" |
|
|
return [instance for instance in collection if "error" in (instance or [])] |
|
|
|