From c1549941bd9940291e4498153b2ed64deba70a3f Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 15:58:19 -0500 Subject: [PATCH 01/40] gitignore --- .gitignore | 442 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 442 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6b5814f --- /dev/null +++ b/.gitignore @@ -0,0 +1,442 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +scratchpad/ +*.egg-info +.vscode/ + +# IDE +.idea + +# Virtual environments +.venv + +# emacs backup +*~ +\#* + +vllm_backup/ + +# Created by https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains +# Edit at https://www.toptal.com/developers/gitignore?templates=python,direnv,visualstudiocode,pycharm,macos,jetbrains + +### direnv ### +.direnv +.envrc + +### JetBrains ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### JetBrains Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff + +# AWS User-specific + +# Generated files + +# Sensitive or high-churn files + +# Gradle + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake + +# Mongo Explorer plugin + +# File-based project format + +# IntelliJ + +# mpeltonen/sbt-idea plugin + +# JIRA plugin + +# Cursive Clojure plugin + +# SonarLint plugin + +# Crashlytics plugin (for Android Studio and IntelliJ) + +# Editor-based Rest Client + +# Android studio 3.1+ serialized cache file + +### PyCharm Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains From 48fa4211af5c092fa10a5a2ff6edec452a79f70b Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 15:59:22 -0500 Subject: [PATCH 02/40] VA init --- mellea_contribs/va/__init__.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 mellea_contribs/va/__init__.py diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py new file mode 100644 index 0000000..10db3e1 --- /dev/null +++ b/mellea_contribs/va/__init__.py @@ -0,0 +1,38 @@ + + +from mellea import MelleaSession + +from pydantic import BaseModel + +from typing import Literal + +class YesNo(BaseModel): + answer : Literal["yes","no"] + +class Core: + + def binary(m:MelleaSession, prompt): + + output = m.instruct(f"{prompt} Answer yes or no.", + format=YesNo) + + yesno = YesNo.model_validate_json(output.value) + + return yesno.answer == "yes" + + + def choice(self:MelleaSession, prompt, choices:list[str]): + + class Choice(BaseModel): + answer : Literal[choices] + + output = self.instruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", + format=Choice) + + choice = Choice.model_validate_json(output.value) + + return choice.answer + + + + From 74e7179164eae0faad229b8a063167a8d7ac5189 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 16:04:14 -0500 Subject: [PATCH 03/40] str --- mellea_contribs/va/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 10db3e1..91feb33 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -11,7 +11,7 @@ class YesNo(BaseModel): class Core: - def binary(m:MelleaSession, prompt): + def binary(m:MelleaSession, prompt:str): output = m.instruct(f"{prompt} Answer yes or no.", format=YesNo) @@ -21,7 +21,7 @@ def binary(m:MelleaSession, prompt): return yesno.answer == "yes" - def choice(self:MelleaSession, prompt, choices:list[str]): + def choice(self:MelleaSession, prompt:str, choices:list[str]): class Choice(BaseModel): answer : Literal[choices] From af72392b6c625373550307a14fa491a0ae85b3eb Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 16:05:37 -0500 Subject: [PATCH 04/40] async --- mellea_contribs/va/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 91feb33..61b123b 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -21,6 +21,16 @@ def binary(m:MelleaSession, prompt:str): return yesno.answer == "yes" + async def abinary(m:MelleaSession, prompt:str): + + output = await m.ainstruct(f"{prompt} Answer yes or no.", + format=YesNo) + + yesno = YesNo.model_validate_json(output.value) + + return yesno.answer == "yes" + + def choice(self:MelleaSession, prompt:str, choices:list[str]): class Choice(BaseModel): @@ -34,5 +44,16 @@ class Choice(BaseModel): return choice.answer + async def achoice(self:MelleaSession, prompt:str, choices:list[str]): + + class Choice(BaseModel): + answer : Literal[choices] + + output = await self.ainstruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", + format=Choice) + + choice = Choice.model_validate_json(output.value) + + return choice.answer From 2ecb726da74dde347e15e3a37aa3f9e2d776da61 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 16:12:11 -0500 Subject: [PATCH 05/40] kwargs --- mellea_contribs/va/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 61b123b..07f5129 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -11,46 +11,46 @@ class YesNo(BaseModel): class Core: - def binary(m:MelleaSession, prompt:str): + def binary(m:MelleaSession, prompt:str, **kwargs): output = m.instruct(f"{prompt} Answer yes or no.", - format=YesNo) + format=YesNo, **kwargs) yesno = YesNo.model_validate_json(output.value) return yesno.answer == "yes" - async def abinary(m:MelleaSession, prompt:str): + async def abinary(m:MelleaSession, prompt:str, **kwargs): output = await m.ainstruct(f"{prompt} Answer yes or no.", - format=YesNo) + format=YesNo, **kwargs) yesno = YesNo.model_validate_json(output.value) return yesno.answer == "yes" - def choice(self:MelleaSession, prompt:str, choices:list[str]): + def choice(self:MelleaSession, prompt:str, choices:list[str], **kwargs): class Choice(BaseModel): answer : Literal[choices] output = self.instruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", - format=Choice) + format=Choice, **kwargs) choice = Choice.model_validate_json(output.value) return choice.answer - async def achoice(self:MelleaSession, prompt:str, choices:list[str]): + async def achoice(self:MelleaSession, prompt:str, choices:list[str], **kwargs): class Choice(BaseModel): answer : Literal[choices] output = await self.ainstruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", - format=Choice) + format=Choice, **kwargs) choice = Choice.model_validate_json(output.value) From fe7dba8f1d678cbfd4f23e8231ff6d315da6ca78 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 5 Nov 2025 16:32:57 -0500 Subject: [PATCH 06/40] WIP --- mellea_contribs/va/__init__.py | 76 ++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 07f5129..d92da9f 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -1,27 +1,24 @@ - +import functools +import asyncio from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread from pydantic import BaseModel from typing import Literal -class YesNo(BaseModel): - answer : Literal["yes","no"] - -class Core: - def binary(m:MelleaSession, prompt:str, **kwargs): - output = m.instruct(f"{prompt} Answer yes or no.", - format=YesNo, **kwargs) - yesno = YesNo.model_validate_json(output.value) - return yesno.answer == "yes" +class YesNo(BaseModel): + answer : Literal["yes","no"] +class Core: - async def abinary(m:MelleaSession, prompt:str, **kwargs): + async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: output = await m.ainstruct(f"{prompt} Answer yes or no.", format=YesNo, **kwargs) @@ -30,30 +27,65 @@ async def abinary(m:MelleaSession, prompt:str, **kwargs): return yesno.answer == "yes" - - def choice(self:MelleaSession, prompt:str, choices:list[str], **kwargs): + async def achoice(self:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: class Choice(BaseModel): answer : Literal[choices] - output = self.instruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", - format=Choice, **kwargs) + output = await self.ainstruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", + format=Choice, **kwargs) choice = Choice.model_validate_json(output.value) return choice.answer + def bool(m:MelleaSession, prompt:str, **kwargs) -> bool: + return _run_async_in_thread(abool(m, prompt, **kwargs)) - async def achoice(self:MelleaSession, prompt:str, choices:list[str], **kwargs): + def choice(m:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: + return _run_async_in_thread(achoice(m, prompt, **kwargs)) - class Choice(BaseModel): - answer : Literal[choices] - output = await self.ainstruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", - format=Choice, **kwargs) - choice = Choice.model_validate_json(output.value) - return choice.answer +class Arity(Core): + + async def abinary(m:MelleaSession, criteria:str, x:str, y:str, symmetric:bool=True, vote:int=3, **kwargs) -> bool: + """Evaluates a binary boolean function. + """ + + if vote % 2 == 0: + FancyLogger.get_logger().warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + tasks = [ + m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}") + for _ in range(vote // 2 + 1) + ] + [ + m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}") + for _ in range(vote // 2) + ] + + else: + tasks = [ + m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}") + for _ in range(vote) + ] + + answers = asyncio.gather(*tasks) + + return (answers[:majority].count(True) + answers[majority:].count(False)) >= majority + + + def binary(m:MelleaSession, criteria:str, x:str, y:str, symmetric:bool=True, vote:int=3, **kwargs) -> bool: + return _run_async_in_thread(abinary(m, criteria, x, y, symmetric, vote, **kwargs)) + + + +class Sorting(Arity): + From 3c03b32e854c3c57f2d16ba465044a28b8474622 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 14:23:00 -0500 Subject: [PATCH 07/40] added positional voting --- mellea_contribs/va/__init__.py | 55 +++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index d92da9f..cce822f 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -1,5 +1,6 @@ import functools +import itertools import asyncio from mellea import MelleaSession from mellea.helpers.fancy_logger import FancyLogger @@ -52,8 +53,19 @@ def choice(m:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: class Arity(Core): - async def abinary(m:MelleaSession, criteria:str, x:str, y:str, symmetric:bool=True, vote:int=3, **kwargs) -> bool: - """Evaluates a binary boolean function. + async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=True, positional:bool=True, **kwargs) -> bool: + """Evaluates a query that corresponds to a binary boolean function. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: If True, half of the queries swap x and y, and count the number of "no" for majority voting instead. This mitigates LLM's psycophancy bias toward answering "yes". + shuffle: If True, shuffles the order of presenting x and y. This mitigates the positional bias. + + Returns: + bool. """ if vote % 2 == 0: @@ -63,19 +75,34 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, symmetric:bool=Tr vote += 1 if symmetric: - tasks = [ - m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}") - for _ in range(vote // 2 + 1) - ] + [ - m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}") - for _ in range(vote // 2) - ] - + if positional: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{x}\nX:{y}", + ] + else: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", + ] else: - tasks = [ - m.abool(f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}") - for _ in range(vote) - ] + if positional: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", + ] + else: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + ] + + + tasks = [ + m.abool(p) + for i, p in zip(range(vote),itertools.cycle(prompts)) + ] answers = asyncio.gather(*tasks) From 79f18a7fddb118b7fc0a44bd34cdf7b88ed7b29e Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 14:44:44 -0500 Subject: [PATCH 08/40] more --- mellea_contribs/va/__init__.py | 45 ++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index cce822f..983af0e 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -1,4 +1,5 @@ +import random import functools import itertools import asyncio @@ -51,23 +52,32 @@ def choice(m:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: -class Arity(Core): +class Relation(Core): - async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=True, positional:bool=True, **kwargs) -> bool: - """Evaluates a query that corresponds to a binary boolean function. + async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, + symmetric:bool=False, + antisymmetric:bool=False, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a binary relation. Args: criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. x: the first element y: the second element vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - symmetric: If True, half of the queries swap x and y, and count the number of "no" for majority voting instead. This mitigates LLM's psycophancy bias toward answering "yes". - shuffle: If True, shuffles the order of presenting x and y. This mitigates the positional bias. - + symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. + antisymmetric: Declares the relation to be antisymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + positional: Half of the queries shuffle the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. Returns: bool. """ + assert not (symmetric and antisymmetric), "symmetric and antisymmetric flags are mutually exclusive" + if vote % 2 == 0: FancyLogger.get_logger().warning( "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." @@ -87,6 +97,19 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symme f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", ] + elif antisymmetric: + if positional: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", + f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nY:{x}\nX:{y}", + ] + else: + prompts = [ + f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", + f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", + ] else: if positional: prompts = [ @@ -98,6 +121,8 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symme f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", ] + if shuffle: + random.shuffle(prompts) tasks = [ m.abool(p) @@ -106,10 +131,14 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symme answers = asyncio.gather(*tasks) - return (answers[:majority].count(True) + answers[majority:].count(False)) >= majority + return answers.count(True) >= (vote // 2) + 1 - def binary(m:MelleaSession, criteria:str, x:str, y:str, symmetric:bool=True, vote:int=3, **kwargs) -> bool: + def binary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, + symmetric:bool=False, + antisymmetric:bool=False, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: return _run_async_in_thread(abinary(m, criteria, x, y, symmetric, vote, **kwargs)) From 2b03c0b4921862473b6b006b12d1f79f1ecc7ef3 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 14:48:58 -0500 Subject: [PATCH 09/40] fix --- mellea_contribs/va/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 983af0e..50ff53e 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -56,7 +56,7 @@ class Relation(Core): async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=False, - antisymmetric:bool=False, + asymmetric:bool=False, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: """Evaluates a query that evaluates a binary relation. @@ -67,7 +67,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, y: the second element vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. - antisymmetric: Declares the relation to be antisymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". positional: Half of the queries shuffle the order of presenting x and y. This mitigates the positional bias. shuffle: It shuffles the variation of queries (symmetric/positional variations). This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). @@ -76,7 +76,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, bool. """ - assert not (symmetric and antisymmetric), "symmetric and antisymmetric flags are mutually exclusive" + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" if vote % 2 == 0: FancyLogger.get_logger().warning( @@ -97,7 +97,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", ] - elif antisymmetric: + elif asymmetric: if positional: prompts = [ f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", @@ -136,7 +136,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, def binary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=False, - antisymmetric:bool=False, + asymmetric:bool=False, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: return _run_async_in_thread(abinary(m, criteria, x, y, symmetric, vote, **kwargs)) From d58fb06ba3aa86d8a6516ca48ca5a41d01ddf903 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 14:53:12 -0500 Subject: [PATCH 10/40] ir/reflexive --- mellea_contribs/va/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 50ff53e..025e6a1 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -57,6 +57,8 @@ class Relation(Core): async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=False, asymmetric:bool=False, + reflexive:bool=False, + irreflexive:bool=False, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: """Evaluates a query that evaluates a binary relation. @@ -68,6 +70,8 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + reflexive: Declares the relation to be reflexive, i.e., if x == y, returns True immediately. + irreflexive: Declares the relation to be irreflexive, i.e., if x == y, returns False immediately. positional: Half of the queries shuffle the order of presenting x and y. This mitigates the positional bias. shuffle: It shuffles the variation of queries (symmetric/positional variations). This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). @@ -78,6 +82,12 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + if x == y: + if reflexive: + return True + if irreflexive: + return False + if vote % 2 == 0: FancyLogger.get_logger().warning( "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." @@ -137,6 +147,8 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, def binary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, symmetric:bool=False, asymmetric:bool=False, + reflexive:bool=False, + irreflexive:bool=False, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: return _run_async_in_thread(abinary(m, criteria, x, y, symmetric, vote, **kwargs)) From 7e630b10bbc8303ab85a32b0d3b7656e411da561 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 15:36:41 -0500 Subject: [PATCH 11/40] sync_wrapper --- mellea_contribs/va/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 025e6a1..0289ae4 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -12,7 +12,12 @@ from typing import Literal - +def sync_wrapper(async_fn): + """Wrap an async function so it can be called synchronously.""" + @functools.wraps(async_fn) + def wrapper(*args, **kwargs): + return _run_async_in_thread(async_fn(*args, **kwargs)) + return wrapper class YesNo(BaseModel): @@ -41,15 +46,10 @@ class Choice(BaseModel): return choice.answer - def bool(m:MelleaSession, prompt:str, **kwargs) -> bool: - return _run_async_in_thread(abool(m, prompt, **kwargs)) - - def choice(m:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: - return _run_async_in_thread(achoice(m, prompt, **kwargs)) - - - + pass +Core.bool = sync_wrapper(Core.abool) +Core.choice = sync_wrapper(Core.achoice) class Relation(Core): From 0ed872a60a9bfa2c5fdecc4906009b30ce09aecb Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 15:37:10 -0500 Subject: [PATCH 12/40] compress using permutations --- mellea_contribs/va/__init__.py | 62 ++++++++-------------------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 0289ae4..e36d0fb 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -72,7 +72,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". reflexive: Declares the relation to be reflexive, i.e., if x == y, returns True immediately. irreflexive: Declares the relation to be irreflexive, i.e., if x == y, returns False immediately. - positional: Half of the queries shuffle the order of presenting x and y. This mitigates the positional bias. + positional: Permute the order of presenting x and y. This mitigates the positional bias. shuffle: It shuffles the variation of queries (symmetric/positional variations). This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. @@ -95,65 +95,33 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, vote += 1 if symmetric: - if positional: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{x}\nX:{y}", - ] - else: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", - ] + args = [(x,y),(y,x)] + target = [True,True] elif asymmetric: - if positional: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", - f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nY:{x}\nX:{y}", - ] - else: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - f"Do X and Y violate the following criteria? \nCriteria: {criteria}\nX:{y}\nY:{x}", - ] + args = [(x,y),(y,x)] + target = [True,False] else: + args = [(x,y)] + target = [True] + + prompts = [] + for (x, y), t in zip(args, target): + prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", t)) if positional: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", - ] - else: - prompts = [ - f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", - ] + prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", t)) if shuffle: random.shuffle(prompts) tasks = [ m.abool(p) - for i, p in zip(range(vote),itertools.cycle(prompts)) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) ] answers = asyncio.gather(*tasks) - return answers.count(True) >= (vote // 2) + 1 - - - def binary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, - symmetric:bool=False, - asymmetric:bool=False, - reflexive:bool=False, - irreflexive:bool=False, - positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: - return _run_async_in_thread(abinary(m, criteria, x, y, symmetric, vote, **kwargs)) - + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + return answers.count(True) >= (vote // 2) + 1 -class Sorting(Arity): From 5082a0eeac016451eb53e87414da63ab18195fa4 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 15:44:13 -0500 Subject: [PATCH 13/40] ternary functions --- mellea_contribs/va/__init__.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index e36d0fb..48583d4 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -125,3 +125,73 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, return answers.count(True) >= (vote // 2) + 1 + async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, vote:int=3, + symmetric:bool=False, + asymmetric:bool=False, + positional:bool=True, + shuffle:bool=True, + **kwargs) -> bool: + """Evaluates a query that evaluates a ternary relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + z: the third element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: Declares the relation to be symmetric wrto x and y. Half of the queries swap x and y. + asymmetric: Declares the relation to be asymmetric wrto x and y. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + positional: The queries permutes the order of presenting x, y, z. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + + if vote % 2 == 0: + FancyLogger.get_logger().warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + args = [(x,y,z),(y,x,z)] + target = [True,True] + elif asymmetric: + args = [(x,y,z),(y,x,z)] + target = [True,False] + else: + args = [(x,y,z)] + target = [True] + + prompts = [] + for (x, y, z), t in zip(args, target): + parts = [f"X:{x}", f"Y:{y}", f"Z:{z}"] + if positional: + for _parts in itertools.permutations(parts): + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *_parts]), t)) + else: + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *parts]), t)) + + if shuffle: + random.shuffle(prompts) + + tasks = [ + m.abool(p) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) + ] + + answers = asyncio.gather(*tasks) + + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + + return answers.count(True) >= (vote // 2) + 1 + + +Relation.binary = sync_wrapper(Relation.abinary) +Relation.ternary = sync_wrapper(Relation.aternary) + + From 56c2009d512342d840175e9b81dbbcb07a5a034c Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 15:52:41 -0500 Subject: [PATCH 14/40] forced keywords --- mellea_contribs/va/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 48583d4..d9f9cdd 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -54,7 +54,8 @@ class Choice(BaseModel): class Relation(Core): - async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, + async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, symmetric:bool=False, asymmetric:bool=False, reflexive:bool=False, @@ -125,7 +126,8 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, vote:int=3, return answers.count(True) >= (vote // 2) + 1 - async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, vote:int=3, + async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, + vote:int=3, symmetric:bool=False, asymmetric:bool=False, positional:bool=True, From cb56dd969ca2b0b161783e9b1f48a5feba591ebe Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 15:53:26 -0500 Subject: [PATCH 15/40] (a)gt,ge,eq --- mellea_contribs/va/__init__.py | 84 ++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index d9f9cdd..8418d22 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -193,7 +193,91 @@ async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, return answers.count(True) >= (vote // 2) + 1 + async def agt(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than" relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=False, + irreflexive=True, + shuffle=shuffle, **kwargs) + + async def age(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than-equal" relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + + async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates an equivalence relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=True, + asymmetric=False, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + + Relation.binary = sync_wrapper(Relation.abinary) Relation.ternary = sync_wrapper(Relation.aternary) +Relation.gt = sync_wrapper(Relation.agt) +Relation.ge = sync_wrapper(Relation.age) +Relation.eq = sync_wrapper(Relation.aeq) From 9015ce2bf1e11709bc9b73dd7286e8a15186fa7c Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 6 Nov 2025 16:02:58 -0500 Subject: [PATCH 16/40] sort --- mellea_contribs/va/__init__.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 8418d22..167c67c 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -281,3 +281,35 @@ async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, Relation.eq = sync_wrapper(Relation.aeq) +async def async_merge_sort(lst, cmp): + if len(lst) <= 1: + return lst + mid = len(lst) // 2 + left = await async_merge_sort(lst[:mid], cmp) + right = await async_merge_sort(lst[mid:], cmp) + return await async_merge(left, right, cmp) + +async def async_merge(left, right, cmp): + result = [] + while left and right: + if await cmp(left[0], right[0]): + result.append(left.pop(0)) + else: + result.append(right.pop(0)) + return result + left + right + + +class Sort(Relation): + + async def asort(m:MelleaSession, criteria:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + + async def cmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return async_merge_sort(elems, cmp) + + +Sort.sort = sync_wrapper(Sort.asort) From adec3a7d5af98c52b18c0b161149d885b054d6d0 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Fri, 7 Nov 2025 16:46:16 -0500 Subject: [PATCH 17/40] maximum: elimination tournament --- mellea_contribs/va/__init__.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 167c67c..110f44d 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -298,8 +298,19 @@ async def async_merge(left, right, cmp): result.append(right.pop(0)) return result + left + right +async def async_max(lst, cmp): + if len(lst) <= 1: + return lst[0] + mid = len(lst) // 2 + left = await async_max(lst[:mid], cmp) + right = await async_max(lst[mid:], cmp) + if await cmp(left, right): + return left + else: + return right + -class Sort(Relation): +class Sequence(Relation): async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, @@ -311,5 +322,18 @@ async def cmp(x, y): return async_merge_sort(elems, cmp) + async def amax(m:MelleaSession, criteria:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + + async def cmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return async_max(elems, cmp) + Sort.sort = sync_wrapper(Sort.asort) +Sort.max = sync_wrapper(Sort.amax) + + From 4b90859934930053d2ace6f4dfe0402c2652deea Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 11:56:13 -0500 Subject: [PATCH 18/40] renamed, typed --- mellea_contribs/va/__init__.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 110f44d..b8ca7ae 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -281,30 +281,30 @@ async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, Relation.eq = sync_wrapper(Relation.aeq) -async def async_merge_sort(lst, cmp): +async def async_merge_sort(lst:list[str], acmp): if len(lst) <= 1: return lst mid = len(lst) // 2 - left = await async_merge_sort(lst[:mid], cmp) - right = await async_merge_sort(lst[mid:], cmp) - return await async_merge(left, right, cmp) + left = await async_merge_sort(lst[:mid], acmp) + right = await async_merge_sort(lst[mid:], acmp) + return await async_merge(left, right, acmp) -async def async_merge(left, right, cmp): +async def async_merge(left:list[str], right:list[str], acmp): result = [] while left and right: - if await cmp(left[0], right[0]): + if await acmp(left[0], right[0]): result.append(left.pop(0)) else: result.append(right.pop(0)) return result + left + right -async def async_max(lst, cmp): +async def async_max(lst:list[str], acmp): if len(lst) <= 1: return lst[0] mid = len(lst) // 2 - left = await async_max(lst[:mid], cmp) - right = await async_max(lst[mid:], cmp) - if await cmp(left, right): + left = await async_max(lst[:mid], acmp) + right = await async_max(lst[mid:], acmp) + if await acmp(left, right): return left else: return right @@ -317,20 +317,20 @@ async def asort(m:MelleaSession, criteria:str, elems:list[str], *, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: - async def cmp(x, y): + async def acmp(x, y): return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - return async_merge_sort(elems, cmp) + return async_merge_sort(elems, acmp) async def amax(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, positional:bool=True, shuffle:bool=True, **kwargs) -> bool: - async def cmp(x, y): + async def acmp(x, y): return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - return async_max(elems, cmp) + return async_max(elems, acmp) Sort.sort = sync_wrapper(Sort.asort) From 4fa47c91e35c4a8c16bd8893817a390916e9cc5f Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 12:28:29 -0500 Subject: [PATCH 19/40] added amedian --- mellea_contribs/va/__init__.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index b8ca7ae..2f7ccbf 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -310,6 +310,51 @@ async def async_max(lst:list[str], acmp): return right +async def async_mom(seq:list[str], acmp, asort, block_size=5): + """ + Median of medians algorithm for finding an approximate median. Worst-case runtime O(n) + """ + + async def median_fixed(seq): + return await asort(seq)[len(seq)//2] + + if len(seq) <= block_size: + return await median_fixed(seq) + + # Step 1: Divide into groups of block_size + groups = itertools.batched(seq, block_size) + + # Step 2: Find median of each group + medians = asyncio.gather(*[median_fixed(g) for g in groups]) + + # Step 3: Recursively find the pivot + return await async_mom(medians, acmp, asort, block_size=block_size) + + +async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): + """ + Quickselect algorithm that uses median-of-medians for pivot selection. Worst-case runtime O(n^2) + """ + + pivot = await async_mom(medians, acmp, asort, block_size=block_size) + + # Step 4: Partition + lows, highs = [], [] + for x in seq: + if await acmp(x, pivot): + lows.append(x) + else: + highs.append(x) + + # Step block_size: Recurse + if k < len(lows): + return await async_quickselect(lows, k, acmp, asort, block_size=block_size) + elif k == len(lows): + return pivot + else: + return await async_quickselect(highs, k - len(lows), acmp, asort, block_size=block_size) + + class Sequence(Relation): async def asort(m:MelleaSession, criteria:str, elems:list[str], *, @@ -332,8 +377,33 @@ async def acmp(x, y): return async_max(elems, acmp) + async def amedian(m:MelleaSession, criteria:str, elems:list[str], *, + exact = False, + vote:int=3, + positional:bool=True, + shuffle:bool=True, + block_size:int=5, + **kwargs) -> bool: + """ + If exact = True, use quickselect. + Otherwise, return the approximate median returned by median of medians. + """ + + async def acmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + async def asort(elems:list[str]): + return await m.asort(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + if exact: + return await async_quickselect(elems, len(elems)//2, acmp, asort, block_size=block_size) + else: + return await async_mom(elems, acmp, asort, block_size=block_size) + + Sort.sort = sync_wrapper(Sort.asort) Sort.max = sync_wrapper(Sort.amax) +Sort.median = sync_wrapper(Sort.amedian) From 3a4e3b206e082b2f150f22ca7ae34257e6a8fc2b Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 13:07:19 -0500 Subject: [PATCH 20/40] refactor: cleanup --- mellea_contribs/va/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 2f7ccbf..5757058 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -125,7 +125,6 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, return answers.count(True) >= (vote // 2) + 1 - async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, vote:int=3, symmetric:bool=False, @@ -192,7 +191,6 @@ async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, return answers.count(True) >= (vote // 2) + 1 - async def agt(m:MelleaSession, criteria:str, x:str, y:str, *, vote:int=3, positional:bool=True, @@ -245,7 +243,6 @@ async def age(m:MelleaSession, criteria:str, x:str, y:str, *, irreflexive=False, shuffle=shuffle, **kwargs) - async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, vote:int=3, positional:bool=True, @@ -273,7 +270,6 @@ async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, shuffle=shuffle, **kwargs) - Relation.binary = sync_wrapper(Relation.abinary) Relation.ternary = sync_wrapper(Relation.aternary) Relation.gt = sync_wrapper(Relation.agt) @@ -309,7 +305,6 @@ async def async_max(lst:list[str], acmp): else: return right - async def async_mom(seq:list[str], acmp, asort, block_size=5): """ Median of medians algorithm for finding an approximate median. Worst-case runtime O(n) @@ -330,7 +325,6 @@ async def median_fixed(seq): # Step 3: Recursively find the pivot return await async_mom(medians, acmp, asort, block_size=block_size) - async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): """ Quickselect algorithm that uses median-of-medians for pivot selection. Worst-case runtime O(n^2) From 145eb027e7c169efc83d8dbfa6b468cc0ed6e09e Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 15:25:05 -0500 Subject: [PATCH 21/40] use module-global logger --- mellea_contribs/va/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 5757058..cf11eaa 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -11,6 +11,7 @@ from typing import Literal +logger = FancyLogger.get_logger() def sync_wrapper(async_fn): """Wrap an async function so it can be called synchronously.""" @@ -90,7 +91,7 @@ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, return False if vote % 2 == 0: - FancyLogger.get_logger().warning( + logger.warning( "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." ) vote += 1 @@ -153,7 +154,7 @@ async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" if vote % 2 == 0: - FancyLogger.get_logger().warning( + logger.warning( "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." ) vote += 1 From d19cc1ca614323152ac32e5048e9a8cb9a3c6c4a Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 15:30:47 -0500 Subject: [PATCH 22/40] fix --- mellea_contribs/va/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index cf11eaa..990d86b 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -355,7 +355,7 @@ class Sequence(Relation): async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: + shuffle:bool=True, **kwargs) -> list[str]: async def acmp(x, y): return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) @@ -365,7 +365,7 @@ async def acmp(x, y): async def amax(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: + shuffle:bool=True, **kwargs) -> str: async def acmp(x, y): return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) @@ -378,7 +378,7 @@ async def amedian(m:MelleaSession, criteria:str, elems:list[str], *, positional:bool=True, shuffle:bool=True, block_size:int=5, - **kwargs) -> bool: + **kwargs) -> str: """ If exact = True, use quickselect. Otherwise, return the approximate median returned by median of medians. From af834e7167960cb2a293003d44404b2c9093b847 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Mon, 10 Nov 2025 15:45:01 -0500 Subject: [PATCH 23/40] triplet clustering --- mellea_contribs/va/__init__.py | 265 +++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 990d86b..e6829f1 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -3,6 +3,7 @@ import functools import itertools import asyncio +import dataclasses from mellea import MelleaSession from mellea.helpers.fancy_logger import FancyLogger from mellea.helpers.event_loop_helper import _run_async_in_thread @@ -11,6 +12,10 @@ from typing import Literal +import numpy as np + +from sklearn.base import ClusterMixin + logger = FancyLogger.get_logger() def sync_wrapper(async_fn): @@ -402,3 +407,263 @@ async def asort(elems:list[str]): Sort.median = sync_wrapper(Sort.amedian) + +@dataclasses.dataclass +class Triplet: + z : str # anchor + x : str + y : str + z_index : int # z's index in items + x_index : int # x's index in items + y_index : int # y's index in items + + def swap(self): + "swaps x and y" + return Triplet(self.z, self.y, self.x, self.z_index, self.y_index, self.x_index) + + +def sample_triplets(items: list[str], + triplets_per_item: int | None = None, + num_triplets: int | None = None, + repeat_x: int = 1, + ) -> list[Triplet]: + """Randomly sample a list of triplet comparison queries. + + Args: + items : input + triplets_per_item : how many triplets to generate relative to the number of items. + num_triplets : the number of triplets to generate. + repeat_x : how many times we reuse the same z, x for sampling y. + + Either triplets_per_item or num_triplets must be specified. + triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). + + """ + N = len(items) + + assert (num_triplets is not None) or (triplets_per_item is not None), \ + "Specify either num_triplets and triplets_per_item." + assert (num_triplets is None) or (triplets_per_item is None), \ + "num_triplets and triplets_per_item are mutually exclusive; do not specify both." + if num_triplets is None: + assert isinstance(triplets_per_item, int) + logger.info(f"num_triplets = triplets_per_item * len(items) = {triplets_per_item} * {N} = {triplets_per_item * N}") + num_triplets = triplets_per_item * N + + # make sure z covers all elements + assert num_triplets / N >= 1, \ + ("Some items are never used as an anchor z. Increase num_triplets or triplets_per_item: " + f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}") + + # make sure z covers all elements even if we sample multiple triplets with the same x + if repeat_x > num_triplets / N: + logger.warning(f"Some items are never used as an anchor z because of too large repeat_x. Overriding it with {num_triplets / N}: " + f"repeat_x = {repeat_x} > " + f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}.") + repeat_x = num_triplets // N + + # switch to the exhaustive mode if num_triplets is large enough + all_triplets = N * (N-1) * (N-2) + logger.info(f"all_triplets = N * (N-1) * (N-2) = {all_triplets}, where N = {N}") + if num_triplets > all_triplets: + logger.warning(f"num_triplets = {num_triplets} is large enough to enumerate all triplets (> {N} * {(N-1)} * {(N-2)} = {all_triplets}). " + f"Switching to the exhaustive mode.") + exhaustive = True + num_triplets = all_triplets + else: + exhaustive = False + + triplets: list[Triplet] = [] + + bar = tqdm(total=num_triplets, desc="sampling triplets") + + if exhaustive: + for i, z in enumerate(items): + for j, x in enumerate(items): + if i == j: + continue + for k, y in enumerate(items): + if k == i or k == j: + continue + triplets.append(Triplet(z, x, y, i, j, k)) + bar.update() + assert len(triplets) == all_triplets + return triplets + + def sample_except(blacklist:set[str]): + while True: + sample = random.choice(items) + if sample not in blacklist: + return sample + + for z in cycle(items): + x = sample_except({z}) + for _ in range(repeat_x): + y = sample_except({z,x}) + triplets.append(Triplet(z, x, y, items.index(z), items.index(x), items.index(y))) + bar.update() + if len(triplets) >= num_triplets: + return triplets + +def update(embeddings: np.ndarray, triplets: list[Triplet], alpha: float, lr: float) -> int: + """ Update embeddings using the t-STE gradient for each triplet. """ + violations_fixed: int = 0 + for idx, t in enumerate(triplets): + xi = embeddings[t.z_index] + xj = embeddings[t.x_index] + xl = embeddings[t.y_index] + + # Squared distances + dij = np.sum((xi - xj) ** 2) + dil = np.sum((xi - xl) ** 2) + + # Student-t similarities + sij = (1 + dij / alpha) ** (-(alpha + 1) / 2) + sil = (1 + dil / alpha) ** (-(alpha + 1) / 2) + pijl = sij / (sij + sil) + + # Gradients (see t-STE paper) + grad_coeff = (alpha + 1) / alpha + grad_xi = grad_coeff * ( + (1 - pijl) * (xj - xi) / (1 + dij / alpha) + - (1 - pijl) * (xl - xi) / (1 + dil / alpha) + ) + grad_xj = grad_coeff * (1 - pijl) * (xi - xj) / (1 + dij / alpha) + grad_xl = -grad_coeff * (1 - pijl) * (xi - xl) / (1 + dil / alpha) + + # Update embeddings + embeddings[t.z_index] = (xi + lr * grad_xi) + embeddings[t.x_index] = (xj + lr * grad_xj) + embeddings[t.y_index] = (xl + lr * grad_xl) + violations_fixed += 1 + return violations_fixed + +default_prompt = "Considering the nature of X, Y and Z, is X more similar to Z than Y is to Z? " + +class Cluster(Relation): + + def query_triplets(m:MelleaSession, triplets: list[Triplet], prompt: str) -> list[Triplet]: + """Given a triplet comparison query, perform the query using the LLM. """ + + answers = asyncio.gather(*[m.achoice(prompt + f"\nZ: {t.z}\nX: {t.x}\nY: {t.y}\n" , ["X", "Y"]) + for t in triplets]) + + logger.info(f"Queried {len(triplets)} triplets.") + + for idx, (t, a) in enumerate(zip(triplets, answers)): + logger.debug(f"Triplet {idx+1}: Z(anchor): {t.z} X: {t.x} Y: {t.y} result: {a}") + + return [t if a == "X" else t.swap() + for t, a in zip(triplets, answers) ] + + async def acluster(m:MelleaSession, criteria:str, elems:list[str], + model : ClusterMixin, + *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, + # + ndims: int = 2, + lr: float = 0.020, + max_iterations: int = 100, + tolerance: float = 1e-4, + alpha: float | None = None, + num_triplets: int | None = None, + triplets_per_item: int | None = None, + repeat_x: int = 3, + # + **kwargs): + """ + Generate Triplet Embeddings of the given strings, and run clustering + + Args: + criteria: triplet comparison criteria + elems: list of strings to cluster + model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering + + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + + ndims: number of dimensions for embeddings + lr: weight to give each triplet when updating embeddings + max_iterations: number of times to use LLM triplets to update embeddings + tolerance: hyperparamater; ??? + alpha: hyperparameter; ??? + num_triplets: the number of triplets to generate. + triplets_per_item: number of triplets to sample per item (will result in len(items) * triplets_per_item triplets) + repeat_x: how many times we reuse the same z, x for sampling y. + verbose: boolean to determine whether or not to provide verbose output + clustering_method: clustering method + + Returns: + Dictionary representing each label in items to its associated coordinate + + Either triplets_per_item or num_triplets must be specified. + triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). + """ + + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.INFO) + + start_time = datetime.now() + # Set alpha default based on ndims if not provided + if alpha is None: + alpha = ndims - 1 + if criteria is None: + criteria = default_criteria + + N: int = len(items) + + logger.info(f"Starting triplet embedding with N={N} items...") + logger.info(f"Algorithm parameters:") + logger.info(f" Embedding dimensions (r): {ndims}") + logger.info(f" Learning rate: {lr}") + logger.info(f" Max iterations: {max_iterations}") + logger.info(f" Triplets per item: {triplets_per_item}") + logger.info(f" Reuse the same z and x for: {repeat_x} times") + logger.info(f" Tolerance: {tolerance}") + logger.info(f" Alpha (DoF): {alpha}") + + # Initialize random embeddings + embeddings = np.random.normal(0, 0.1, (len(items), ndims)) + + # Show initial positions + logger.debug(f"Generated initial random embeddings in {ndims}D space") + + # Generate LLM triplets ONCE + triplets = sample_triplets(items, + triplets_per_item=triplets_per_item, + num_triplets=num_triplets, + repeat_x=repeat_x,) + logger.debug(f"Using {len(triplets)} LLM-judged triplets for all iterations") + + # swap X/Y of triplets using LLM. Now X is always closer to anchor Z than Y is to anchor Z + triplets = m.query_triplets(triplets, criteria) + + # Iterative improvement + stat = { + "violations_fixed": 0, + "convergence_ratio": 0.0, + } + for iteration in tqdm(range(max_iterations), desc="updating the embedding (outer loop)", position=0, postfix=stat): + # Use the same triplets every iteration + violations_fixed: int = update(embeddings, triplets, alpha, lr) + convergence_ratio: float = violations_fixed / len(triplets) if len(triplets) > 0 else 0 + + stat["violations_fixed"] = violations_fixed + stat["convergence_ratio"] = convergence_ratio + + if convergence_ratio < tolerance: + logger.debug(f"Converged early at iteration {iteration + 1} (ratio < {tolerance})") + break + + elapsed_time = datetime.now() - start_time + formatted = str(elapsed_time).split('.')[0] + logger.debug(f"Elapsed time: {formatted}") + + return model.fit_predict(embeddings) From ae308bcffc0092298a7a75b14eb480fd1c9efcfa Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 13:13:04 -0500 Subject: [PATCH 24/40] added cluster from acluster --- mellea_contribs/va/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index e6829f1..8b74df7 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -667,3 +667,6 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], logger.debug(f"Elapsed time: {formatted}") return model.fit_predict(embeddings) + + +Cluster.cluster = sync_wrapper(Sort.acluster) From 5700e053557f5aaa59d03b261c632c3acec2f9f3 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 13:44:07 -0500 Subject: [PATCH 25/40] answer choice with the index --- mellea_contribs/va/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 8b74df7..b89c4a7 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -43,14 +43,16 @@ async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: async def achoice(self:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: class Choice(BaseModel): - answer : Literal[choices] + answer : Literal[*[ str(i) for i in range(len(choices))]] - output = await self.ainstruct(f"{prompt} Respond with one of the following answers: " + ",".join([f"'{c}'" for c in choices]) + ".", + output = await self.ainstruct(f"{prompt}\n" + + f"Answer the index (0-{len(choices)-1}) of one of the following choices: \n" + + "\n".join([f"index {i}: {c}" for i, c in enumerate(choices)]), format=Choice, **kwargs) choice = Choice.model_validate_json(output.value) - return choice.answer + return int(choice.answer) pass From 42708fe7a90676748b6f6defc625b710b062024c Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 14:07:36 -0500 Subject: [PATCH 26/40] choice supports voting and shuffling --- mellea_contribs/va/__init__.py | 36 ++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index b89c4a7..2b42c6b 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -40,19 +40,39 @@ async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: return yesno.answer == "yes" - async def achoice(self:MelleaSession, prompt:str, choices:list[str], **kwargs) -> str: + async def achoice(self:MelleaSession, prompt:str, choices:list[str], *, vote:int=3, positional:bool=True, **kwargs) -> str: + # note: constraint decoding does not respect pydantic.conint + L = len(choices) class Choice(BaseModel): - answer : Literal[*[ str(i) for i in range(len(choices))]] + answer : Literal[*[ str(i) for i in range(L)]] + + async def choose(choices:list[str]) -> str: + output = await self.ainstruct(f"{prompt}\n" + + f"Answer the index (0-{L-1}) of one of the following choices: \n" + + "\n".join([f"index {i}: {c}" for i, c in enumerate(_choices)]), + format=Choice, **kwargs) + index = int(Choice.model_validate_json(output.value)) + return choices[index] + + if positional: + # enumerate random permutations while avoiding duplicaes + shuffled = set() + while len(shuffled) < vote: + _choices = choices.copy() + random.shuffle(_choices) + shuffled.add(tuple(choices)) + inputs = list(shuffled) + else: + inputs = [ choices for _ in range(vote) ] + + tasks = [choose(_choices) for _choices in inputs] - output = await self.ainstruct(f"{prompt}\n" + - f"Answer the index (0-{len(choices)-1}) of one of the following choices: \n" + - "\n".join([f"index {i}: {c}" for i, c in enumerate(choices)]), - format=Choice, **kwargs) + choices = asyncio.gather(*tasks) - choice = Choice.model_validate_json(output.value) + counter = Counter(choices) - return int(choice.answer) + return counter.most_common(1)[0][0] pass From 4ecdb14cca0912af1d70be2e193e62ee6c6faa0d Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 14:16:53 -0500 Subject: [PATCH 27/40] Submodular powerup --- mellea_contribs/va/__init__.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 2b42c6b..dc9b3c0 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -692,3 +692,34 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], Cluster.cluster = sync_wrapper(Sort.acluster) + + + +class Submodular(Core): + async def agreedy_submodular_maximization(m:MelleaSession, + criteria: str, + elems:list[str], + k:int, + *, + vote:int=3, + positional:bool=True, + **kwargs): + + current = [] + remaining = elems.copy() + + for _ in range(k): + chosen = await m.achoice(f"{criteria}\n" + + "The current set:\n" + + "\n".join(current) + "\n", + remaining, + vote=vote, + positional=positional, + **kwargs) + current.append(chosen) + remaining.remove(chosen) + + return current + + +Submodular.greedy_submodular_maximization = sync_wrapper(Submodular.agreedy_submodular_maximization) From bfbb491e45067f2a7f9e74e58b90bdaacaaad29a Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:17:18 -0500 Subject: [PATCH 28/40] renamed Submodular -> Subset --- mellea_contribs/va/__init__.py | 55 ++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index dc9b3c0..a06afd9 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -695,22 +695,53 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], -class Submodular(Core): - async def agreedy_submodular_maximization(m:MelleaSession, - criteria: str, - elems:list[str], - k:int, - *, - vote:int=3, - positional:bool=True, - **kwargs): +class Subset(Core): + async def asubset(m:MelleaSession, + description:str, + criteria: str, + elems:list[str], + k:int, + *, + vote:int=3, + positional:bool=True, + **kwargs): + """ + Greedily select a k-elements subset from elems, maximizing the given criteria. + + Args: + description: A decription of what the current and the output subset is meant to represent. + criteria: A decription of the desired property of the returned subset. + elems: The universe to select the subset from. + k: The number of elements to select from elems. + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. + + The criteria is assumed to be contain a modular or submodular aspect. + + Example 1: + + description = "We are building a team of culturally diverse members." + + criteria = "Maximize the cultural diversity among the members." + + Example 2: + + description = ("We need set of past legal cases that helps defending our case. " + "In our case, the defandant has ..." + "We want to see a variety of cases that are relevant to ours but" + "are also different from each other.") + + criteria = "Minimize the ovelap with the documents in the current set while staying relevant to our case." + """ current = [] remaining = elems.copy() for _ in range(k): - chosen = await m.achoice(f"{criteria}\n" + - "The current set:\n" + + chosen = await m.achoice(f"{description}\n" + "Choose the best element to add to the current set following the criteria:\n" + f"Criteria: {criteria}\n" + + "Current set:\n" + "\n".join(current) + "\n", remaining, vote=vote, @@ -722,4 +753,4 @@ async def agreedy_submodular_maximization(m:MelleaSession, return current -Submodular.greedy_submodular_maximization = sync_wrapper(Submodular.agreedy_submodular_maximization) +Subset.subset = sync_wrapper(Subset.asubset) From 5adf2292aafb5935c33cfc5ef8a041b09502dfdf Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:17:51 -0500 Subject: [PATCH 29/40] edit --- mellea_contribs/va/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index a06afd9..2df1585 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -718,12 +718,14 @@ async def asubset(m:MelleaSession, The criteria is assumed to be contain a modular or submodular aspect. + Example 1: description = "We are building a team of culturally diverse members." criteria = "Maximize the cultural diversity among the members." + Example 2: description = ("We need set of past legal cases that helps defending our case. " From 551ba26e9ac977516d0375cbf83a874918b122b8 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:22:46 -0500 Subject: [PATCH 30/40] refactor: split into files --- mellea_contribs/va/__init__.py | 760 +-------------------------------- mellea_contribs/va/cluster.py | 285 +++++++++++++ mellea_contribs/va/core.py | 67 +++ mellea_contribs/va/relation.py | 239 +++++++++++ mellea_contribs/va/sequence.py | 139 ++++++ mellea_contribs/va/subset.py | 76 ++++ mellea_contribs/va/util.py | 12 + 7 files changed, 823 insertions(+), 755 deletions(-) create mode 100644 mellea_contribs/va/cluster.py create mode 100644 mellea_contribs/va/core.py create mode 100644 mellea_contribs/va/relation.py create mode 100644 mellea_contribs/va/sequence.py create mode 100644 mellea_contribs/va/subset.py create mode 100644 mellea_contribs/va/util.py diff --git a/mellea_contribs/va/__init__.py b/mellea_contribs/va/__init__.py index 2df1585..3a37459 100644 --- a/mellea_contribs/va/__init__.py +++ b/mellea_contribs/va/__init__.py @@ -1,758 +1,8 @@ -import random -import functools -import itertools -import asyncio -import dataclasses -from mellea import MelleaSession -from mellea.helpers.fancy_logger import FancyLogger -from mellea.helpers.event_loop_helper import _run_async_in_thread -from pydantic import BaseModel +from .core import Core +from .relation import Relation +from .sequence import Sequence +from .cluster import Cluster +from .subset import Subset -from typing import Literal - -import numpy as np - -from sklearn.base import ClusterMixin - -logger = FancyLogger.get_logger() - -def sync_wrapper(async_fn): - """Wrap an async function so it can be called synchronously.""" - @functools.wraps(async_fn) - def wrapper(*args, **kwargs): - return _run_async_in_thread(async_fn(*args, **kwargs)) - return wrapper - - -class YesNo(BaseModel): - answer : Literal["yes","no"] - -class Core: - - async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: - - output = await m.ainstruct(f"{prompt} Answer yes or no.", - format=YesNo, **kwargs) - - yesno = YesNo.model_validate_json(output.value) - - return yesno.answer == "yes" - - async def achoice(self:MelleaSession, prompt:str, choices:list[str], *, vote:int=3, positional:bool=True, **kwargs) -> str: - - # note: constraint decoding does not respect pydantic.conint - L = len(choices) - class Choice(BaseModel): - answer : Literal[*[ str(i) for i in range(L)]] - - async def choose(choices:list[str]) -> str: - output = await self.ainstruct(f"{prompt}\n" + - f"Answer the index (0-{L-1}) of one of the following choices: \n" + - "\n".join([f"index {i}: {c}" for i, c in enumerate(_choices)]), - format=Choice, **kwargs) - index = int(Choice.model_validate_json(output.value)) - return choices[index] - - if positional: - # enumerate random permutations while avoiding duplicaes - shuffled = set() - while len(shuffled) < vote: - _choices = choices.copy() - random.shuffle(_choices) - shuffled.add(tuple(choices)) - inputs = list(shuffled) - else: - inputs = [ choices for _ in range(vote) ] - - tasks = [choose(_choices) for _choices in inputs] - - choices = asyncio.gather(*tasks) - - counter = Counter(choices) - - return counter.most_common(1)[0][0] - - pass - -Core.bool = sync_wrapper(Core.abool) -Core.choice = sync_wrapper(Core.achoice) - - -class Relation(Core): - - async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, - vote:int=3, - symmetric:bool=False, - asymmetric:bool=False, - reflexive:bool=False, - irreflexive:bool=False, - positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: - """Evaluates a query that evaluates a binary relation. - - Args: - criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. - x: the first element - y: the second element - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. - asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". - reflexive: Declares the relation to be reflexive, i.e., if x == y, returns True immediately. - irreflexive: Declares the relation to be irreflexive, i.e., if x == y, returns False immediately. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - Returns: - bool. - """ - - assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" - - if x == y: - if reflexive: - return True - if irreflexive: - return False - - if vote % 2 == 0: - logger.warning( - "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." - ) - vote += 1 - - if symmetric: - args = [(x,y),(y,x)] - target = [True,True] - elif asymmetric: - args = [(x,y),(y,x)] - target = [True,False] - else: - args = [(x,y)] - target = [True] - - prompts = [] - for (x, y), t in zip(args, target): - prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", t)) - if positional: - prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", t)) - - if shuffle: - random.shuffle(prompts) - - tasks = [ - m.abool(p) - for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) - ] - - answers = asyncio.gather(*tasks) - - answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] - - return answers.count(True) >= (vote // 2) + 1 - - async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, - vote:int=3, - symmetric:bool=False, - asymmetric:bool=False, - positional:bool=True, - shuffle:bool=True, - **kwargs) -> bool: - """Evaluates a query that evaluates a ternary relation. - - Args: - criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. - x: the first element - y: the second element - z: the third element - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - symmetric: Declares the relation to be symmetric wrto x and y. Half of the queries swap x and y. - asymmetric: Declares the relation to be asymmetric wrto x and y. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". - positional: The queries permutes the order of presenting x, y, z. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - Returns: - bool. - """ - - assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" - - if vote % 2 == 0: - logger.warning( - "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." - ) - vote += 1 - - if symmetric: - args = [(x,y,z),(y,x,z)] - target = [True,True] - elif asymmetric: - args = [(x,y,z),(y,x,z)] - target = [True,False] - else: - args = [(x,y,z)] - target = [True] - - prompts = [] - for (x, y, z), t in zip(args, target): - parts = [f"X:{x}", f"Y:{y}", f"Z:{z}"] - if positional: - for _parts in itertools.permutations(parts): - prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *_parts]), t)) - else: - prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *parts]), t)) - - if shuffle: - random.shuffle(prompts) - - tasks = [ - m.abool(p) - for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) - ] - - answers = asyncio.gather(*tasks) - - answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] - - return answers.count(True) >= (vote // 2) + 1 - - async def agt(m:MelleaSession, criteria:str, x:str, y:str, *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: - """Evaluates a query that evaluates a "greater-than" relation. - - Args: - criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. - x: the first element - y: the second element - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - Returns: - bool. - """ - return await m.abinary(criteria, x, y, - vote=vote, - symmetric=False, - asymmetric=True, - reflexive=False, - irreflexive=True, - shuffle=shuffle, **kwargs) - - async def age(m:MelleaSession, criteria:str, x:str, y:str, *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: - """Evaluates a query that evaluates a "greater-than-equal" relation. - - Args: - criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. - x: the first element - y: the second element - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - Returns: - bool. - """ - return await m.abinary(criteria, x, y, - vote=vote, - symmetric=False, - asymmetric=True, - reflexive=True, - irreflexive=False, - shuffle=shuffle, **kwargs) - - async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, **kwargs) -> bool: - """Evaluates a query that evaluates an equivalence relation. - - Args: - criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. - x: the first element - y: the second element - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - Returns: - bool. - """ - return await m.abinary(criteria, x, y, - vote=vote, - symmetric=True, - asymmetric=False, - reflexive=True, - irreflexive=False, - shuffle=shuffle, **kwargs) - - -Relation.binary = sync_wrapper(Relation.abinary) -Relation.ternary = sync_wrapper(Relation.aternary) -Relation.gt = sync_wrapper(Relation.agt) -Relation.ge = sync_wrapper(Relation.age) -Relation.eq = sync_wrapper(Relation.aeq) - - -async def async_merge_sort(lst:list[str], acmp): - if len(lst) <= 1: - return lst - mid = len(lst) // 2 - left = await async_merge_sort(lst[:mid], acmp) - right = await async_merge_sort(lst[mid:], acmp) - return await async_merge(left, right, acmp) - -async def async_merge(left:list[str], right:list[str], acmp): - result = [] - while left and right: - if await acmp(left[0], right[0]): - result.append(left.pop(0)) - else: - result.append(right.pop(0)) - return result + left + right - -async def async_max(lst:list[str], acmp): - if len(lst) <= 1: - return lst[0] - mid = len(lst) // 2 - left = await async_max(lst[:mid], acmp) - right = await async_max(lst[mid:], acmp) - if await acmp(left, right): - return left - else: - return right - -async def async_mom(seq:list[str], acmp, asort, block_size=5): - """ - Median of medians algorithm for finding an approximate median. Worst-case runtime O(n) - """ - - async def median_fixed(seq): - return await asort(seq)[len(seq)//2] - - if len(seq) <= block_size: - return await median_fixed(seq) - - # Step 1: Divide into groups of block_size - groups = itertools.batched(seq, block_size) - - # Step 2: Find median of each group - medians = asyncio.gather(*[median_fixed(g) for g in groups]) - - # Step 3: Recursively find the pivot - return await async_mom(medians, acmp, asort, block_size=block_size) - -async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): - """ - Quickselect algorithm that uses median-of-medians for pivot selection. Worst-case runtime O(n^2) - """ - - pivot = await async_mom(medians, acmp, asort, block_size=block_size) - - # Step 4: Partition - lows, highs = [], [] - for x in seq: - if await acmp(x, pivot): - lows.append(x) - else: - highs.append(x) - - # Step block_size: Recurse - if k < len(lows): - return await async_quickselect(lows, k, acmp, asort, block_size=block_size) - elif k == len(lows): - return pivot - else: - return await async_quickselect(highs, k - len(lows), acmp, asort, block_size=block_size) - - -class Sequence(Relation): - - async def asort(m:MelleaSession, criteria:str, elems:list[str], *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, **kwargs) -> list[str]: - - async def acmp(x, y): - return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - - return async_merge_sort(elems, acmp) - - async def amax(m:MelleaSession, criteria:str, elems:list[str], *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, **kwargs) -> str: - - async def acmp(x, y): - return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - - return async_max(elems, acmp) - - async def amedian(m:MelleaSession, criteria:str, elems:list[str], *, - exact = False, - vote:int=3, - positional:bool=True, - shuffle:bool=True, - block_size:int=5, - **kwargs) -> str: - """ - If exact = True, use quickselect. - Otherwise, return the approximate median returned by median of medians. - """ - - async def acmp(x, y): - return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - - async def asort(elems:list[str]): - return await m.asort(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) - - if exact: - return await async_quickselect(elems, len(elems)//2, acmp, asort, block_size=block_size) - else: - return await async_mom(elems, acmp, asort, block_size=block_size) - - - -Sort.sort = sync_wrapper(Sort.asort) -Sort.max = sync_wrapper(Sort.amax) -Sort.median = sync_wrapper(Sort.amedian) - - - -@dataclasses.dataclass -class Triplet: - z : str # anchor - x : str - y : str - z_index : int # z's index in items - x_index : int # x's index in items - y_index : int # y's index in items - - def swap(self): - "swaps x and y" - return Triplet(self.z, self.y, self.x, self.z_index, self.y_index, self.x_index) - - -def sample_triplets(items: list[str], - triplets_per_item: int | None = None, - num_triplets: int | None = None, - repeat_x: int = 1, - ) -> list[Triplet]: - """Randomly sample a list of triplet comparison queries. - - Args: - items : input - triplets_per_item : how many triplets to generate relative to the number of items. - num_triplets : the number of triplets to generate. - repeat_x : how many times we reuse the same z, x for sampling y. - - Either triplets_per_item or num_triplets must be specified. - triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). - - """ - N = len(items) - - assert (num_triplets is not None) or (triplets_per_item is not None), \ - "Specify either num_triplets and triplets_per_item." - assert (num_triplets is None) or (triplets_per_item is None), \ - "num_triplets and triplets_per_item are mutually exclusive; do not specify both." - if num_triplets is None: - assert isinstance(triplets_per_item, int) - logger.info(f"num_triplets = triplets_per_item * len(items) = {triplets_per_item} * {N} = {triplets_per_item * N}") - num_triplets = triplets_per_item * N - - # make sure z covers all elements - assert num_triplets / N >= 1, \ - ("Some items are never used as an anchor z. Increase num_triplets or triplets_per_item: " - f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}") - - # make sure z covers all elements even if we sample multiple triplets with the same x - if repeat_x > num_triplets / N: - logger.warning(f"Some items are never used as an anchor z because of too large repeat_x. Overriding it with {num_triplets / N}: " - f"repeat_x = {repeat_x} > " - f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}.") - repeat_x = num_triplets // N - - # switch to the exhaustive mode if num_triplets is large enough - all_triplets = N * (N-1) * (N-2) - logger.info(f"all_triplets = N * (N-1) * (N-2) = {all_triplets}, where N = {N}") - if num_triplets > all_triplets: - logger.warning(f"num_triplets = {num_triplets} is large enough to enumerate all triplets (> {N} * {(N-1)} * {(N-2)} = {all_triplets}). " - f"Switching to the exhaustive mode.") - exhaustive = True - num_triplets = all_triplets - else: - exhaustive = False - - triplets: list[Triplet] = [] - - bar = tqdm(total=num_triplets, desc="sampling triplets") - - if exhaustive: - for i, z in enumerate(items): - for j, x in enumerate(items): - if i == j: - continue - for k, y in enumerate(items): - if k == i or k == j: - continue - triplets.append(Triplet(z, x, y, i, j, k)) - bar.update() - assert len(triplets) == all_triplets - return triplets - - def sample_except(blacklist:set[str]): - while True: - sample = random.choice(items) - if sample not in blacklist: - return sample - - for z in cycle(items): - x = sample_except({z}) - for _ in range(repeat_x): - y = sample_except({z,x}) - triplets.append(Triplet(z, x, y, items.index(z), items.index(x), items.index(y))) - bar.update() - if len(triplets) >= num_triplets: - return triplets - -def update(embeddings: np.ndarray, triplets: list[Triplet], alpha: float, lr: float) -> int: - """ Update embeddings using the t-STE gradient for each triplet. """ - violations_fixed: int = 0 - for idx, t in enumerate(triplets): - xi = embeddings[t.z_index] - xj = embeddings[t.x_index] - xl = embeddings[t.y_index] - - # Squared distances - dij = np.sum((xi - xj) ** 2) - dil = np.sum((xi - xl) ** 2) - - # Student-t similarities - sij = (1 + dij / alpha) ** (-(alpha + 1) / 2) - sil = (1 + dil / alpha) ** (-(alpha + 1) / 2) - pijl = sij / (sij + sil) - - # Gradients (see t-STE paper) - grad_coeff = (alpha + 1) / alpha - grad_xi = grad_coeff * ( - (1 - pijl) * (xj - xi) / (1 + dij / alpha) - - (1 - pijl) * (xl - xi) / (1 + dil / alpha) - ) - grad_xj = grad_coeff * (1 - pijl) * (xi - xj) / (1 + dij / alpha) - grad_xl = -grad_coeff * (1 - pijl) * (xi - xl) / (1 + dil / alpha) - - # Update embeddings - embeddings[t.z_index] = (xi + lr * grad_xi) - embeddings[t.x_index] = (xj + lr * grad_xj) - embeddings[t.y_index] = (xl + lr * grad_xl) - violations_fixed += 1 - return violations_fixed - -default_prompt = "Considering the nature of X, Y and Z, is X more similar to Z than Y is to Z? " - -class Cluster(Relation): - - def query_triplets(m:MelleaSession, triplets: list[Triplet], prompt: str) -> list[Triplet]: - """Given a triplet comparison query, perform the query using the LLM. """ - - answers = asyncio.gather(*[m.achoice(prompt + f"\nZ: {t.z}\nX: {t.x}\nY: {t.y}\n" , ["X", "Y"]) - for t in triplets]) - - logger.info(f"Queried {len(triplets)} triplets.") - - for idx, (t, a) in enumerate(zip(triplets, answers)): - logger.debug(f"Triplet {idx+1}: Z(anchor): {t.z} X: {t.x} Y: {t.y} result: {a}") - - return [t if a == "X" else t.swap() - for t, a in zip(triplets, answers) ] - - async def acluster(m:MelleaSession, criteria:str, elems:list[str], - model : ClusterMixin, - *, - vote:int=3, - positional:bool=True, - shuffle:bool=True, - # - ndims: int = 2, - lr: float = 0.020, - max_iterations: int = 100, - tolerance: float = 1e-4, - alpha: float | None = None, - num_triplets: int | None = None, - triplets_per_item: int | None = None, - repeat_x: int = 3, - # - **kwargs): - """ - Generate Triplet Embeddings of the given strings, and run clustering - - Args: - criteria: triplet comparison criteria - elems: list of strings to cluster - model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering - - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - - ndims: number of dimensions for embeddings - lr: weight to give each triplet when updating embeddings - max_iterations: number of times to use LLM triplets to update embeddings - tolerance: hyperparamater; ??? - alpha: hyperparameter; ??? - num_triplets: the number of triplets to generate. - triplets_per_item: number of triplets to sample per item (will result in len(items) * triplets_per_item triplets) - repeat_x: how many times we reuse the same z, x for sampling y. - verbose: boolean to determine whether or not to provide verbose output - clustering_method: clustering method - - Returns: - Dictionary representing each label in items to its associated coordinate - - Either triplets_per_item or num_triplets must be specified. - triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). - """ - - if verbose: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.INFO) - - start_time = datetime.now() - # Set alpha default based on ndims if not provided - if alpha is None: - alpha = ndims - 1 - if criteria is None: - criteria = default_criteria - - N: int = len(items) - - logger.info(f"Starting triplet embedding with N={N} items...") - logger.info(f"Algorithm parameters:") - logger.info(f" Embedding dimensions (r): {ndims}") - logger.info(f" Learning rate: {lr}") - logger.info(f" Max iterations: {max_iterations}") - logger.info(f" Triplets per item: {triplets_per_item}") - logger.info(f" Reuse the same z and x for: {repeat_x} times") - logger.info(f" Tolerance: {tolerance}") - logger.info(f" Alpha (DoF): {alpha}") - - # Initialize random embeddings - embeddings = np.random.normal(0, 0.1, (len(items), ndims)) - - # Show initial positions - logger.debug(f"Generated initial random embeddings in {ndims}D space") - - # Generate LLM triplets ONCE - triplets = sample_triplets(items, - triplets_per_item=triplets_per_item, - num_triplets=num_triplets, - repeat_x=repeat_x,) - logger.debug(f"Using {len(triplets)} LLM-judged triplets for all iterations") - - # swap X/Y of triplets using LLM. Now X is always closer to anchor Z than Y is to anchor Z - triplets = m.query_triplets(triplets, criteria) - - # Iterative improvement - stat = { - "violations_fixed": 0, - "convergence_ratio": 0.0, - } - for iteration in tqdm(range(max_iterations), desc="updating the embedding (outer loop)", position=0, postfix=stat): - # Use the same triplets every iteration - violations_fixed: int = update(embeddings, triplets, alpha, lr) - convergence_ratio: float = violations_fixed / len(triplets) if len(triplets) > 0 else 0 - - stat["violations_fixed"] = violations_fixed - stat["convergence_ratio"] = convergence_ratio - - if convergence_ratio < tolerance: - logger.debug(f"Converged early at iteration {iteration + 1} (ratio < {tolerance})") - break - - elapsed_time = datetime.now() - start_time - formatted = str(elapsed_time).split('.')[0] - logger.debug(f"Elapsed time: {formatted}") - - return model.fit_predict(embeddings) - - -Cluster.cluster = sync_wrapper(Sort.acluster) - - - -class Subset(Core): - async def asubset(m:MelleaSession, - description:str, - criteria: str, - elems:list[str], - k:int, - *, - vote:int=3, - positional:bool=True, - **kwargs): - """ - Greedily select a k-elements subset from elems, maximizing the given criteria. - - Args: - description: A decription of what the current and the output subset is meant to represent. - criteria: A decription of the desired property of the returned subset. - elems: The universe to select the subset from. - k: The number of elements to select from elems. - vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. - positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. - - The criteria is assumed to be contain a modular or submodular aspect. - - - Example 1: - - description = "We are building a team of culturally diverse members." - - criteria = "Maximize the cultural diversity among the members." - - - Example 2: - - description = ("We need set of past legal cases that helps defending our case. " - "In our case, the defandant has ..." - "We want to see a variety of cases that are relevant to ours but" - "are also different from each other.") - - criteria = "Minimize the ovelap with the documents in the current set while staying relevant to our case." - """ - - current = [] - remaining = elems.copy() - - for _ in range(k): - chosen = await m.achoice(f"{description}\n" - "Choose the best element to add to the current set following the criteria:\n" - f"Criteria: {criteria}\n" + - "Current set:\n" + - "\n".join(current) + "\n", - remaining, - vote=vote, - positional=positional, - **kwargs) - current.append(chosen) - remaining.remove(chosen) - - return current - - -Subset.subset = sync_wrapper(Subset.asubset) diff --git a/mellea_contribs/va/cluster.py b/mellea_contribs/va/cluster.py new file mode 100644 index 0000000..9f6a7e4 --- /dev/null +++ b/mellea_contribs/va/cluster.py @@ -0,0 +1,285 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +import numpy as np + +from sklearn.base import ClusterMixin + +from .util import sync_wrapper +from .relation import Relation + + + +@dataclasses.dataclass +class Triplet: + z : str # anchor + x : str + y : str + z_index : int # z's index in items + x_index : int # x's index in items + y_index : int # y's index in items + + def swap(self): + "swaps x and y" + return Triplet(self.z, self.y, self.x, self.z_index, self.y_index, self.x_index) + + +def sample_triplets(items: list[str], + triplets_per_item: int | None = None, + num_triplets: int | None = None, + repeat_x: int = 1, + ) -> list[Triplet]: + """Randomly sample a list of triplet comparison queries. + + Args: + items : input + triplets_per_item : how many triplets to generate relative to the number of items. + num_triplets : the number of triplets to generate. + repeat_x : how many times we reuse the same z, x for sampling y. + + Either triplets_per_item or num_triplets must be specified. + triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). + + """ + N = len(items) + + assert (num_triplets is not None) or (triplets_per_item is not None), \ + "Specify either num_triplets and triplets_per_item." + assert (num_triplets is None) or (triplets_per_item is None), \ + "num_triplets and triplets_per_item are mutually exclusive; do not specify both." + if num_triplets is None: + assert isinstance(triplets_per_item, int) + logger.info(f"num_triplets = triplets_per_item * len(items) = {triplets_per_item} * {N} = {triplets_per_item * N}") + num_triplets = triplets_per_item * N + + # make sure z covers all elements + assert num_triplets / N >= 1, \ + ("Some items are never used as an anchor z. Increase num_triplets or triplets_per_item: " + f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}") + + # make sure z covers all elements even if we sample multiple triplets with the same x + if repeat_x > num_triplets / N: + logger.warning(f"Some items are never used as an anchor z because of too large repeat_x. Overriding it with {num_triplets / N}: " + f"repeat_x = {repeat_x} > " + f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}.") + repeat_x = num_triplets // N + + # switch to the exhaustive mode if num_triplets is large enough + all_triplets = N * (N-1) * (N-2) + logger.info(f"all_triplets = N * (N-1) * (N-2) = {all_triplets}, where N = {N}") + if num_triplets > all_triplets: + logger.warning(f"num_triplets = {num_triplets} is large enough to enumerate all triplets (> {N} * {(N-1)} * {(N-2)} = {all_triplets}). " + f"Switching to the exhaustive mode.") + exhaustive = True + num_triplets = all_triplets + else: + exhaustive = False + + triplets: list[Triplet] = [] + + bar = tqdm(total=num_triplets, desc="sampling triplets") + + if exhaustive: + for i, z in enumerate(items): + for j, x in enumerate(items): + if i == j: + continue + for k, y in enumerate(items): + if k == i or k == j: + continue + triplets.append(Triplet(z, x, y, i, j, k)) + bar.update() + assert len(triplets) == all_triplets + return triplets + + def sample_except(blacklist:set[str]): + while True: + sample = random.choice(items) + if sample not in blacklist: + return sample + + for z in cycle(items): + x = sample_except({z}) + for _ in range(repeat_x): + y = sample_except({z,x}) + triplets.append(Triplet(z, x, y, items.index(z), items.index(x), items.index(y))) + bar.update() + if len(triplets) >= num_triplets: + return triplets + +def update(embeddings: np.ndarray, triplets: list[Triplet], alpha: float, lr: float) -> int: + """ Update embeddings using the t-STE gradient for each triplet. """ + violations_fixed: int = 0 + for idx, t in enumerate(triplets): + xi = embeddings[t.z_index] + xj = embeddings[t.x_index] + xl = embeddings[t.y_index] + + # Squared distances + dij = np.sum((xi - xj) ** 2) + dil = np.sum((xi - xl) ** 2) + + # Student-t similarities + sij = (1 + dij / alpha) ** (-(alpha + 1) / 2) + sil = (1 + dil / alpha) ** (-(alpha + 1) / 2) + pijl = sij / (sij + sil) + + # Gradients (see t-STE paper) + grad_coeff = (alpha + 1) / alpha + grad_xi = grad_coeff * ( + (1 - pijl) * (xj - xi) / (1 + dij / alpha) + - (1 - pijl) * (xl - xi) / (1 + dil / alpha) + ) + grad_xj = grad_coeff * (1 - pijl) * (xi - xj) / (1 + dij / alpha) + grad_xl = -grad_coeff * (1 - pijl) * (xi - xl) / (1 + dil / alpha) + + # Update embeddings + embeddings[t.z_index] = (xi + lr * grad_xi) + embeddings[t.x_index] = (xj + lr * grad_xj) + embeddings[t.y_index] = (xl + lr * grad_xl) + violations_fixed += 1 + return violations_fixed + +default_prompt = "Considering the nature of X, Y and Z, is X more similar to Z than Y is to Z? " + +class Cluster(Relation): + + def query_triplets(m:MelleaSession, triplets: list[Triplet], prompt: str) -> list[Triplet]: + """Given a triplet comparison query, perform the query using the LLM. """ + + answers = asyncio.gather(*[m.achoice(prompt + f"\nZ: {t.z}\nX: {t.x}\nY: {t.y}\n" , ["X", "Y"]) + for t in triplets]) + + logger.info(f"Queried {len(triplets)} triplets.") + + for idx, (t, a) in enumerate(zip(triplets, answers)): + logger.debug(f"Triplet {idx+1}: Z(anchor): {t.z} X: {t.x} Y: {t.y} result: {a}") + + return [t if a == "X" else t.swap() + for t, a in zip(triplets, answers) ] + + async def acluster(m:MelleaSession, criteria:str, elems:list[str], + model : ClusterMixin, + *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, + # + ndims: int = 2, + lr: float = 0.020, + max_iterations: int = 100, + tolerance: float = 1e-4, + alpha: float | None = None, + num_triplets: int | None = None, + triplets_per_item: int | None = None, + repeat_x: int = 3, + # + **kwargs): + """ + Generate Triplet Embeddings of the given strings, and run clustering + + Args: + criteria: triplet comparison criteria + elems: list of strings to cluster + model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering + + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + + ndims: number of dimensions for embeddings + lr: weight to give each triplet when updating embeddings + max_iterations: number of times to use LLM triplets to update embeddings + tolerance: hyperparamater; ??? + alpha: hyperparameter; ??? + num_triplets: the number of triplets to generate. + triplets_per_item: number of triplets to sample per item (will result in len(items) * triplets_per_item triplets) + repeat_x: how many times we reuse the same z, x for sampling y. + verbose: boolean to determine whether or not to provide verbose output + clustering_method: clustering method + + Returns: + Dictionary representing each label in items to its associated coordinate + + Either triplets_per_item or num_triplets must be specified. + triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). + """ + + if verbose: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.INFO) + + start_time = datetime.now() + # Set alpha default based on ndims if not provided + if alpha is None: + alpha = ndims - 1 + if criteria is None: + criteria = default_criteria + + N: int = len(items) + + logger.info(f"Starting triplet embedding with N={N} items...") + logger.info(f"Algorithm parameters:") + logger.info(f" Embedding dimensions (r): {ndims}") + logger.info(f" Learning rate: {lr}") + logger.info(f" Max iterations: {max_iterations}") + logger.info(f" Triplets per item: {triplets_per_item}") + logger.info(f" Reuse the same z and x for: {repeat_x} times") + logger.info(f" Tolerance: {tolerance}") + logger.info(f" Alpha (DoF): {alpha}") + + # Initialize random embeddings + embeddings = np.random.normal(0, 0.1, (len(items), ndims)) + + # Show initial positions + logger.debug(f"Generated initial random embeddings in {ndims}D space") + + # Generate LLM triplets ONCE + triplets = sample_triplets(items, + triplets_per_item=triplets_per_item, + num_triplets=num_triplets, + repeat_x=repeat_x,) + logger.debug(f"Using {len(triplets)} LLM-judged triplets for all iterations") + + # swap X/Y of triplets using LLM. Now X is always closer to anchor Z than Y is to anchor Z + triplets = m.query_triplets(triplets, criteria) + + # Iterative improvement + stat = { + "violations_fixed": 0, + "convergence_ratio": 0.0, + } + for iteration in tqdm(range(max_iterations), desc="updating the embedding (outer loop)", position=0, postfix=stat): + # Use the same triplets every iteration + violations_fixed: int = update(embeddings, triplets, alpha, lr) + convergence_ratio: float = violations_fixed / len(triplets) if len(triplets) > 0 else 0 + + stat["violations_fixed"] = violations_fixed + stat["convergence_ratio"] = convergence_ratio + + if convergence_ratio < tolerance: + logger.debug(f"Converged early at iteration {iteration + 1} (ratio < {tolerance})") + break + + elapsed_time = datetime.now() - start_time + formatted = str(elapsed_time).split('.')[0] + logger.debug(f"Elapsed time: {formatted}") + + return model.fit_predict(embeddings) + + +Cluster.cluster = sync_wrapper(Sort.acluster) + + diff --git a/mellea_contribs/va/core.py b/mellea_contribs/va/core.py new file mode 100644 index 0000000..2a99e83 --- /dev/null +++ b/mellea_contribs/va/core.py @@ -0,0 +1,67 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import sync_wrapper + +class YesNo(BaseModel): + answer : Literal["yes","no"] + +class Core: + + async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: + + output = await m.ainstruct(f"{prompt} Answer yes or no.", + format=YesNo, **kwargs) + + yesno = YesNo.model_validate_json(output.value) + + return yesno.answer == "yes" + + async def achoice(self:MelleaSession, prompt:str, choices:list[str], *, vote:int=3, positional:bool=True, **kwargs) -> str: + + # note: constraint decoding does not respect pydantic.conint + L = len(choices) + class Choice(BaseModel): + answer : Literal[*[ str(i) for i in range(L)]] + + async def choose(choices:list[str]) -> str: + output = await self.ainstruct(f"{prompt}\n" + + f"Answer the index (0-{L-1}) of one of the following choices: \n" + + "\n".join([f"index {i}: {c}" for i, c in enumerate(_choices)]), + format=Choice, **kwargs) + index = int(Choice.model_validate_json(output.value)) + return choices[index] + + if positional: + # enumerate random permutations while avoiding duplicaes + shuffled = set() + while len(shuffled) < vote: + _choices = choices.copy() + random.shuffle(_choices) + shuffled.add(tuple(choices)) + inputs = list(shuffled) + else: + inputs = [ choices for _ in range(vote) ] + + tasks = [choose(_choices) for _choices in inputs] + + choices = asyncio.gather(*tasks) + + counter = Counter(choices) + + return counter.most_common(1)[0][0] + + pass + +Core.bool = sync_wrapper(Core.abool) +Core.choice = sync_wrapper(Core.achoice) + diff --git a/mellea_contribs/va/relation.py b/mellea_contribs/va/relation.py new file mode 100644 index 0000000..c2f348e --- /dev/null +++ b/mellea_contribs/va/relation.py @@ -0,0 +1,239 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import sync_wrapper +from .core import Core + +class Relation(Core): + + async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + symmetric:bool=False, + asymmetric:bool=False, + reflexive:bool=False, + irreflexive:bool=False, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a binary relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: Declares the relation to be symmetric. Half of the queries swap x and y. + asymmetric: Declares the relation to be asymmetric. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + reflexive: Declares the relation to be reflexive, i.e., if x == y, returns True immediately. + irreflexive: Declares the relation to be irreflexive, i.e., if x == y, returns False immediately. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + + if x == y: + if reflexive: + return True + if irreflexive: + return False + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + args = [(x,y),(y,x)] + target = [True,True] + elif asymmetric: + args = [(x,y),(y,x)] + target = [True,False] + else: + args = [(x,y)] + target = [True] + + prompts = [] + for (x, y), t in zip(args, target): + prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nX:{x}\nY:{y}", t)) + if positional: + prompts.append((f"Do X and Y satisfy the following criteria? \nCriteria: {criteria}\nY:{y}\nX:{x}", t)) + + if shuffle: + random.shuffle(prompts) + + tasks = [ + m.abool(p) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) + ] + + answers = asyncio.gather(*tasks) + + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + + return answers.count(True) >= (vote // 2) + 1 + + async def aternary(m:MelleaSession, criteria:str, x:str, y:str, z:str, *, + vote:int=3, + symmetric:bool=False, + asymmetric:bool=False, + positional:bool=True, + shuffle:bool=True, + **kwargs) -> bool: + """Evaluates a query that evaluates a ternary relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + z: the third element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + symmetric: Declares the relation to be symmetric wrto x and y. Half of the queries swap x and y. + asymmetric: Declares the relation to be asymmetric wrto x and y. Half of the queries swap x and y, and asks if they violate the criteria. This mitigates LLM's psycophancy bias toward answering "yes". + positional: The queries permutes the order of presenting x, y, z. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + + assert not (symmetric and asymmetric), "symmetric and asymmetric flags are mutually exclusive" + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + if symmetric: + args = [(x,y,z),(y,x,z)] + target = [True,True] + elif asymmetric: + args = [(x,y,z),(y,x,z)] + target = [True,False] + else: + args = [(x,y,z)] + target = [True] + + prompts = [] + for (x, y, z), t in zip(args, target): + parts = [f"X:{x}", f"Y:{y}", f"Z:{z}"] + if positional: + for _parts in itertools.permutations(parts): + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *_parts]), t)) + else: + prompts.append(("\n".join([f"Do X, Y and Z satisfy the following criteria?", f"Criteria: {criteria}", *parts]), t)) + + if shuffle: + random.shuffle(prompts) + + tasks = [ + m.abool(p) + for i, (p, t) in zip(range(vote),itertools.cycle(prompts)) + ] + + answers = asyncio.gather(*tasks) + + answers = [ t == a for (p, t), a in zip(itertools.cycle(prompts), answers) ] + + return answers.count(True) >= (vote // 2) + 1 + + async def agt(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than" relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=False, + irreflexive=True, + shuffle=shuffle, **kwargs) + + async def age(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates a "greater-than-equal" relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=False, + asymmetric=True, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + async def aeq(m:MelleaSession, criteria:str, x:str, y:str, *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> bool: + """Evaluates a query that evaluates an equivalence relation. + + Args: + criteria: A natural language statement on variables X and Y. LLM decides if X and Y satisfy this critria, and this function returns yes if so. + x: the first element + y: the second element + vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. + positional: Permute the order of presenting x and y. This mitigates the positional bias. + shuffle: It shuffles the variation of queries (symmetric/positional variations). + This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). + For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. + Returns: + bool. + """ + return await m.abinary(criteria, x, y, + vote=vote, + symmetric=True, + asymmetric=False, + reflexive=True, + irreflexive=False, + shuffle=shuffle, **kwargs) + + +Relation.binary = sync_wrapper(Relation.abinary) +Relation.ternary = sync_wrapper(Relation.aternary) +Relation.gt = sync_wrapper(Relation.agt) +Relation.ge = sync_wrapper(Relation.age) +Relation.eq = sync_wrapper(Relation.aeq) + diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py new file mode 100644 index 0000000..5777957 --- /dev/null +++ b/mellea_contribs/va/sequence.py @@ -0,0 +1,139 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import sync_wrapper +from .relation import Relation + + +async def async_merge_sort(lst:list[str], acmp): + if len(lst) <= 1: + return lst + mid = len(lst) // 2 + left = await async_merge_sort(lst[:mid], acmp) + right = await async_merge_sort(lst[mid:], acmp) + return await async_merge(left, right, acmp) + +async def async_merge(left:list[str], right:list[str], acmp): + result = [] + while left and right: + if await acmp(left[0], right[0]): + result.append(left.pop(0)) + else: + result.append(right.pop(0)) + return result + left + right + +async def async_max(lst:list[str], acmp): + if len(lst) <= 1: + return lst[0] + mid = len(lst) // 2 + left = await async_max(lst[:mid], acmp) + right = await async_max(lst[mid:], acmp) + if await acmp(left, right): + return left + else: + return right + +async def async_mom(seq:list[str], acmp, asort, block_size=5): + """ + Median of medians algorithm for finding an approximate median. Worst-case runtime O(n) + """ + + async def median_fixed(seq): + return await asort(seq)[len(seq)//2] + + if len(seq) <= block_size: + return await median_fixed(seq) + + # Step 1: Divide into groups of block_size + groups = itertools.batched(seq, block_size) + + # Step 2: Find median of each group + medians = asyncio.gather(*[median_fixed(g) for g in groups]) + + # Step 3: Recursively find the pivot + return await async_mom(medians, acmp, asort, block_size=block_size) + +async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): + """ + Quickselect algorithm that uses median-of-medians for pivot selection. Worst-case runtime O(n^2) + """ + + pivot = await async_mom(medians, acmp, asort, block_size=block_size) + + # Step 4: Partition + lows, highs = [], [] + for x in seq: + if await acmp(x, pivot): + lows.append(x) + else: + highs.append(x) + + # Step block_size: Recurse + if k < len(lows): + return await async_quickselect(lows, k, acmp, asort, block_size=block_size) + elif k == len(lows): + return pivot + else: + return await async_quickselect(highs, k - len(lows), acmp, asort, block_size=block_size) + + +class Sequence(Relation): + + async def asort(m:MelleaSession, criteria:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> list[str]: + + async def acmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return async_merge_sort(elems, acmp) + + async def amax(m:MelleaSession, criteria:str, elems:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> str: + + async def acmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return async_max(elems, acmp) + + async def amedian(m:MelleaSession, criteria:str, elems:list[str], *, + exact = False, + vote:int=3, + positional:bool=True, + shuffle:bool=True, + block_size:int=5, + **kwargs) -> str: + """ + If exact = True, use quickselect. + Otherwise, return the approximate median returned by median of medians. + """ + + async def acmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + async def asort(elems:list[str]): + return await m.asort(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + if exact: + return await async_quickselect(elems, len(elems)//2, acmp, asort, block_size=block_size) + else: + return await async_mom(elems, acmp, asort, block_size=block_size) + + + +Sort.sort = sync_wrapper(Sort.asort) +Sort.max = sync_wrapper(Sort.amax) +Sort.median = sync_wrapper(Sort.amedian) + diff --git a/mellea_contribs/va/subset.py b/mellea_contribs/va/subset.py new file mode 100644 index 0000000..0ef090b --- /dev/null +++ b/mellea_contribs/va/subset.py @@ -0,0 +1,76 @@ +import random +import functools +import itertools +import asyncio +from mellea import MelleaSession +from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.event_loop_helper import _run_async_in_thread + +from pydantic import BaseModel + +from typing import Literal + +from .util import sync_wrapper +from .core import Core + +class Subset(Core): + async def asubset(m:MelleaSession, + description:str, + criteria: str, + elems:list[str], + k:int, + *, + vote:int=3, + positional:bool=True, + **kwargs): + """ + Greedily select a k-elements subset from elems, maximizing the given criteria. + + Args: + description: A decription of what the current and the output subset is meant to represent. + criteria: A decription of the desired property of the returned subset. + elems: The universe to select the subset from. + k: The number of elements to select from elems. + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. + + The criteria is assumed to be contain a modular or submodular aspect. + + + Example 1: + + description = "We are building a team of culturally diverse members." + + criteria = "Maximize the cultural diversity among the members." + + + Example 2: + + description = ("We need set of past legal cases that helps defending our case. " + "In our case, the defandant has ..." + "We want to see a variety of cases that are relevant to ours but" + "are also different from each other.") + + criteria = "Minimize the ovelap with the documents in the current set while staying relevant to our case." + """ + + current = [] + remaining = elems.copy() + + for _ in range(k): + chosen = await m.achoice(f"{description}\n" + "Choose the best element to add to the current set following the criteria:\n" + f"Criteria: {criteria}\n" + + "Current set:\n" + + "\n".join(current) + "\n", + remaining, + vote=vote, + positional=positional, + **kwargs) + current.append(chosen) + remaining.remove(chosen) + + return current + + +Subset.subset = sync_wrapper(Subset.asubset) diff --git a/mellea_contribs/va/util.py b/mellea_contribs/va/util.py new file mode 100644 index 0000000..3930c6e --- /dev/null +++ b/mellea_contribs/va/util.py @@ -0,0 +1,12 @@ + +import functools +from mellea.helpers.event_loop_helper import _run_async_in_thread + +def sync_wrapper(async_fn): + """Wrap an async function so it can be called synchronously.""" + @functools.wraps(async_fn) + def wrapper(*args, **kwargs): + return _run_async_in_thread(async_fn(*args, **kwargs)) + return wrapper + + From 4edc19a73a47e551949ad2a2d2a94707124a6b49 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:45:10 -0500 Subject: [PATCH 31/40] documentations --- mellea_contribs/va/core.py | 14 ++++++++++++++ mellea_contribs/va/relation.py | 5 +++++ mellea_contribs/va/sequence.py | 4 ++++ mellea_contribs/va/subset.py | 5 ++++- 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mellea_contribs/va/core.py b/mellea_contribs/va/core.py index 2a99e83..7311413 100644 --- a/mellea_contribs/va/core.py +++ b/mellea_contribs/va/core.py @@ -16,8 +16,14 @@ class YesNo(BaseModel): answer : Literal["yes","no"] class Core: + """ + The Core powerup provides a core functionality for extracting the embedded reward model in the model. + """ async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: + """ + Answers a yes/no question. + """ output = await m.ainstruct(f"{prompt} Answer yes or no.", format=YesNo, **kwargs) @@ -27,6 +33,14 @@ async def abool(m:MelleaSession, prompt:str, **kwargs) -> bool: return yesno.answer == "yes" async def achoice(self:MelleaSession, prompt:str, choices:list[str], *, vote:int=3, positional:bool=True, **kwargs) -> str: + """ + Answers a multiple-choice question. Returns an element of choices. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + positional: Shuffle the order to present the elements to the LLM in order to mitigate the positional bias. + + """ # note: constraint decoding does not respect pydantic.conint L = len(choices) diff --git a/mellea_contribs/va/relation.py b/mellea_contribs/va/relation.py index c2f348e..5edc393 100644 --- a/mellea_contribs/va/relation.py +++ b/mellea_contribs/va/relation.py @@ -14,6 +14,11 @@ from .core import Core class Relation(Core): + """ + The Relation powerup defines methods for binary and ternary predicates. + Options can be used to declare the property of the predicate, + such as being symmetric or reflexive with regard to certain arguments. + """ async def abinary(m:MelleaSession, criteria:str, x:str, y:str, *, vote:int=3, diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index 5777957..12a63f2 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -87,6 +87,10 @@ async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): class Sequence(Relation): + """ + Sequence powerup provides a set of sequence operations, such as + sorting a list of strings, selecting an element, or extracting the median according to some criteria. + """ async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, diff --git a/mellea_contribs/va/subset.py b/mellea_contribs/va/subset.py index 0ef090b..b822a5a 100644 --- a/mellea_contribs/va/subset.py +++ b/mellea_contribs/va/subset.py @@ -14,6 +14,9 @@ from .core import Core class Subset(Core): + """ + Subset powerup provides methods for selecting a subset of the input set. + """ async def asubset(m:MelleaSession, description:str, criteria: str, @@ -22,7 +25,7 @@ async def asubset(m:MelleaSession, *, vote:int=3, positional:bool=True, - **kwargs): + **kwargs) -> list[str]: """ Greedily select a k-elements subset from elems, maximizing the given criteria. From 77fab7756e7fc75869d239d7443dcf74047ca9b6 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:51:08 -0500 Subject: [PATCH 32/40] fixup! refactor: split into files --- mellea_contribs/va/sequence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index 12a63f2..13278fb 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -137,7 +137,7 @@ async def asort(elems:list[str]): -Sort.sort = sync_wrapper(Sort.asort) -Sort.max = sync_wrapper(Sort.amax) -Sort.median = sync_wrapper(Sort.amedian) +Sequence.sort = sync_wrapper(Sequence.asort) +Sequence.max = sync_wrapper(Sequence.amax) +Sequence.median = sync_wrapper(Sequence.amedian) From f1caf546858908f50003f62cffc79268c8000c64 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:51:20 -0500 Subject: [PATCH 33/40] Added sequence.map / amap --- mellea_contribs/va/sequence.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index 13278fb..dd48248 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -89,9 +89,23 @@ async def async_quickselect(seq:list[str], k, acmp, asort, block_size=5): class Sequence(Relation): """ Sequence powerup provides a set of sequence operations, such as - sorting a list of strings, selecting an element, or extracting the median according to some criteria. + mapping a list of strings, + sorting a list of strings, + selecting an element, or extracting the median according to some criteria. """ + async def amap(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> list[str]: + + tasks = [ + m.ainstruct("Apply the criteria to the target. \n"+ + f"Criteria: {criteria}\n" + f"Target: {elem}") + for elem in elems + ] + + return [o.value for o in asyncio.gather(*tasks)] + + async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, positional:bool=True, @@ -137,6 +151,7 @@ async def asort(elems:list[str]): +Sequence.map = sync_wrapper(Sequence.amap) Sequence.sort = sync_wrapper(Sequence.asort) Sequence.max = sync_wrapper(Sequence.amax) Sequence.median = sync_wrapper(Sequence.amedian) From aa71b9bd7a7a0b327afbe8a4cd12c15e24afceb2 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 15:51:46 -0500 Subject: [PATCH 34/40] Added subset.filter / afilter --- mellea_contribs/va/subset.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/mellea_contribs/va/subset.py b/mellea_contribs/va/subset.py index b822a5a..4c48b7c 100644 --- a/mellea_contribs/va/subset.py +++ b/mellea_contribs/va/subset.py @@ -17,6 +17,48 @@ class Subset(Core): """ Subset powerup provides methods for selecting a subset of the input set. """ + + async def afilter(m:MelleaSession, + criteria: str, + elems:list[str], + *, + vote:int=3, + **kwargs) -> list[str]: + """ + Returns a subset whose elements all satisfy the criteria. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + """ + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + async def per_elem(elem): + tasks = [ + m.abool("Does the input satisfy the criteria?\n"+ + f"Criteria: {criteria}\n"+ + f"Input: {elem}") + for _ in range(vote) + ] + return asyncio.gather(*tasks).count(True) >= (vote // 2 + 1) + + tasks = [ + per_elem(elem) + for elem in elems + ] + + results = [] + for answer, elem in zip(asyncio.gather(*tasks), elems): + if answer: + results.append(elem) + + return results + + async def asubset(m:MelleaSession, description:str, criteria: str, @@ -76,4 +118,5 @@ async def asubset(m:MelleaSession, return current +Subset.filter = sync_wrapper(Subset.afilter) Subset.subset = sync_wrapper(Subset.asubset) From 3a948971d3ecdfce74bf5104b686423fa8d2f149 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 16:16:59 -0500 Subject: [PATCH 35/40] Added sequence.find / afind --- mellea_contribs/va/sequence.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index dd48248..f3fe090 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -105,6 +105,41 @@ async def amap(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> list return [o.value for o in asyncio.gather(*tasks)] + async def afind(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> str | None: + + """ + Returns any element which satisfies the criteria. + It checks the criteria over the elements concurrently and returns the earliest element that satisfied the criteria, + cancelling all running or pending LLM calls. + + Args: + vote: When >=1, it samples multiple selections in each turn, and perform a majority voting. + """ + + if vote % 2 == 0: + logger.warning( + "the specified number of votes in a majority vote is even, making ties possible. Increasing the value by one to avoid this." + ) + vote += 1 + + async def per_elem(elem): + tasks = [ + m.abool("Does the input satisfy the criteria?\n"+ + f"Criteria: {criteria}\n"+ + f"Input: {elem}") + for _ in range(vote) + ] + return asyncio.gather(*tasks).count(True) >= (vote // 2 + 1), elem + + tasks = [ + per_elem(elem) + for elem in elems + ] + + async for answer, elem in asyncio.as_completed(*tasks): + if answer: + return elem + pass async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, @@ -152,6 +187,7 @@ async def asort(elems:list[str]): Sequence.map = sync_wrapper(Sequence.amap) +Sequence.find = sync_wrapper(Sequence.afind) Sequence.sort = sync_wrapper(Sequence.asort) Sequence.max = sync_wrapper(Sequence.amax) Sequence.median = sync_wrapper(Sequence.amedian) From 47af33935d998bf41e3b5b50272470befdd71fde Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 12 Nov 2025 16:26:27 -0500 Subject: [PATCH 36/40] Added sequence.merge / amerge --- mellea_contribs/va/sequence.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index f3fe090..3bd6666 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -141,6 +141,20 @@ async def per_elem(elem): return elem pass + async def amerge(m:MelleaSession, criteria:str, elems1:list[str], elems2:list[str], *, + vote:int=3, + positional:bool=True, + shuffle:bool=True, **kwargs) -> list[str]: + """ + Given two lists already sorted according to the criteria, + merge them into a list so that the resulting list is also sorted according to the criteria. + """ + + async def acmp(x, y): + return await m.agt(criteria, x, y, vote=vote, positional=positional, shuffle=shuffle, **kwargs) + + return async_merge(elems1, elems2, acmp) + async def asort(m:MelleaSession, criteria:str, elems:list[str], *, vote:int=3, positional:bool=True, @@ -188,6 +202,7 @@ async def asort(elems:list[str]): Sequence.map = sync_wrapper(Sequence.amap) Sequence.find = sync_wrapper(Sequence.afind) +Sequence.merge = sync_wrapper(Sequence.amerge) Sequence.sort = sync_wrapper(Sequence.asort) Sequence.max = sync_wrapper(Sequence.amax) Sequence.median = sync_wrapper(Sequence.amedian) From 3ddd925b4aefdf2bbd7ada9bf8e9c66c953f410f Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 13 Nov 2025 15:38:32 -0500 Subject: [PATCH 37/40] cluster: use DBSCAN as the default model --- mellea_contribs/va/cluster.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mellea_contribs/va/cluster.py b/mellea_contribs/va/cluster.py index 9f6a7e4..38587c9 100644 --- a/mellea_contribs/va/cluster.py +++ b/mellea_contribs/va/cluster.py @@ -13,6 +13,7 @@ import numpy as np from sklearn.base import ClusterMixin +from sklearn.cluster import DBSCAN from .util import sync_wrapper from .relation import Relation @@ -168,8 +169,8 @@ def query_triplets(m:MelleaSession, triplets: list[Triplet], prompt: str) -> lis for t, a in zip(triplets, answers) ] async def acluster(m:MelleaSession, criteria:str, elems:list[str], - model : ClusterMixin, *, + model : ClusterMixin = None, vote:int=3, positional:bool=True, shuffle:bool=True, @@ -190,7 +191,7 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], Args: criteria: triplet comparison criteria elems: list of strings to cluster - model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering + model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering. default = sklearn.cluster.DBSCAN vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. positional: Permute the order of presenting x and y. This mitigates the positional bias. @@ -216,6 +217,9 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). """ + if model is None: + model = DBSCAN() + if verbose: logger.setLevel(logging.DEBUG) else: @@ -280,6 +284,6 @@ async def acluster(m:MelleaSession, criteria:str, elems:list[str], return model.fit_predict(embeddings) -Cluster.cluster = sync_wrapper(Sort.acluster) +Cluster.cluster = sync_wrapper(Cluster.acluster) From 9873b54d16c7997bf1259506f72afdc49d4577f8 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 13 Nov 2025 16:31:33 -0500 Subject: [PATCH 38/40] map, find, filter uses variables --- mellea_contribs/va/sequence.py | 18 +++++++++--------- mellea_contribs/va/subset.py | 7 ++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mellea_contribs/va/sequence.py b/mellea_contribs/va/sequence.py index 3bd6666..076280c 100644 --- a/mellea_contribs/va/sequence.py +++ b/mellea_contribs/va/sequence.py @@ -94,21 +94,21 @@ class Sequence(Relation): selecting an element, or extracting the median according to some criteria. """ - async def amap(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> list[str]: + async def amap(m:MelleaSession, variable:str, output:str, elems:list[str], **kwargs) -> list[str]: tasks = [ - m.ainstruct("Apply the criteria to the target. \n"+ - f"Criteria: {criteria}\n" - f"Target: {elem}") + m.ainstruct(f"Given a value of {variable}, answer the value of the output. \n"+ + f"{variable}: {elem}\n" + + f"Output: {output}\n") for elem in elems ] return [o.value for o in asyncio.gather(*tasks)] - async def afind(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> str | None: + async def afind(m:MelleaSession, variable:str, criteria:str, elems:list[str], **kwargs) -> str | None: """ - Returns any element which satisfies the criteria. + Returns any element which satisfies the criteria about the variable. It checks the criteria over the elements concurrently and returns the earliest element that satisfied the criteria, cancelling all running or pending LLM calls. @@ -124,9 +124,9 @@ async def afind(m:MelleaSession, criteria:str, elems:list[str], **kwargs) -> str async def per_elem(elem): tasks = [ - m.abool("Does the input satisfy the criteria?\n"+ - f"Criteria: {criteria}\n"+ - f"Input: {elem}") + m.abool(f"Does {variable} satisfy the criteria?\n"+ + f"{variable}: {elem}\n"+ + f"Criteria: {criteria}") for _ in range(vote) ] return asyncio.gather(*tasks).count(True) >= (vote // 2 + 1), elem diff --git a/mellea_contribs/va/subset.py b/mellea_contribs/va/subset.py index 4c48b7c..9c4e05a 100644 --- a/mellea_contribs/va/subset.py +++ b/mellea_contribs/va/subset.py @@ -19,6 +19,7 @@ class Subset(Core): """ async def afilter(m:MelleaSession, + variable: str, criteria: str, elems:list[str], *, @@ -39,9 +40,9 @@ async def afilter(m:MelleaSession, async def per_elem(elem): tasks = [ - m.abool("Does the input satisfy the criteria?\n"+ - f"Criteria: {criteria}\n"+ - f"Input: {elem}") + m.abool(f"Does {variable} satisfy the criteria?\n"+ + f"{variable}: {elem}\n"+ + f"Criteria: {criteria}") for _ in range(vote) ] return asyncio.gather(*tasks).count(True) >= (vote // 2 + 1) From 33629f6ce4ff44a1d4d79755d09bc194966e55a1 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Thu, 13 Nov 2025 16:31:45 -0500 Subject: [PATCH 39/40] writing the test --- test/test_va.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 test/test_va.py diff --git a/test/test_va.py b/test/test_va.py new file mode 100644 index 0000000..c7d7d9f --- /dev/null +++ b/test/test_va.py @@ -0,0 +1,84 @@ +import pytest + +from mellea.backends.vllm import LocalVLLMBackend + +from mellea_contribs.va import Core, Relation, Sequence, Subset, Cluster + +from mellea import start_session +from mellea.stdlib.requirement import req +from mellea.stdlib.sampling import RejectionSamplingStrategy + +import timer + +MelleaSession.powerup(Core) +MelleaSession.powerup(Relation) +MelleaSession.powerup(Sequence) +MelleaSession.powerup(Subset) +MelleaSession.powerup(Cluster) + +@pytest.fixture(scope="module") +def m() -> MelleaSession: + return MelleaSession(backend=LocalVLLMBackend("Qwen/Qwen3-1.7B")) + + +def test_core(m: MelleaSession): + assert m.bool("Is a number 2 even?") + assert not m.bool("Is a number 5 even?") + assert m.choice("Which country is in Asia?", ["United States", "Norway", "Japan", "France", "Namibia"]) == "Japan" + +def test_relation(m: MelleaSession): + assert m.gt("The population of X is larger than that of Y.", "China", "Singapore") + + t1 = time.now() + answer = m.eq("People in country X and country Y speak the same language.", "Spain", "Mexico") + t2 = time.now() + assert answer + + t3 = time.now() + answer = m.eq("People in country X and country Y speak the same language.", "Spain", "Spain") + t4 = time.now() + assert answer + assert t4 - t3 < 1.0 + assert t4 - t3 < t2 - t1 + +def test_sequence(m: MelleaSession): + + assert m.map("X", "X+1", ["3", "5"]) == ["4", "6"] + + assert m.find("X", "X is a country in Asia.", ["United States", "Norway", "Japan", "France", "Namibia"]) == "Japan" + + messages = [ + # "You are the worst scum in this world", + "I hate you", + "I dislike you", + # "You are a bit annoying", + "You are okay", + # "You are not bad", + # "You are kind of nice", + "I like you", + "I love you", + # "Oh my gosh you are the best person in the world" + ] + random.shuffle(messages) + + results = m.sort("Message X shows a more positive sentiment than message Y does.", messages) + + assert results.find("I love you") > results.find("I hate you") + + assert m.max("Message X shows a more positive sentiment than message Y does.", messages) == "I love you" + + assert m.median("Message X shows a more positive sentiment than message Y does.", messages) == "You are okay" + + +def test_subset(m: MelleaSession): + + assert m.filter("X", "X is an insect", ["crow", "dolphin", "cockroach", "cicada"]) == ["cockroach", "cicada"] + + subset = m.subset("We need a set of things with different colors.", + "Select an element whose color is different from any of the current set.", + ["crow", "orange", "tomato", "cucumber", "coal", "strawberry"]) + + assert "orange" in subset + assert ("tomato" in subset) != ("strawberry" in subset) + assert ("crow" in subset) != ("coal" in subset) + From f8303fefff32f1c85ba52e75f074385cfa2ca6e1 Mon Sep 17 00:00:00 2001 From: Masataro Asai Date: Wed, 26 Nov 2025 15:46:27 -0500 Subject: [PATCH 40/40] denaulay-based clustering [wip] --- mellea_contribs/va/cluster.py | 422 +++++++++++++---------------- mellea_contribs/va/cluster_test.py | 161 +++++++++++ 2 files changed, 343 insertions(+), 240 deletions(-) create mode 100644 mellea_contribs/va/cluster_test.py diff --git a/mellea_contribs/va/cluster.py b/mellea_contribs/va/cluster.py index 38587c9..254bff6 100644 --- a/mellea_contribs/va/cluster.py +++ b/mellea_contribs/va/cluster.py @@ -2,288 +2,230 @@ import functools import itertools import asyncio +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt from mellea import MelleaSession from mellea.helpers.fancy_logger import FancyLogger from mellea.helpers.event_loop_helper import _run_async_in_thread from pydantic import BaseModel -from typing import Literal +from typing import ( + Literal, + Callable, + TypeVar, + List, +) import numpy as np -from sklearn.base import ClusterMixin -from sklearn.cluster import DBSCAN - from .util import sync_wrapper from .relation import Relation +T = TypeVar("T") +async def delaunay(elems:list[T], criteria:Callable[[T,T,T],bool], k:int=3) -> nx.Graph: -@dataclasses.dataclass -class Triplet: - z : str # anchor - x : str - y : str - z_index : int # z's index in items - x_index : int # x's index in items - y_index : int # y's index in items + assert len(elems) >= 2 - def swap(self): - "swaps x and y" - return Triplet(self.z, self.y, self.x, self.z_index, self.y_index, self.x_index) + g = nx.Graph() + for elem in elems: + g.add_node(elem) + for _ in range(k): -def sample_triplets(items: list[str], - triplets_per_item: int | None = None, - num_triplets: int | None = None, - repeat_x: int = 1, - ) -> list[Triplet]: - """Randomly sample a list of triplet comparison queries. + def select(elems:list[T]): + assert len(elems) >= 2 + _elems = elems.copy() + i1 = random.randint(0, len(_elems)-1) + r1 = _elems.pop(i1) + i2 = random.randint(0, len(_elems)-1) + r2 = _elems.pop(i2) + return r1, r2, _elems - Args: - items : input - triplets_per_item : how many triplets to generate relative to the number of items. - num_triplets : the number of triplets to generate. - repeat_x : how many times we reuse the same z, x for sampling y. + async def split(x:T, y:T, S:list[T]): + """Split a set S into Sx and Sy, which are closer to x or y, respectively""" - Either triplets_per_item or num_triplets must be specified. - triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). + Sx = [] + Sy = [] + for z in S: + if await criteria(x, y, z): + Sx.append(z) + else: + Sy.append(z) - """ - N = len(items) - - assert (num_triplets is not None) or (triplets_per_item is not None), \ - "Specify either num_triplets and triplets_per_item." - assert (num_triplets is None) or (triplets_per_item is None), \ - "num_triplets and triplets_per_item are mutually exclusive; do not specify both." - if num_triplets is None: - assert isinstance(triplets_per_item, int) - logger.info(f"num_triplets = triplets_per_item * len(items) = {triplets_per_item} * {N} = {triplets_per_item * N}") - num_triplets = triplets_per_item * N - - # make sure z covers all elements - assert num_triplets / N >= 1, \ - ("Some items are never used as an anchor z. Increase num_triplets or triplets_per_item: " - f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}") - - # make sure z covers all elements even if we sample multiple triplets with the same x - if repeat_x > num_triplets / N: - logger.warning(f"Some items are never used as an anchor z because of too large repeat_x. Overriding it with {num_triplets / N}: " - f"repeat_x = {repeat_x} > " - f"num_triplets / len(items) = {num_triplets} / {N} = {num_triplets / N}.") - repeat_x = num_triplets // N - - # switch to the exhaustive mode if num_triplets is large enough - all_triplets = N * (N-1) * (N-2) - logger.info(f"all_triplets = N * (N-1) * (N-2) = {all_triplets}, where N = {N}") - if num_triplets > all_triplets: - logger.warning(f"num_triplets = {num_triplets} is large enough to enumerate all triplets (> {N} * {(N-1)} * {(N-2)} = {all_triplets}). " - f"Switching to the exhaustive mode.") - exhaustive = True - num_triplets = all_triplets - else: - exhaustive = False - - triplets: list[Triplet] = [] - - bar = tqdm(total=num_triplets, desc="sampling triplets") - - if exhaustive: - for i, z in enumerate(items): - for j, x in enumerate(items): - if i == j: - continue - for k, y in enumerate(items): - if k == i or k == j: - continue - triplets.append(Triplet(z, x, y, i, j, k)) - bar.update() - assert len(triplets) == all_triplets - return triplets - - def sample_except(blacklist:set[str]): - while True: - sample = random.choice(items) - if sample not in blacklist: - return sample - - for z in cycle(items): - x = sample_except({z}) - for _ in range(repeat_x): - y = sample_except({z,x}) - triplets.append(Triplet(z, x, y, items.index(z), items.index(x), items.index(y))) - bar.update() - if len(triplets) >= num_triplets: - return triplets - -def update(embeddings: np.ndarray, triplets: list[Triplet], alpha: float, lr: float) -> int: - """ Update embeddings using the t-STE gradient for each triplet. """ - violations_fixed: int = 0 - for idx, t in enumerate(triplets): - xi = embeddings[t.z_index] - xj = embeddings[t.x_index] - xl = embeddings[t.y_index] - - # Squared distances - dij = np.sum((xi - xj) ** 2) - dil = np.sum((xi - xl) ** 2) - - # Student-t similarities - sij = (1 + dij / alpha) ** (-(alpha + 1) / 2) - sil = (1 + dil / alpha) ** (-(alpha + 1) / 2) - pijl = sij / (sij + sil) - - # Gradients (see t-STE paper) - grad_coeff = (alpha + 1) / alpha - grad_xi = grad_coeff * ( - (1 - pijl) * (xj - xi) / (1 + dij / alpha) - - (1 - pijl) * (xl - xi) / (1 + dil / alpha) - ) - grad_xj = grad_coeff * (1 - pijl) * (xi - xj) / (1 + dij / alpha) - grad_xl = -grad_coeff * (1 - pijl) * (xi - xl) / (1 + dil / alpha) + return Sx, Sy - # Update embeddings - embeddings[t.z_index] = (xi + lr * grad_xi) - embeddings[t.x_index] = (xj + lr * grad_xj) - embeddings[t.y_index] = (xl + lr * grad_xl) - violations_fixed += 1 - return violations_fixed + async def construct(parent, elems): + if len(elems) < 2: + for elem in elems: + g.add_edge(parent, elem) + return -default_prompt = "Considering the nature of X, Y and Z, is X more similar to Z than Y is to Z? " + c1, c2, _elems = select(elems) + g.add_edge(parent, c1) + g.add_edge(parent, c2) -class Cluster(Relation): + elems1, elems2 = split(c1, c2, _elems) + asyncio.gather( + construct(c1, elems1), + construct(c2, elems2)) - def query_triplets(m:MelleaSession, triplets: list[Triplet], prompt: str) -> list[Triplet]: - """Given a triplet comparison query, perform the query using the LLM. """ + r1, r2, _elems = select(elems) + g.add_edge(r1, r2) + elems1, elems2 = split(r1, r2, _elems) + asyncio.gather( + construct(c1, elems1), + construct(c2, elems2)) - answers = asyncio.gather(*[m.achoice(prompt + f"\nZ: {t.z}\nX: {t.x}\nY: {t.y}\n" , ["X", "Y"]) - for t in triplets]) + return g - logger.info(f"Queried {len(triplets)} triplets.") - for idx, (t, a) in enumerate(zip(triplets, answers)): - logger.debug(f"Triplet {idx+1}: Z(anchor): {t.z} X: {t.x} Y: {t.y} result: {a}") +class Cluster(Relation): + + async def atriplet(m:MelleaSession, prompt: str, x:str, y:str, z:str, **kwargs) -> bool: + """Given a triplet comparison query, perform the query using the LLM. + + It returns True if Z is closer to X than is to Y. + """ + answer = await m.achoice(prompt + f"\nZ: {z}\nX: {x}\nY: {y}\n" , ["X", "Y"], **kwargs) + return answer == "X" - return [t if a == "X" else t.swap() - for t, a in zip(triplets, answers) ] async def acluster(m:MelleaSession, criteria:str, elems:list[str], *, - model : ClusterMixin = None, - vote:int=3, - positional:bool=True, - shuffle:bool=True, - # - ndims: int = 2, - lr: float = 0.020, - max_iterations: int = 100, - tolerance: float = 1e-4, - alpha: float | None = None, - num_triplets: int | None = None, - triplets_per_item: int | None = None, - repeat_x: int = 3, - # - **kwargs): - """ - Generate Triplet Embeddings of the given strings, and run clustering + k:int = 3, + **kwargs) -> list[set[str]]: + """Generate an approximate Delaunay graph and perform graph clustering on it. + + The graph construction method follows the n log n algorithm by Haghiri et. al. [1] Args: criteria: triplet comparison criteria elems: list of strings to cluster - model: an instance of sklearn.base.ClusterMixin, such as sklearn.cluster.KMEANS, sklearn.cluster.AgglomerativeClustering. default = sklearn.cluster.DBSCAN - - vote: an odd integer specifying the number of queries to make. The final result is a majority vote over the results. Since the LLM answers "yes"/"no", by default it counts "yes". If it is even, we add 1 to make it an odd number. - positional: Permute the order of presenting x and y. This mitigates the positional bias. - shuffle: It shuffles the variation of queries (symmetric/positional variations). - This helps when you are making multiple queries with a small vote count (less than 2*2=4 variations). - For example, when shuffle = False and vote = 1, the query always contains the original x y in the x y order. - - ndims: number of dimensions for embeddings - lr: weight to give each triplet when updating embeddings - max_iterations: number of times to use LLM triplets to update embeddings - tolerance: hyperparamater; ??? - alpha: hyperparameter; ??? - num_triplets: the number of triplets to generate. - triplets_per_item: number of triplets to sample per item (will result in len(items) * triplets_per_item triplets) - repeat_x: how many times we reuse the same z, x for sampling y. - verbose: boolean to determine whether or not to provide verbose output - clustering_method: clustering method + k: k for k-ANNS for Delaunay Graph + + **kwargs: accepts vote, positional, shuffle. Returns: - Dictionary representing each label in items to its associated coordinate + A cluster representation as list[set[str]] + + + [1] Haghiri, Siavash, Debarghya Ghoshdastidar, and Ulrike von Luxburg. + "Comparison-based nearest neighbor search." Artificial Intelligence and Statistics. PMLR, 2017. - Either triplets_per_item or num_triplets must be specified. - triplets_per_item and num_triplets are mutually exclusive (cannot be specified at the same time). """ - if model is None: - model = DBSCAN() - - if verbose: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.INFO) - - start_time = datetime.now() - # Set alpha default based on ndims if not provided - if alpha is None: - alpha = ndims - 1 - if criteria is None: - criteria = default_criteria - - N: int = len(items) - - logger.info(f"Starting triplet embedding with N={N} items...") - logger.info(f"Algorithm parameters:") - logger.info(f" Embedding dimensions (r): {ndims}") - logger.info(f" Learning rate: {lr}") - logger.info(f" Max iterations: {max_iterations}") - logger.info(f" Triplets per item: {triplets_per_item}") - logger.info(f" Reuse the same z and x for: {repeat_x} times") - logger.info(f" Tolerance: {tolerance}") - logger.info(f" Alpha (DoF): {alpha}") - - # Initialize random embeddings - embeddings = np.random.normal(0, 0.1, (len(items), ndims)) - - # Show initial positions - logger.debug(f"Generated initial random embeddings in {ndims}D space") - - # Generate LLM triplets ONCE - triplets = sample_triplets(items, - triplets_per_item=triplets_per_item, - num_triplets=num_triplets, - repeat_x=repeat_x,) - logger.debug(f"Using {len(triplets)} LLM-judged triplets for all iterations") - - # swap X/Y of triplets using LLM. Now X is always closer to anchor Z than Y is to anchor Z - triplets = m.query_triplets(triplets, criteria) - - # Iterative improvement - stat = { - "violations_fixed": 0, - "convergence_ratio": 0.0, - } - for iteration in tqdm(range(max_iterations), desc="updating the embedding (outer loop)", position=0, postfix=stat): - # Use the same triplets every iteration - violations_fixed: int = update(embeddings, triplets, alpha, lr) - convergence_ratio: float = violations_fixed / len(triplets) if len(triplets) > 0 else 0 - - stat["violations_fixed"] = violations_fixed - stat["convergence_ratio"] = convergence_ratio - - if convergence_ratio < tolerance: - logger.debug(f"Converged early at iteration {iteration + 1} (ratio < {tolerance})") - break - - elapsed_time = datetime.now() - start_time - formatted = str(elapsed_time).split('.')[0] - logger.debug(f"Elapsed time: {formatted}") - - return model.fit_predict(embeddings) + async def fn(x:str, y:str, z:str) -> bool: + return await m.atriplet(criteria, x, y, z, **kwargs) + + g = delaunay(elems, fn, k=k) + + communities = list(nx.algorithms.community.greedy_modularity_communities(g)) + + return communities Cluster.cluster = sync_wrapper(Cluster.acluster) +# Testing Delaunay Graph Clustering approach on 2D points. +# (for VA, we replace the triplet comparison with LLM-based one) + +Point = tuple[float,float] + +def points( + n_clusters=5, + points_per_cluster=20, + radius=10.0, + cluster_std=0.5, + seed=None, +): + """ + Generate 2D points in clusters centered at the vertices of a regular polyhedra. + + Returns + ------- + points : np.ndarray of shape (n_clusters*points_per_cluster, 2) + The generated 2D points. + """ + if seed is not None: + np.random.seed(seed) + + angles = np.linspace(0, 2 * np.pi, n_clusters, endpoint=False) + centers = np.column_stack([radius * np.cos(angles), + radius * np.sin(angles)]) + + # Generate clusters + points = [] + for cx, cy in centers: + cluster = np.random.normal( + loc=[cx, cy], + scale=cluster_std, + size=(points_per_cluster, 2) + ) + points.append(cluster) + + return np.vstack(points) + +async def criteria(x, y, z): + x = np.array(x) + y = np.array(y) + z = np.array(z) + return np.square(x-z).sum() < np.square(y-z).sum() + +def plot(g): + """ + Plot the graph g whose nodes are 2D points [x, y]. + Also compute greedy modularity communities and color + nodes by community assignment. + """ + + # --- Compute communities --- + communities = list(nx.algorithms.community.greedy_modularity_communities(g)) + + # Assign a color index to each node + node_color = {} + for cid, comm in enumerate(communities): + for node in comm: + node_color[node] = cid + + # Color palette + # If many clusters, matplotlib cycles automatically + colors = [node_color[n] for n in g.nodes] + + # --- Extract node positions --- + xs = [node[0] for node in g.nodes] + ys = [node[1] for node in g.nodes] + + plt.figure(figsize=(7, 7)) + + # --- Draw edges --- + for u, v in g.edges: + plt.plot([u[0], v[0]], [u[1], v[1]], linewidth=0.8, color="gray", alpha=0.5) + + # --- Draw nodes --- + sc = plt.scatter(xs, ys, c=colors, cmap="tab10", s=35) + + plt.gca().set_aspect('equal', 'box') + plt.title("Graph with Greedy Modularity Communities") + plt.xlabel("x") + plt.ylabel("y") + plt.grid(True) + + cbar = plt.colorbar(sc) + cbar.set_label("Community ID") + + plt.show() + +def main(): + elems = points(points_per_cluster=30) + elems = [ tuple(p) for p in elems ] + g = delaunay(elems) + plot(g) + +if __name__ == "__main__": + + main() + diff --git a/mellea_contribs/va/cluster_test.py b/mellea_contribs/va/cluster_test.py new file mode 100644 index 0000000..4ad0382 --- /dev/null +++ b/mellea_contribs/va/cluster_test.py @@ -0,0 +1,161 @@ +""" +Testing Delaunay Graph Clustering approach on 2D points. +(for VA, we replace the triplet comparison with LLM-based one) +""" + +import random +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt + +Point = tuple[float,float] + +def points( + n_clusters=5, + points_per_cluster=20, + radius=10.0, + cluster_std=0.5, + seed=None, +): + """ + Generate 2D points in clusters centered at the vertices of a regular polyhedra. + + Returns + ------- + points : np.ndarray of shape (n_clusters*points_per_cluster, 2) + The generated 2D points. + """ + if seed is not None: + np.random.seed(seed) + + angles = np.linspace(0, 2 * np.pi, n_clusters, endpoint=False) + centers = np.column_stack([radius * np.cos(angles), + radius * np.sin(angles)]) + + # Generate clusters + points = [] + for cx, cy in centers: + cluster = np.random.normal( + loc=[cx, cy], + scale=cluster_std, + size=(points_per_cluster, 2) + ) + points.append(cluster) + + return np.vstack(points) + +def criteria(x, y, z): + x = np.array(x) + y = np.array(y) + z = np.array(z) + return np.square(x-z).sum() < np.square(y-z).sum() + +def delaunay(elems, k=7): + + assert len(elems) >= 2 + + g = nx.Graph() + for elem in elems: + g.add_node(elem) + + for _ in range(k): + + def select(elems:list[Point]): + assert len(elems) >= 2 + _elems = elems.copy() + i1 = random.randint(0, len(_elems)-1) + r1 = _elems.pop(i1) + i2 = random.randint(0, len(_elems)-1) + r2 = _elems.pop(i2) + return r1, r2, _elems + + def split(x:Point, y:Point, S:list[Point]): + """Split a set S into Sx and Sy, which are closer to x or y, respectively""" + + Sx = [] + Sy = [] + for z in S: + if criteria(x, y, z): + Sx.append(z) + else: + Sy.append(z) + + return Sx, Sy + + def construct(parent, elems): + if len(elems) < 2: + for elem in elems: + g.add_edge(parent, elem) + return + + c1, c2, _elems = select(elems) + g.add_edge(parent, c1) + g.add_edge(parent, c2) + + elems1, elems2 = split(c1, c2, _elems) + construct(c1, elems1) + construct(c2, elems2) + + r1, r2, _elems = select(elems) + g.add_edge(r1, r2) + elems1, elems2 = split(r1, r2, _elems) + construct(r1, elems1) + construct(r2, elems2) + + return g + +def plot(g): + """ + Plot the graph g whose nodes are 2D points [x, y]. + Also compute greedy modularity communities and color + nodes by community assignment. + """ + + # --- Compute communities --- + communities = list(nx.algorithms.community.greedy_modularity_communities(g)) + + # Assign a color index to each node + node_color = {} + for cid, comm in enumerate(communities): + for node in comm: + node_color[node] = cid + + # Color palette + # If many clusters, matplotlib cycles automatically + colors = [node_color[n] for n in g.nodes] + + # --- Extract node positions --- + xs = [node[0] for node in g.nodes] + ys = [node[1] for node in g.nodes] + + plt.figure(figsize=(7, 7)) + + # --- Draw edges --- + for u, v in g.edges: + plt.plot([u[0], v[0]], [u[1], v[1]], linewidth=0.8, color="gray", alpha=0.5) + + # --- Draw nodes --- + sc = plt.scatter(xs, ys, c=colors, cmap="tab10", s=35) + + plt.gca().set_aspect('equal', 'box') + plt.title("Graph with Greedy Modularity Communities") + plt.xlabel("x") + plt.ylabel("y") + plt.grid(True) + + cbar = plt.colorbar(sc) + cbar.set_label("Community ID") + + plt.show() + +def main(): + elems = points(points_per_cluster=30) + elems = [ tuple(p) for p in elems ] + g = delaunay(elems) + plot(g) + + +if __name__ == "__main__": + + main() +