Skip to content

Commit 82fa060

Browse files
authored
Merge pull request #60 from JigsawStack/feat/embeddingV2
Feat/embedding v2
2 parents 65a1ab7 + a022213 commit 82fa060

File tree

4 files changed

+178
-32
lines changed

4 files changed

+178
-32
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ pyproject.toml
3030
uv.lock
3131

3232
.ruff_cache/
33-
local_tests/
33+
local_tests/*
34+

jigsawstack/__init__.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .exceptions import JigsawStackError
1616
from .image_generation import ImageGeneration, AsyncImageGeneration
1717
from .classification import Classification, AsyncClassification
18+
from .embeddingV2 import EmbeddingV2, AsyncEmbeddingV2
1819

1920

2021
class JigsawStack:
@@ -48,7 +49,7 @@ def __init__(
4849
if api_url is None:
4950
api_url = os.environ.get("JIGSAWSTACK_API_URL")
5051
if api_url is None:
51-
api_url = f"https://api.jigsawstack.com/v1"
52+
api_url = f"https://api.jigsawstack.com/"
5253

5354
self.api_key = api_key
5455
self.api_url = api_url
@@ -59,73 +60,79 @@ def __init__(
5960

6061
self.audio = Audio(
6162
api_key=api_key,
62-
api_url=api_url,
63+
api_url=api_url + "/v1",
6364
disable_request_logging=disable_request_logging,
6465
)
6566
self.web = Web(
6667
api_key=api_key,
67-
api_url=api_url,
68+
api_url=api_url + "/v1",
6869
disable_request_logging=disable_request_logging,
6970
)
7071
self.sentiment = Sentiment(
7172
api_key=api_key,
72-
api_url=api_url,
73+
api_url=api_url + "/v1",
7374
disable_request_logging=disable_request_logging,
7475
).analyze
7576
self.validate = Validate(
7677
api_key=api_key,
77-
api_url=api_url,
78+
api_url=api_url + "/v1",
7879
disable_request_logging=disable_request_logging,
7980
)
8081
self.summary = Summary(
8182
api_key=api_key,
82-
api_url=api_url,
83+
api_url=api_url + "/v1",
8384
disable_request_logging=disable_request_logging,
8485
).summarize
8586
self.vision = Vision(
8687
api_key=api_key,
87-
api_url=api_url,
88+
api_url=api_url + "/v1",
8889
disable_request_logging=disable_request_logging,
8990
)
9091
self.prediction = Prediction(
9192
api_key=api_key,
92-
api_url=api_url,
93+
api_url=api_url + "/v1",
9394
disable_request_logging=disable_request_logging,
9495
).predict
9596
self.text_to_sql = SQL(
9697
api_key=api_key,
97-
api_url=api_url,
98+
api_url=api_url + "/v1",
9899
disable_request_logging=disable_request_logging,
99100
).text_to_sql
100101
self.store = Store(
101102
api_key=api_key,
102-
api_url=api_url,
103+
api_url=api_url + "/v1",
103104
disable_request_logging=disable_request_logging,
104105
)
105106
self.translate = Translate(
106107
api_key=api_key,
107-
api_url=api_url,
108+
api_url=api_url + "/v1",
108109
disable_request_logging=disable_request_logging,
109110
)
110111

111112
self.embedding = Embedding(
112113
api_key=api_key,
113-
api_url=api_url,
114+
api_url=api_url + "/v1",
114115
disable_request_logging=disable_request_logging,
115116
).execute
117+
118+
self.embeddingV2 = EmbeddingV2(
119+
api_key=api_key,
120+
api_url=api_url + "/v2",
121+
disable_request_logging=disable_request_logging,
122+
).execute
123+
116124
self.image_generation = ImageGeneration(
117125
api_key=api_key,
118-
api_url=api_url,
126+
api_url=api_url + "/v1",
119127
disable_request_logging=disable_request_logging,
120128
).image_generation
121129

122130
self.classification = Classification(
123131
api_key=api_key,
124-
api_url=api_url,
132+
api_url=api_url + "/v1",
125133
disable_request_logging=disable_request_logging,
126134
).classify
127135

128-
129136
class AsyncJigsawStack:
130137
validate: AsyncValidate
131138
web: AsyncWeb
@@ -154,87 +161,92 @@ def __init__(
154161
if api_url is None:
155162
api_url = os.environ.get("JIGSAWSTACK_API_URL")
156163
if api_url is None:
157-
api_url = f"https://api.jigsawstack.com/v1"
164+
api_url = f"https://api.jigsawstack.com/"
158165

159166
self.api_key = api_key
160167
self.api_url = api_url
161168

162169
self.web = AsyncWeb(
163170
api_key=api_key,
164-
api_url=api_url,
171+
api_url=api_url + "/v1",
165172
disable_request_logging=disable_request_logging,
166173
)
167174

168175
self.validate = AsyncValidate(
169176
api_key=api_key,
170-
api_url=api_url,
177+
api_url=api_url + "/v1",
171178
disable_request_logging=disable_request_logging,
172179
)
173180
self.audio = AsyncAudio(
174181
api_key=api_key,
175-
api_url=api_url,
182+
api_url=api_url + "/v1",
176183
disable_request_logging=disable_request_logging,
177184
)
178185

179186
self.vision = AsyncVision(
180187
api_key=api_key,
181-
api_url=api_url,
188+
api_url=api_url + "/v1",
182189
disable_request_logging=disable_request_logging,
183190
)
184191

185192
self.store = AsyncStore(
186193
api_key=api_key,
187-
api_url=api_url,
194+
api_url=api_url + "/v1",
188195
disable_request_logging=disable_request_logging,
189196
)
190197

191198
self.summary = AsyncSummary(
192199
api_key=api_key,
193-
api_url=api_url,
200+
api_url=api_url + "/v1",
194201
disable_request_logging=disable_request_logging,
195202
).summarize
196203

197204
self.prediction = AsyncPrediction(
198205
api_key=api_key,
199-
api_url=api_url,
206+
api_url=api_url + "/v1",
200207
disable_request_logging=disable_request_logging,
201208
).predict
202209
self.text_to_sql = AsyncSQL(
203210
api_key=api_key,
204-
api_url=api_url,
211+
api_url=api_url + "/v1",
205212
disable_request_logging=disable_request_logging,
206213
).text_to_sql
207214

208215
self.sentiment = AsyncSentiment(
209216
api_key=api_key,
210-
api_url=api_url,
217+
api_url=api_url + "/v1",
211218
disable_request_logging=disable_request_logging,
212219
).analyze
213220

214221
self.translate = AsyncTranslate(
215222
api_key=api_key,
216-
api_url=api_url,
223+
api_url=api_url + "/v1",
217224
disable_request_logging=disable_request_logging,
218225
)
219226

220227
self.embedding = AsyncEmbedding(
221228
api_key=api_key,
222-
api_url=api_url,
229+
api_url=api_url + "/v1",
230+
disable_request_logging=disable_request_logging,
231+
).execute
232+
233+
self.embeddingV2 = AsyncEmbeddingV2(
234+
api_key=api_key,
235+
api_url=api_url + "/v2",
223236
disable_request_logging=disable_request_logging,
224237
).execute
225238

226239
self.image_generation = AsyncImageGeneration(
227240
api_key=api_key,
228-
api_url=api_url,
241+
api_url=api_url + "/v1",
229242
disable_request_logging=disable_request_logging,
230243
).image_generation
231244

232245
self.classification = AsyncClassification(
233246
api_key=api_key,
234-
api_url=api_url,
247+
api_url=api_url + "/v1",
235248
disable_request_logging=disable_request_logging,
236249
).classify
237250

238-
239251
# Create a global instance of the Web class
240252
__all__ = ["JigsawStack", "Search", "JigsawStackError", "AsyncJigsawStack"]

jigsawstack/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Chunk(TypedDict):
2424

2525
class EmbeddingResponse(BaseResponse):
2626
embeddings: List[List[float]]
27-
chunks: List[Chunk]
27+
chunks: Union[List[Chunk], List[str]]
2828

2929

3030
class Embedding(ClientConfig):

jigsawstack/embeddingV2.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import Any, Dict, List, Union, cast, Literal, overload
2+
from typing_extensions import NotRequired, TypedDict
3+
from .request import Request, RequestConfig
4+
from .async_request import AsyncRequest
5+
from typing import List, Union
6+
from ._config import ClientConfig
7+
from .helpers import build_path
8+
from .embedding import Chunk
9+
10+
11+
class EmbeddingV2Params(TypedDict):
12+
text: NotRequired[str]
13+
file_content: NotRequired[Any]
14+
type: Literal["text", "text-other", "image", "audio", "pdf"]
15+
url: NotRequired[str]
16+
file_store_key: NotRequired[str]
17+
token_overflow_mode: NotRequired[Literal["truncate", "chunk", "error"]] = "chunk"
18+
speaker_fingerprint: NotRequired[bool]
19+
20+
21+
class EmbeddingV2Response(TypedDict):
22+
success: bool
23+
embeddings: List[List[float]]
24+
chunks: Union[List[str], List[Chunk]]
25+
speaker_embeddings: List[List[float]]
26+
27+
28+
class EmbeddingV2(ClientConfig):
29+
config: RequestConfig
30+
31+
def __init__(
32+
self,
33+
api_key: str,
34+
api_url: str,
35+
disable_request_logging: Union[bool, None] = False,
36+
):
37+
super().__init__(api_key, api_url, disable_request_logging)
38+
self.config = RequestConfig(
39+
api_url=api_url,
40+
api_key=api_key,
41+
disable_request_logging=disable_request_logging,
42+
)
43+
44+
@overload
45+
def execute(self, params: EmbeddingV2Params) -> EmbeddingV2Response: ...
46+
@overload
47+
def execute(
48+
self, blob: bytes, options: EmbeddingV2Params = None
49+
) -> EmbeddingV2Response: ...
50+
51+
def execute(
52+
self,
53+
blob: Union[EmbeddingV2Params, bytes],
54+
options: EmbeddingV2Params = None,
55+
) -> EmbeddingV2Response:
56+
path = "/embedding"
57+
if isinstance(blob, dict):
58+
resp = Request(
59+
config=self.config,
60+
path=path,
61+
params=cast(Dict[Any, Any], blob),
62+
verb="post",
63+
).perform_with_content()
64+
return resp
65+
66+
options = options or {}
67+
path = build_path(base_path=path, params=options)
68+
content_type = options.get("content_type", "application/octet-stream")
69+
_headers = {"Content-Type": content_type}
70+
71+
resp = Request(
72+
config=self.config,
73+
path=path,
74+
params=options,
75+
data=blob,
76+
headers=_headers,
77+
verb="post",
78+
).perform_with_content()
79+
return resp
80+
81+
82+
class AsyncEmbeddingV2(ClientConfig):
83+
config: RequestConfig
84+
85+
def __init__(
86+
self,
87+
api_key: str,
88+
api_url: str,
89+
disable_request_logging: Union[bool, None] = False,
90+
):
91+
super().__init__(api_key, api_url, disable_request_logging)
92+
self.config = RequestConfig(
93+
api_url=api_url,
94+
api_key=api_key,
95+
disable_request_logging=disable_request_logging,
96+
)
97+
98+
@overload
99+
async def execute(self, params: EmbeddingV2Params) -> EmbeddingV2Response: ...
100+
@overload
101+
async def execute(
102+
self, blob: bytes, options: EmbeddingV2Params = None
103+
) -> EmbeddingV2Response: ...
104+
105+
async def execute(
106+
self,
107+
blob: Union[EmbeddingV2Params, bytes],
108+
options: EmbeddingV2Params = None,
109+
) -> EmbeddingV2Response:
110+
path = "/embedding"
111+
if isinstance(blob, dict):
112+
resp = await AsyncRequest(
113+
config=self.config,
114+
path=path,
115+
params=cast(Dict[Any, Any], blob),
116+
verb="post",
117+
).perform_with_content()
118+
return resp
119+
120+
options = options or {}
121+
path = build_path(base_path=path, params=options)
122+
content_type = options.get("content_type", "application/octet-stream")
123+
_headers = {"Content-Type": content_type}
124+
125+
resp = await AsyncRequest(
126+
config=self.config,
127+
path=path,
128+
params=options,
129+
data=blob,
130+
headers=_headers,
131+
verb="post",
132+
).perform_with_content()
133+
return resp

0 commit comments

Comments
 (0)