Skip to content

Commit 6ead149

Browse files
committed
added object detection
1 parent d362b9c commit 6ead149

File tree

5 files changed

+272
-63
lines changed

5 files changed

+272
-63
lines changed

jigsawstack/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .embedding import Embedding, AsyncEmbedding
1616
from .exceptions import JigsawStackError
1717
from .image_generation import ImageGeneration, AsyncImageGeneration
18+
from .object_detection import ObjectDetection, AsyncObjectDetection
1819

1920

2021
class JigsawStack:
@@ -118,6 +119,12 @@ def __init__(
118119
disable_request_logging=disable_request_logging,
119120
).image_generation
120121

122+
self.object_detection = ObjectDetection(
123+
api_key=api_key,
124+
api_url=api_url,
125+
disable_request_logging=disable_request_logging,
126+
)
127+
121128

122129
class AsyncJigsawStack:
123130
validate: AsyncValidate
@@ -228,6 +235,12 @@ def __init__(
228235
disable_request_logging=disable_request_logging,
229236
).image_generation
230237

238+
self.object_detection = AsyncObjectDetection(
239+
api_key=api_key,
240+
api_url=api_url,
241+
disable_request_logging=disable_request_logging,
242+
)
243+
231244

232245
# Create a global instance of the Web class
233246
__all__ = ["JigsawStack", "Search", "JigsawStackError", "AsyncJigsawStack"]

jigsawstack/_client.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

jigsawstack/object_detection.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from typing import Any, Dict, List, Union, cast, Literal
2+
from typing_extensions import NotRequired, TypedDict
3+
from .request import Request, RequestConfig
4+
from .async_request import AsyncRequest, AsyncRequestConfig
5+
from ._config import ClientConfig
6+
7+
8+
class Point(TypedDict):
9+
x: int
10+
"""
11+
X coordinate of the point
12+
"""
13+
14+
y: int
15+
"""
16+
Y coordinate of the point
17+
"""
18+
19+
20+
class BoundingBox(TypedDict):
21+
top_left: Point
22+
"""
23+
Top-left corner of the bounding box
24+
"""
25+
26+
top_right: Point
27+
"""
28+
Top-right corner of the bounding box
29+
"""
30+
31+
bottom_left: Point
32+
"""
33+
Bottom-left corner of the bounding box
34+
"""
35+
36+
bottom_right: Point
37+
"""
38+
Bottom-right corner of the bounding box
39+
"""
40+
41+
width: int
42+
"""
43+
Width of the bounding box
44+
"""
45+
46+
height: int
47+
"""
48+
Height of the bounding box
49+
"""
50+
51+
52+
class GuiElement(TypedDict):
53+
bounds: BoundingBox
54+
"""
55+
Bounding box coordinates of the GUI element
56+
"""
57+
58+
content: Union[str, None]
59+
"""
60+
Content of the GUI element, can be null if no object detected
61+
"""
62+
63+
64+
class DetectedObject(TypedDict):
65+
bounds: BoundingBox
66+
"""
67+
Bounding box coordinates of the detected object
68+
"""
69+
70+
mask: NotRequired[str]
71+
"""
72+
URL or base64 string depending on return_type - only present for some objects
73+
"""
74+
75+
76+
class UsageStats(TypedDict):
77+
"""
78+
Usage statistics - structure depends on the RunPod response
79+
"""
80+
pass # Flexible structure for usage stats
81+
82+
83+
class ObjectDetectionParams(TypedDict):
84+
url: NotRequired[str]
85+
"""
86+
URL of the image to process
87+
"""
88+
89+
file_store_key: NotRequired[str]
90+
"""
91+
File store key of the image to process
92+
"""
93+
94+
prompts: NotRequired[List[str]]
95+
"""
96+
List of prompts for object detection
97+
"""
98+
99+
features: NotRequired[List[Literal["object_detection", "gui"]]]
100+
"""
101+
List of features to enable: object_detection, gui
102+
"""
103+
104+
annotated_image: NotRequired[bool]
105+
"""
106+
Whether to return an annotated image
107+
"""
108+
109+
return_type: NotRequired[Literal["url", "base64"]]
110+
"""
111+
Format for returned images: url or base64
112+
"""
113+
114+
115+
class ObjectDetectionResponse(TypedDict):
116+
annotated_image: NotRequired[str]
117+
"""
118+
URL or base64 string of annotated image (included only if annotated_image=true and objects/gui_elements exist)
119+
"""
120+
121+
gui_elements: NotRequired[List[GuiElement]]
122+
"""
123+
List of detected GUI elements (included only if features includes "gui")
124+
"""
125+
126+
objects: NotRequired[List[DetectedObject]]
127+
"""
128+
List of detected objects (included only if features includes "object_detection")
129+
"""
130+
131+
_usage: NotRequired[UsageStats]
132+
"""
133+
Optional usage statistics
134+
"""
135+
136+
137+
class ObjectDetection(ClientConfig):
138+
config: RequestConfig
139+
140+
def __init__(
141+
self,
142+
api_key: str,
143+
api_url: str,
144+
disable_request_logging: Union[bool, None] = False,
145+
):
146+
super().__init__(api_key, api_url, disable_request_logging)
147+
self.config = RequestConfig(
148+
api_url=api_url,
149+
api_key=api_key,
150+
disable_request_logging=disable_request_logging,
151+
)
152+
153+
def detect(self, params: ObjectDetectionParams) -> ObjectDetectionResponse:
154+
"""
155+
Detect objects and/or GUI elements in an image
156+
157+
Args:
158+
params: Object detection parameters
159+
160+
Returns:
161+
Object detection response with detected objects, GUI elements, and optional annotated image
162+
"""
163+
resp = Request(
164+
config=self.config,
165+
path="/ai/object_detection",
166+
params=cast(Dict[Any, Any], params),
167+
verb="POST",
168+
).perform_with_content()
169+
170+
return resp
171+
172+
173+
class AsyncObjectDetection(ClientConfig):
174+
config: AsyncRequestConfig
175+
176+
def __init__(
177+
self,
178+
api_key: str,
179+
api_url: str,
180+
disable_request_logging: Union[bool, None] = False,
181+
):
182+
super().__init__(api_key, api_url, disable_request_logging)
183+
self.config = AsyncRequestConfig(
184+
api_url=api_url,
185+
api_key=api_key,
186+
disable_request_logging=disable_request_logging,
187+
)
188+
189+
async def detect(self, params: ObjectDetectionParams) -> ObjectDetectionResponse:
190+
"""
191+
Detect objects and/or GUI elements in an image (async)
192+
193+
Args:
194+
params: Object detection parameters
195+
196+
Returns:
197+
Object detection response with detected objects, GUI elements, and optional annotated image
198+
"""
199+
resp = await AsyncRequest(
200+
config=self.config,
201+
path="/ai/object_detection",
202+
params=cast(Dict[Any, Any], params),
203+
verb="POST",
204+
).perform_with_content()
205+
206+
return resp

tests/test_object_detection.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from unittest.mock import MagicMock
2+
import unittest
3+
from jigsawstack.exceptions import JigsawStackError
4+
import jigsawstack
5+
import pytest
6+
import asyncio
7+
import logging
8+
9+
logging.basicConfig(level=logging.INFO)
10+
logger = logging.getLogger(__name__)
11+
12+
jigsaw = jigsawstack.JigsawStack()
13+
async_jigsaw = jigsawstack.AsyncJigsawStack()
14+
15+
16+
def test_object_detection_response():
17+
try:
18+
result = jigsaw.object_detection.detect({"url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg"})
19+
print(result)
20+
assert result["success"] == True
21+
except JigsawStackError as e:
22+
pytest.fail(f"Unexpected JigsawStackError: {e}")
23+
24+
25+
def test_object_detection_response_async():
26+
async def _test():
27+
client = jigsawstack.AsyncJigsawStack()
28+
try:
29+
result = await client.object_detection.detect({"url": "https://rogilvkqloanxtvjfrkm.supabase.co/storage/v1/object/public/demo/Collabo%201080x842.jpg"})
30+
print(result)
31+
assert result["success"] == True
32+
except JigsawStackError as e:
33+
pytest.fail(f"Unexpected JigsawStackError: {e}")
34+
35+
asyncio.run(_test())
36+

tests/test_search.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@
1414

1515

1616
def test_search_suggestion_response():
17+
try:
18+
result = jigsaw.web.search({"query": "Where is San Francisco"})
19+
assert result["success"] == True
20+
except JigsawStackError as e:
21+
pytest.fail(f"Unexpected JigsawStackError: {e}")
22+
23+
24+
def test_ai_search_response():
25+
try:
26+
result = jigsaw.web.search({"query": "Where is San Francisco"})
27+
assert result["success"] == True
28+
except JigsawStackError as e:
29+
pytest.fail(f"Unexpected JigsawStackError: {e}")
30+
31+
32+
def test_search_suggestion_response_async():
1733
async def _test():
1834
client = jigsawstack.AsyncJigsawStack()
1935
try:
@@ -25,7 +41,7 @@ async def _test():
2541
asyncio.run(_test())
2642

2743

28-
def test_ai_search_response():
44+
def test_ai_search_response_async():
2945
async def _test():
3046
client = jigsawstack.AsyncJigsawStack()
3147
try:

0 commit comments

Comments
 (0)