From 0d37b51dbca990a9a0b0ee80f356155a9d2c7f96 Mon Sep 17 00:00:00 2001 From: Anooshka Pendyal Date: Mon, 3 Nov 2025 19:28:39 -0500 Subject: [PATCH 1/2] Draft: moving over citation_exists.py requirement and corresponding test; also added eyecite useage for LLM output parsing --- mellea_contribs/reqlib/citation_exists.py | 186 ++++++++++++++++++++++ test/test_citation_exists.py | 52 ++++++ 2 files changed, 238 insertions(+) create mode 100644 mellea_contribs/reqlib/citation_exists.py create mode 100644 test/test_citation_exists.py diff --git a/mellea_contribs/reqlib/citation_exists.py b/mellea_contribs/reqlib/citation_exists.py new file mode 100644 index 0000000..40c5577 --- /dev/null +++ b/mellea_contribs/reqlib/citation_exists.py @@ -0,0 +1,186 @@ +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.base import Context, CBlock + +import json +import os +import re +from eyecite import get_citations, clean_text +from typing import Any + +# region: citation_exists function and helpers + +def normalize_case_name(name) -> str: + """ + Converts a case name to a standard format. + + Args: + name: A string representing the case name. + + Returns: + A normalized case name. + """ + # 1. Lowercase everything + name = name.lower() + + # 2. Normalize 'vs', 'vs.', 'v', 'versus' to 'v.' + name = re.sub(r'\b(vs\.?|versus|v)(?!\.)\b', 'v.', name) + + # 3. Remove all non-alphanumeric characters except periods, spaces, and apostrophes + name = re.sub(r"[^a-z0-9.& ']+", '', name) + + # 4. Replace multiple spaces with a single space + name = re.sub(r'\s+', ' ', name) + + return name.strip() + +# might not be needed +# def ensure_list_of_dicts(obj: Any) -> list[dict]: +# """ +# Normalize any JSON-like object into a list of dictionaries. + +# Accepts: +# - A JSON string (object or array) +# - A single dict +# - A list of dicts + +# Args: +# obj: Any data type, ideally something that can unpacked into a dictionary + +# Returns: +# The unpacked object in list of dictionary form or raises an error. +# """ +# # JSON string +# if isinstance(obj, str): +# try: +# obj = json.loads(obj) +# except json.JSONDecodeError as e: +# raise ValueError(f"Invalid JSON string: {e!s}") + +# # Single dict +# if isinstance(obj, dict): +# return [obj] + +# # List of dicts +# if isinstance(obj, list): +# if all(isinstance(item, dict) for item in obj): +# return obj +# else: +# raise ValueError("List contains non-dictionary elements") + +# raise TypeError(f"Unsupported metadata format: {type(obj)}") + +# alternatively: +# should this take in last_output instead of the whole context? +# get case name: take LLM output and extract case name --> a string which you get from ctx.last_output() is the input +# so the argument should be ctx.last_output.value: str + +def extract_case_names(ctx: Context) -> list[str]: + """ + Given an LLM output, use eyecite to parse the text and collect case names. + + Args: + ctx: An LLM output that may contain multiple citations. + + Returns: + A list of case names. + """ + # should i clean text?? + + # install hyperscan if not already installed + # !pip install hyperscan + # tokenizer = HyperscanTokenizer(cache_dir=".test_cache") + # citations = get_citations(cleaned_text, tokenizer=tokenizer) + + # or this? + # cleaned_text = clean_text(text, ["html", "all_whitespace"]) + # citations = get_citations(cleaned_text) + + # get_citations outputs a list of citations + citations = get_citations(ctx.last_output().value) + case_names = set() + + for citation in citations: + plaintiff = citation.metadata.get("plaintiff") + defendant = citation.metadata.get("defendant") + if plaintiff and defendant: + case_names.add(f"{plaintiff} v. {defendant}") + # name = citation.metadata['plaintiff'] + " v. " + citation.metadata['defendant'] + # case_names.add(name) + + return list(case_names) + +def citation_exists(ctx: Context, case_metadata: list[dict]) -> ValidationResult: + """ + Given an LLM output and a list of dictionaries, checks that list (which represents a collection of + case metadata json files) to see if the given case names can be found in it. + + Args: + ctx: Context that contains the case names we're checking for + case_metadata: a list of dictionaries which represents a collection of case metadata json files + + Returns: + A validation result indicating if a match was found between given case names and database + """ + if ctx is None: + return ValidationResult(False, reason="No context provided in output") + + # 1) this will spit out a bunch of words --> look through to extract case names + # 2) use eyecite (might have to do some conversion) + last_output = ctx.last_output() + + # if last_output is None or not getattr(output, "value", None): + if last_output is None: + return ValidationResult(False, reason="No last output found in context") + + # 3) run checking + # call get_case_name func + case_names = extract_case_names(ctx) + + if not case_names or not isinstance(case_names, list[str]): + return ValidationResult(False, reason="No case names provided in output") + + normalized_case_names = [normalize_case_name(case_name) for case_name in case_names] + + case_names = set() + case_name_abb = set() + + # add name and name_abbreviation from the database + for case in case_metadata: + if 'name' in case: + case_names.add(normalize_case_name(case['name'])) + if 'name_abbreviation' in case: + case_name_abb.add(normalize_case_name(case['name_abbreviation'])) + + # Check both name and name_abbreviation + for normalized_case_name in normalized_case_names: + if normalized_case_name not in case_names and normalized_case_name not in case_name_abb: + # probably want to change this to the actual case name at some point + # maybe keep a tuple structure or something + return ValidationResult(False, reason=f"'{normalized_case_name}' not found in database") + + return ValidationResult(True, reason="All case names found in database") + + # check if this code chunk is right later + # db_names = {normalize_case_name(c["name"]) for c in case_metadata if "name" in c} + # db_abbrevs = { + # normalize_case_name(c["name_abbreviation"]) for c in case_metadata if "name_abbreviation" in c + # } + + # for name in normalized_output_names: + # if name not in db_names and name not in db_abbrevs: + # return ValidationResult(False, reason=f"Case '{name}' not found in database") + + # return ValidationResult(True, reason="All case names found in database") + + +class CaseNameExistsInDatabase(Requirement): + """ + Checks if the output case name exists in the provided case metadata database. + """ + def __init__(self, case_metadata: str): + self._case_metadata = case_metadata + super().__init__( + description="The case name should exist in the provided case metadata database.", + validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata), + ) +# endregion \ No newline at end of file diff --git a/test/test_citation_exists.py b/test/test_citation_exists.py new file mode 100644 index 0000000..3298b81 --- /dev/null +++ b/test/test_citation_exists.py @@ -0,0 +1,52 @@ +import pytest +from mellea.mellea.stdlib.reqlib.citation_exists import normalize_case_name, citation_exists + +# Mock context for testing citation_exists + +# make up my own model outputs + +# can just check if case names are in one json file +class MockContext: + def __init__(self, case_name): + self._case_name = case_name + + def last_output(self): + return type("MockOutput", (), {"value": self._case_name})() + + +# region: normalize_case_name tests +@pytest.mark.parametrize("raw_name,expected", [ + ("BOB VS SHMEEGUS", "bob v. shmeegus"), + ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", "william payne executor of john payne v. william dudley executor of fleet"), + ("Ozwald v. Dickinson's Ex'rs", "ozwald v. dickinson's ex'rs"), + ("Fox & al. v. Cosby", "fox & al. v. cosby"), + ("Groves v. Graves", "groves v. graves"), + ("Ozwald, Deniston, & Co. v. Dickinson's Ex'rs", "ozwald deniston & co. v. dickinson's ex'rs"), + ("Bobby- versus shmeegy", "bobby v. shmeegy") +]) + +def test_normalize_case_name(raw_name, expected): + assert normalize_case_name(raw_name) == expected +# endregion + +# region: citation_exists tests +@pytest.mark.parametrize("case_name,expected", [ + ("Bob v. Shmeegus", False), + ("Gimli versus Legolas", False), + ("Groves v. Graves", True), + ("William Payne, Executor of John Payne v. William Dudley Executor of Fleet", True), + ("Payne v. Dudley", True), + ("Fox & al. v. Cosby", True), + ("Fox v. Cosby", True), +]) + +def test_citation_exists(tmp_path, case_name, expected): + # create mock context + ctx = MockContext(case_name) + # path to metadata folder + # db_folder = "/Users/anooshkapendyal/Desktop/mellea/mellea/test/stdlib_basics/legal/cases_metadata" + + result = citation_exists(ctx, db_folder) + assert result.as_bool() == expected, result.reason + +# endregion \ No newline at end of file From a930aa5811097ecf6d24b3bb219e0b26207410ff Mon Sep 17 00:00:00 2001 From: Anooshka Pendyal Date: Mon, 1 Dec 2025 17:48:02 -0500 Subject: [PATCH 2/2] fix: using citeurl correctly in citation_exists --- .../__pycache__/__init__.cpython-313.pyc | Bin 0 -> 170 bytes .../citation_exists.cpython-313.pyc | Bin 0 -> 6894 bytes mellea_contribs/reqlib/citation_exists.py | 278 +++++++++--------- 3 files changed, 144 insertions(+), 134 deletions(-) create mode 100644 mellea_contribs/reqlib/__pycache__/__init__.cpython-313.pyc create mode 100644 mellea_contribs/reqlib/__pycache__/citation_exists.cpython-313.pyc diff --git a/mellea_contribs/reqlib/__pycache__/__init__.cpython-313.pyc b/mellea_contribs/reqlib/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ba4eadaaf5556fadcfc0513a29ceb7d1493c51 GIT binary patch literal 170 zcmey&%ge<81QAZ0nIQTxh=2h`DC08=kTI1Zok5e)ZzV$!6Oi{ABz4PEKeRZts8~NS zFF(IHBRjDmH7}(yF-Jc)H76%EQ8zh1ucRn5sTj(Shx3b43v)7)^yA|*^D;}~ literal 0 HcmV?d00001 diff --git a/mellea_contribs/reqlib/__pycache__/citation_exists.cpython-313.pyc b/mellea_contribs/reqlib/__pycache__/citation_exists.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27973de95c9bf2a8ed19010590a24ebc822f6239 GIT binary patch literal 6894 zcmai2U2xl0cD?{X5(FucmS{fal$o8Wx3dq*)4J0ace>M`E@NJ@=k-zw@2D+0-Nvc(&qcBYA+3ukgcrLiNn%L5`4*h(Z)@f{b&N z8|Nu6<2OGM7#FB89;89`9heA>hiQ0Rq$2wkCL-e!l~_JF5gm`w7|Vwyn#P-H^LU)b z*>`v%F}{!PV|j6+WxSQPvV3GBNs};7WMV(v&)#jcjlJ7xdpA*}nGPj7b6}<;2%-)N zO01LI7HFr^#L7FMq?wg;DREYE5K0oPc4A)c0+gmWr)gdy<)y2HFyhM%pTLo-hr1{&HVNN-MJ2%w}MQUb2t^N~+94=|| zI#c!yzwTtciOFTS>H4DmUXaOs)OuO|K=jFU%7i4R0=R7dV1^t$G$FQh7(h2$PIazBo%2~QCn^r-F zMmki0aI#h?P-C`O(2vVYy6hSeRK0BG$Hufg*ew%KmTe1`qZQbi7Bvb7L%!pl3yaEG z&6MXfBQI;&tVMIM%=~g%asn`0sfQUkm^FxiqU;o5L@pgxtOebd*p_ofj!B$_Bc>T9 z+s_fn;iR!oIVVJQ&9=;xND=TH33mn0lx{m=7$28LTd*diEiUS2&Jn1-NF4#s4Hcjw z6{e_FC(QI^Q76JOAfN3BpcmT->6BX3c0{kjiLerN!JfDAe4)nRajp(#Y@OPJnrT_~ zt@kyYZ&}L^F6jBZuAM;nVV~{#x9WdW{e!$QJLnCs>UUugc6xEyiJ>8=1xrOiXdli6 zKl@`yR>;;t5^wv^So!mfMEiQ8uaf9{G*nF-TY39=v}xtyMpJxc?5pIVd*ZWb%fnNv zne}LYCEC9x{wnsX*hVbzVB-G7s`&S@Ph!7~9j=8)Q|p7t`;+B^r+rU@Ndiiy!@YS)gNB1B(L6!Jd5`{3a%NI_^GmZ>UlK2 zGO-Dw5l39q=5-W$2Noe7Aku6;gv0U?*^Na!VQ5rG6CBBVrhUBeN(k{0jL-1XsE3yx zzdrCMkj6(iz=OcFj~$I+0NFrNd;t4=#K^S~a>GXpg;xR@f+(X1!@SKaLHzR&Q6_*8 z!nBVY?r4}QA%E;(Cg`EjuCc@4KlY7%-iX)zIroyGmmK%8HwSB)<7LtV#w3&Z;MqVA znd4Go=?~I7hHW5xF;Gfp!A+&}+LDagtfSRc-vP;*nUldhUBp9Z>$_+znR!dg!TGm! zc|k8|;3l;`fGYF-9r)&%2VmSDQ|*E}1gQezjV_NXSEl@tb6;`g=dS1^KEXr`qWfMM>(`Fbg6`}(;D8IPmglDO>1n}et?pU}0dIV~1 zJeCZ=WZH|Z57d>a;;yQm7V$LGA3*&m_W4Ih0D+Q3I`A~oxjMVilKlA9pS`-?GEiw5 zsJ0xr7v6}st;c&R@t$hD_nxpJ#vVlOM;;z~Ds}^xgEA-l4yT;w;Dh4O)r$JM>^sEJG%p<{_wnf;myE*oE(7kvwjKFH>J^F*okdl z=?>fqKsWqtnvei)LFzu!dmvTccIJfK>U)q2aEJRo4oeqaNE0W9}X&> z(e2b>`zl6cw399a8Cayc4bB~z%C7U5 zmQsQvg8M4!;INzk@L4B8H3Q6s*(dPcB6v}pB*oL$VAIh3+KV>sn_M@yq$dno82UPF zPKTkzd+@VgfrMFFEcxKf-=0~GRilT&y9wgDaG)X_cq(*lMB@*J?+<@?x-7KOerWE9 z2s^W|Gdt9XiePGd12P|hnp(*_oX!7$*hp)^fmr(#4jUQ#BoI@ii8gO?1CYA8IS2(x zzm`TVaOXbDL!^8A;x>Y}NUb|Y4)z*VgXrn4Z-<&Z2@wH(CWmABn{R-|5lmquMjl;V zeV%Quqj^utJAQ@TH>YsZOVwv zL}`AV?K9f-0GWh2C(pU|zD*ZlCGJYwOKu>6W>1)C2VwNcwsyP{%>-f0SSA7_NWeZe z-vf6FX5@Utox#n|adV_jH4@qd5KgKG7#ONtGVLND13(Of6#5+ z1T|tJ8G;MM6U)p&S4-9auQZ6f@Ep|}?>LSAQItKy-f>j zAzm@()7MI^IE!~t;dW1#ns#W2f*FyLJO+Zn%0U?rM^Xb++1PYsYb72cyx=HmsMN;7 zI+sL&Fy8i+aWOM?^g_t`NJ#dHk5d?-> zhtq8gMO{XPsmrS3-u~6YJ30xC{t|vRo`n@s>mXwEr$1p4RXIIXiC-y;SDv@*|9J3c zgVmP)m5Jx!mP)v@UVo($zgiZrejSe9o2Z66em+_a_tdKGM0#nG=@ zUs*d*<{8haFUH}@$x>mE*Xm5=VHR;^5o>Na+t378Q#-4QzK8{tp&OQ`3Iw1!A zB>ga4lYoF#TLOqY^6Q7p2+=g zEr89wW1i`=WY=ht|2!#;js!pNj*p%Uetv?({K+7sS=ep=3Dk3&=nCHjJ}^Tlrx403 z91SQu6_mhCa3+LkI>cwl5T^)mTPVUkWLODOF+2?EX#et1UZQvpLBXb}2Q@C{)d7&1y=N*mp3w%%NDwn<8~;4VIE>g zk`ks$7p{4c7#>^MzGW5j7$M?46|7^2lx&%vKn+oM@HONSt?=9YGf2*H+b;|<v8tU(E91A%IT@6N3Jkdir2_F-Of^jA&=8kzb=ED zCUwqC1?V-DDrgtMJRH%R2g);0PX7WcBdVGO$g@?|M)Q47b_J(MOq|~TSLT0(W_CX$ ze_SCO-LKvgAQ*clHr)?b#Ey-QnP;6>YC#e?@INFFkzA2ECR9p7c~GgU>PV`(VC9N= z%*Rv}LyLM%Tvg``Y8N1~)J+Sj!>XFIvXBvw;W`3RTUXZoumZdk&#HdF3X5`@>g&sb z4_4I_M`uuCfJgEc(-6G^S(f0*vf&PyeET=#1s@2CH7OXB9_7}CesTL70@>eV!WO~W zr}~@&?loN6F67|0Esar2{YIbSX1X4wWoILdAt)XZiWJgm)|YkW^k4^cxRAq3Cn{o1 zxL!k%M6$a?ywi2>Btzhsz>QM8V4^XsZ*5%Tr`ff>wrIHFe++t|ba0~zGh+---^3og zWjj38=@6DYMJA0`--(;OxkG=T>M$BTTOS--nwXYms??-{zQ!gG{D8SbzBPmlnT za@?0>^h=Wdk_>-AdcPoDUlREXGWZ3Nzar89A`$iTqk#QEl?xydK}ooS8XK5wLgNU@LDTu;j7jyaVK1T)d*xCD_i)gbq;fr+-iG` h;Ct=j}vo3 literal 0 HcmV?d00001 diff --git a/mellea_contribs/reqlib/citation_exists.py b/mellea_contribs/reqlib/citation_exists.py index 40c5577..04eeaed 100644 --- a/mellea_contribs/reqlib/citation_exists.py +++ b/mellea_contribs/reqlib/citation_exists.py @@ -1,186 +1,196 @@ from mellea.stdlib.requirement import Requirement, ValidationResult from mellea.stdlib.base import Context, CBlock +from eyecite.models import FullCaseCitation, CitationBase +from eyecite import get_citations +from citeurl import Citator +from typing import Any, Optional +from playwright.sync_api import sync_playwright +from urllib.parse import urljoin import json import os import re -from eyecite import get_citations, clean_text -from typing import Any +import requests # region: citation_exists function and helpers -def normalize_case_name(name) -> str: +""" +Validator: Ensure that every case-law citation in an LLM output corresponds to a real case in the +provided case metadata database. + +Process: +1. Extract citations from LLM output using citeurl. +2. Convert citation objects to URLs. +3. For each cite.case.law URL: + - Use Playwright to extract metadata URL. + - Fetch JSON metadata. + - Compare its case ID against the known database. +4. If any citation fails, return ValidationResult(False). +5. If all succeed, return ValidationResult(True). +""" + +def text_to_urls(text: str) -> list[str]: """ - Converts a case name to a standard format. + Extracts all citation URLs from the given text using citeurl. Args: - name: A string representing the case name. + text: An LLM output Returns: - A normalized case name. + A list of citation URLs. + + Behavior: + - If a citation does not have a URL attribute, we return a ValidationResult(False) + so that the parent validator can fail accordingly. """ - # 1. Lowercase everything - name = name.lower() - - # 2. Normalize 'vs', 'vs.', 'v', 'versus' to 'v.' - name = re.sub(r'\b(vs\.?|versus|v)(?!\.)\b', 'v.', name) - - # 3. Remove all non-alphanumeric characters except periods, spaces, and apostrophes - name = re.sub(r"[^a-z0-9.& ']+", '', name) - - # 4. Replace multiple spaces with a single space - name = re.sub(r'\s+', ' ', name) - - return name.strip() - -# might not be needed -# def ensure_list_of_dicts(obj: Any) -> list[dict]: -# """ -# Normalize any JSON-like object into a list of dictionaries. - -# Accepts: -# - A JSON string (object or array) -# - A single dict -# - A list of dicts - -# Args: -# obj: Any data type, ideally something that can unpacked into a dictionary - -# Returns: -# The unpacked object in list of dictionary form or raises an error. -# """ -# # JSON string -# if isinstance(obj, str): -# try: -# obj = json.loads(obj) -# except json.JSONDecodeError as e: -# raise ValueError(f"Invalid JSON string: {e!s}") - -# # Single dict -# if isinstance(obj, dict): -# return [obj] - -# # List of dicts -# if isinstance(obj, list): -# if all(isinstance(item, dict) for item in obj): -# return obj -# else: -# raise ValueError("List contains non-dictionary elements") - -# raise TypeError(f"Unsupported metadata format: {type(obj)}") - -# alternatively: -# should this take in last_output instead of the whole context? -# get case name: take LLM output and extract case name --> a string which you get from ctx.last_output() is the input -# so the argument should be ctx.last_output.value: str - -def extract_case_names(ctx: Context) -> list[str]: + citator = Citator() + citations = citator.list_cites(text) + + urls = [] + errors = [] + + for citation in citations: + if hasattr(citation, "URL") and citation.URL: + urls.append(citation.URL) + else: + # Record a descriptive error about the invalid citation object + errors.append(f"Citation has no URL attribute: {repr(citation)}") + + if errors: + # Raise one combined error + error_msg = "Some citations did not contain URLs:\n" + "\n".join(errors) + return ValidationResult(False, reason=error_msg) + + return urls + + +def extract_case_metadata_url(page_url: str) -> str: """ - Given an LLM output, use eyecite to parse the text and collect case names. + Visits a cite.case.law page using Playwright and extracts the "Download case metadata" link. Args: - ctx: An LLM output that may contain multiple citations. + page_url: A cite.case.law page Returns: - A list of case names. + A URL to the JSON metadata for the case or a false ValidationResult if the link cannot be found """ - # should i clean text?? + with sync_playwright() as pw: + browser = pw.chromium.launch() + page = browser.new_page() + page.goto(page_url) + + # Wait for the metadata link to appear + link = page.wait_for_selector("a:has-text('Download case metadata')") + if not link: + return ValidationResult(False, reason=f"No metadata link found on page: {page_url}") + + # Extract relative href + href = link.get_attribute("href") + if not href: + return ValidationResult(False, reason=f"Metadata link missing href attribute on page: {page_url}") + + # Build the absolute metadata URL + return urljoin(page_url, href) + - # install hyperscan if not already installed - # !pip install hyperscan - # tokenizer = HyperscanTokenizer(cache_dir=".test_cache") - # citations = get_citations(cleaned_text, tokenizer=tokenizer) +def metadata_url_to_json(metadata_url: str) -> dict: + """ + Fetches JSON metadata for a case. - # or this? - # cleaned_text = clean_text(text, ["html", "all_whitespace"]) - # citations = get_citations(cleaned_text) + Args: + metadata_url: Fully-qualified URL to metadata.json - # get_citations outputs a list of citations - citations = get_citations(ctx.last_output().value) - case_names = set() + Returns: + A dictionary representing the JSON metadata. + """ + resp = requests.get(metadata_url) + resp.raise_for_status() + return resp.json() - for citation in citations: - plaintiff = citation.metadata.get("plaintiff") - defendant = citation.metadata.get("defendant") - if plaintiff and defendant: - case_names.add(f"{plaintiff} v. {defendant}") - # name = citation.metadata['plaintiff'] + " v. " + citation.metadata['defendant'] - # case_names.add(name) - - return list(case_names) -def citation_exists(ctx: Context, case_metadata: list[dict]) -> ValidationResult: +def collect_ids_in_database(database: list[dict]) -> set: """ - Given an LLM output and a list of dictionaries, checks that list (which represents a collection of - case metadata json files) to see if the given case names can be found in it. + Collects all case IDs from the provided caselaw metadata. Args: - ctx: Context that contains the case names we're checking for - case_metadata: a list of dictionaries which represents a collection of case metadata json files + database: A list of case dictionaries loaded from a caselaw JSON dataset. Returns: - A validation result indicating if a match was found between given case names and database + A set of all unique case IDs. + """ + return {case["id"] for case in database} + + +def citation_exists(ctx: Context, database: list[dict]) -> ValidationResult: + """ + Validator: + Ensures that every cite.case.law URL in the LLM output corresponds to a real case in the provided case metadata database. + + Args: + ctx: Mellea runtime context containing the last LLM output. + database: Parsed caselaw metadata database of JSON objects. + + Returns: + ValidationResult indicating pass/fail. """ if ctx is None: - return ValidationResult(False, reason="No context provided in output") + return ValidationResult(False, reason="No context provided in output.") - # 1) this will spit out a bunch of words --> look through to extract case names - # 2) use eyecite (might have to do some conversion) last_output = ctx.last_output() - # if last_output is None or not getattr(output, "value", None): if last_output is None: - return ValidationResult(False, reason="No last output found in context") + return ValidationResult(False, reason="No last output found in contex.") - # 3) run checking - # call get_case_name func - case_names = extract_case_names(ctx) - - if not case_names or not isinstance(case_names, list[str]): - return ValidationResult(False, reason="No case names provided in output") + if type(last_output) != str: + return ValidationResult(False, reason="Last output must be a string.") - normalized_case_names = [normalize_case_name(case_name) for case_name in case_names] + # List of urls of citations found in the LLM output + output_citation_urls = text_to_urls(last_output) + + # text_to_urls may return a ValidationResult (error condition) + if isinstance(output_citation_urls, ValidationResult): + return output_citation_urls - case_names = set() - case_name_abb = set() - - # add name and name_abbreviation from the database - for case in case_metadata: - if 'name' in case: - case_names.add(normalize_case_name(case['name'])) - if 'name_abbreviation' in case: - case_name_abb.add(normalize_case_name(case['name_abbreviation'])) - - # Check both name and name_abbreviation - for normalized_case_name in normalized_case_names: - if normalized_case_name not in case_names and normalized_case_name not in case_name_abb: - # probably want to change this to the actual case name at some point - # maybe keep a tuple structure or something - return ValidationResult(False, reason=f"'{normalized_case_name}' not found in database") + if output_citation_urls is None or output_citation_urls == []: + # No citations, so trivially valid + return ValidationResult(True, reason="No citations found.") + + database_ids = collect_ids_in_database(database) + + for url in output_citation_urls: + + # If this URL is Caselaw, do direct comparison within database by using case id + if "cite.case.law" in url: + try: + metadata_url = extract_case_metadata_url(url) + metadata = metadata_url_to_json(metadata_url) + case_id = metadata["id"] + + except Exception as e: + return ValidationResult(False, reason=f"Failed to retrieve metadata for {url}: {e}") + + if case_id not in database_ids: + return ValidationResult(False, reason=f"Case {case_id} not found in database") - return ValidationResult(True, reason="All case names found in database") - - # check if this code chunk is right later - # db_names = {normalize_case_name(c["name"]) for c in case_metadata if "name" in c} - # db_abbrevs = { - # normalize_case_name(c["name_abbreviation"]) for c in case_metadata if "name_abbreviation" in c - # } - - # for name in normalized_output_names: - # if name not in db_names and name not in db_abbrevs: - # return ValidationResult(False, reason=f"Case '{name}' not found in database") - - # return ValidationResult(True, reason="All case names found in database") + else: + # Non-caselaw citations (e.g., statutes): ignore + # Extending functionality to be done later: use LLM as judge to see if citations match + continue + return ValidationResult(True, reason="All case names found in database") + class CaseNameExistsInDatabase(Requirement): """ - Checks if the output case name exists in the provided case metadata database. + Requirement wrapper for Mellea that ensures case citations in LLM output + refer to real cases in the provided metadata database. """ - def __init__(self, case_metadata: str): + # is this taking in the right parameters? + def __init__(self, case_metadata: list[dict]): self._case_metadata = case_metadata super().__init__( description="The case name should exist in the provided case metadata database.", validation_fn=lambda ctx: citation_exists(ctx, self._case_metadata), ) + # endregion \ No newline at end of file