Skip to content

Commit fcab25e

Browse files
he-jamesAssemblyAI
andauthored
chore: sync sdk code with DeepLearning repo (#129)
Co-authored-by: AssemblyAI <engineering.sdk@assemblyai.com>
1 parent 4677079 commit fcab25e

File tree

4 files changed

+165
-5
lines changed

4 files changed

+165
-5
lines changed

assemblyai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
Sentiment,
5454
SentimentType,
5555
Settings,
56+
SpeakerOptions,
5657
SpeechModel,
5758
StatusResult,
5859
SummarizationModel,
@@ -114,6 +115,7 @@
114115
"Sentiment",
115116
"SentimentType",
116117
"Settings",
118+
"SpeakerOptions",
117119
"SpeechModel",
118120
"StatusResult",
119121
"SummarizationModel",

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.41.3"
1+
__version__ = "0.41.4"

assemblyai/types.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,43 @@ class SpeechModel(str, Enum):
489489
"The model optimized for accuracy, low latency, ease of use, and multi-language support"
490490

491491

492+
class SpeakerOptions(BaseModel):
493+
"""
494+
Speaker options for controlling speaker diarization parameters
495+
"""
496+
497+
min_speakers_expected: Optional[int] = Field(
498+
None, ge=1, description="Minimum number of speakers expected in the audio"
499+
)
500+
max_speakers_expected: Optional[int] = Field(
501+
None, ge=1, description="Maximum number of speakers expected in the audio"
502+
)
503+
504+
if pydantic_v2:
505+
506+
@field_validator("max_speakers_expected")
507+
@classmethod
508+
def validate_max_speakers(cls, v, info):
509+
if v is not None and info.data.get("min_speakers_expected") is not None:
510+
min_speakers = info.data["min_speakers_expected"]
511+
if v < min_speakers:
512+
raise ValueError(
513+
"max_speakers_expected must be greater than or equal to min_speakers_expected"
514+
)
515+
return v
516+
else:
517+
518+
@validator("max_speakers_expected")
519+
def validate_max_speakers(cls, v, values):
520+
if v is not None and values.get("min_speakers_expected") is not None:
521+
min_speakers = values["min_speakers_expected"]
522+
if v < min_speakers:
523+
raise ValueError(
524+
"max_speakers_expected must be greater than or equal to min_speakers_expected"
525+
)
526+
return v
527+
528+
492529
class RawTranscriptionConfig(BaseModel):
493530
language_code: Optional[Union[str, LanguageCode]] = None
494531
"""
@@ -546,6 +583,9 @@ class RawTranscriptionConfig(BaseModel):
546583
speakers_expected: Optional[int] = None
547584
"The number of speakers you expect to be in your audio file."
548585

586+
speaker_options: Optional[SpeakerOptions] = None
587+
"Advanced options for controlling speaker diarization parameters."
588+
549589
content_safety: Optional[bool] = None
550590
"Enable Content Safety Detection."
551591

@@ -633,6 +673,7 @@ def __init__(
633673
redact_pii_sub: Optional[PIISubstitutionPolicy] = None,
634674
speaker_labels: Optional[bool] = None,
635675
speakers_expected: Optional[int] = None,
676+
speaker_options: Optional[SpeakerOptions] = None,
636677
content_safety: Optional[bool] = None,
637678
content_safety_confidence: Optional[int] = None,
638679
iab_categories: Optional[bool] = None,
@@ -675,6 +716,7 @@ def __init__(
675716
redact_pii_sub: The replacement logic for detected PII.
676717
speaker_labels: Enable Speaker Diarization.
677718
speakers_expected: The number of speakers you expect to hear in your audio file. Up to 10 speakers are supported.
719+
speaker_options: Advanced options for controlling speaker diarization parameters, including min and max speakers expected.
678720
content_safety: Enable Content Safety Detection.
679721
iab_categories: Enable Topic Detection.
680722
custom_spelling: Customize how words are spelled and formatted using to and from values.
@@ -722,7 +764,7 @@ def __init__(
722764
redact_pii_policies,
723765
redact_pii_sub,
724766
)
725-
self.set_speaker_diarization(speaker_labels, speakers_expected)
767+
self.set_speaker_diarization(speaker_labels, speakers_expected, speaker_options)
726768
self.set_content_safety(content_safety, content_safety_confidence)
727769
self.iab_categories = iab_categories
728770
self.set_custom_spelling(custom_spelling, override=True)
@@ -934,6 +976,12 @@ def speakers_expected(self) -> Optional[int]:
934976

935977
return self._raw_transcription_config.speakers_expected
936978

979+
@property
980+
def speaker_options(self) -> Optional[SpeakerOptions]:
981+
"Returns the advanced speaker diarization options."
982+
983+
return self._raw_transcription_config.speaker_options
984+
937985
@property
938986
def content_safety(self) -> Optional[bool]:
939987
"Returns the status of the Content Safety feature."
@@ -1162,21 +1210,32 @@ def set_speaker_diarization(
11621210
self,
11631211
enable: Optional[bool] = True,
11641212
speakers_expected: Optional[int] = None,
1213+
speaker_options: Optional[SpeakerOptions] = None,
11651214
) -> Self:
11661215
"""
11671216
Whether to enable Speaker Diarization on the transcript.
11681217
11691218
Args:
11701219
`enable`: Enable Speaker Diarization
11711220
`speakers_expected`: The number of speakers in the audio file.
1221+
`speaker_options`: Advanced options for controlling speaker diarization parameters.
11721222
"""
11731223

1174-
if not enable:
1224+
# If enable is explicitly False, clear all speaker settings
1225+
if enable is False:
11751226
self._raw_transcription_config.speaker_labels = None
11761227
self._raw_transcription_config.speakers_expected = None
1228+
self._raw_transcription_config.speaker_options = None
1229+
# If enable is True or None, set the values (allow setting speaker_options even when enable is None)
11771230
else:
1178-
self._raw_transcription_config.speaker_labels = True
1179-
self._raw_transcription_config.speakers_expected = speakers_expected
1231+
# Only set speaker_labels to True if enable is explicitly True
1232+
if enable is True:
1233+
self._raw_transcription_config.speaker_labels = True
1234+
# Always set these if provided, regardless of enable value
1235+
if speakers_expected is not None:
1236+
self._raw_transcription_config.speakers_expected = speakers_expected
1237+
if speaker_options is not None:
1238+
self._raw_transcription_config.speaker_options = speaker_options
11801239

11811240
return self
11821241

@@ -1712,6 +1771,9 @@ class BaseTranscript(BaseModel):
17121771
speakers_expected: Optional[int] = None
17131772
"The number of speakers you expect to be in your audio file."
17141773

1774+
speaker_options: Optional[SpeakerOptions] = None
1775+
"Advanced options for controlling speaker diarization parameters."
1776+
17151777
content_safety: Optional[bool] = None
17161778
"Enable Content Safety Detection."
17171779

tests/unit/test_speaker_options.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import pytest
2+
3+
import assemblyai as aai
4+
5+
6+
def test_speaker_options_creation():
7+
"""Test that SpeakerOptions can be created with valid parameters."""
8+
speaker_options = aai.SpeakerOptions(
9+
min_speakers_expected=2, max_speakers_expected=5
10+
)
11+
assert speaker_options.min_speakers_expected == 2
12+
assert speaker_options.max_speakers_expected == 5
13+
14+
15+
def test_speaker_options_validation():
16+
"""Test that SpeakerOptions validates max >= min."""
17+
with pytest.raises(
18+
ValueError,
19+
match="max_speakers_expected must be greater than or equal to min_speakers_expected",
20+
):
21+
aai.SpeakerOptions(min_speakers_expected=5, max_speakers_expected=2)
22+
23+
24+
def test_speaker_options_min_only():
25+
"""Test that SpeakerOptions can be created with only min_speakers_expected."""
26+
speaker_options = aai.SpeakerOptions(min_speakers_expected=3)
27+
assert speaker_options.min_speakers_expected == 3
28+
assert speaker_options.max_speakers_expected is None
29+
30+
31+
def test_speaker_options_max_only():
32+
"""Test that SpeakerOptions can be created with only max_speakers_expected."""
33+
speaker_options = aai.SpeakerOptions(max_speakers_expected=5)
34+
assert speaker_options.min_speakers_expected is None
35+
assert speaker_options.max_speakers_expected == 5
36+
37+
38+
def test_transcription_config_with_speaker_options():
39+
"""Test that TranscriptionConfig accepts speaker_options parameter."""
40+
speaker_options = aai.SpeakerOptions(
41+
min_speakers_expected=2, max_speakers_expected=4
42+
)
43+
44+
config = aai.TranscriptionConfig(
45+
speaker_labels=True, speaker_options=speaker_options
46+
)
47+
48+
assert config.speaker_labels is True
49+
assert config.speaker_options == speaker_options
50+
assert config.speaker_options.min_speakers_expected == 2
51+
assert config.speaker_options.max_speakers_expected == 4
52+
53+
54+
def test_set_speaker_diarization_with_speaker_options():
55+
"""Test setting speaker diarization with speaker_options."""
56+
speaker_options = aai.SpeakerOptions(
57+
min_speakers_expected=1, max_speakers_expected=3
58+
)
59+
60+
config = aai.TranscriptionConfig()
61+
config.set_speaker_diarization(
62+
enable=True, speakers_expected=2, speaker_options=speaker_options
63+
)
64+
65+
assert config.speaker_labels is True
66+
assert config.speakers_expected == 2
67+
assert config.speaker_options == speaker_options
68+
69+
70+
def test_set_speaker_diarization_disable_clears_speaker_options():
71+
"""Test that disabling speaker diarization clears speaker_options."""
72+
speaker_options = aai.SpeakerOptions(min_speakers_expected=2)
73+
74+
config = aai.TranscriptionConfig()
75+
config.set_speaker_diarization(enable=True, speaker_options=speaker_options)
76+
77+
# Verify it was set
78+
assert config.speaker_options == speaker_options
79+
80+
# Now disable
81+
config.set_speaker_diarization(enable=False)
82+
83+
assert config.speaker_labels is None
84+
assert config.speakers_expected is None
85+
assert config.speaker_options is None
86+
87+
88+
def test_speaker_options_in_raw_config():
89+
"""Test that speaker_options is properly set in the raw config."""
90+
speaker_options = aai.SpeakerOptions(
91+
min_speakers_expected=2, max_speakers_expected=5
92+
)
93+
94+
config = aai.TranscriptionConfig(speaker_options=speaker_options)
95+
96+
assert config.raw.speaker_options == speaker_options

0 commit comments

Comments
 (0)