Skip to content

Commit 613360d

Browse files
dmccrystals0h3yl
authored andcommitted
fix: allow import of base sdk without extras installed
See issue: #23 GitOrigin-RevId: 485ef59591db8f79ec82bcda63742a8510447bb7
1 parent 2e990ee commit 613360d

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

assemblyai/extras.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import time
22
from typing import Generator
33

4-
try:
5-
import pyaudio
6-
except ImportError:
7-
raise ImportError(
8-
"You must install the extras for this SDK to use this feature. "
9-
"Run `pip install assemblyai[extras]` to install the extras. "
10-
"Make sure to install `apt install portaudio19-dev` (Debian/Ubuntu) or "
11-
"`brew install portaudio` (MacOS) before installing the extras."
12-
)
4+
5+
class AssemblyAIExtrasNotInstalledError(ImportError):
6+
def __init__(
7+
self,
8+
msg="""
9+
You must install the extras for this SDK to use this feature.
10+
Run `pip install assemblyai[extras]` to install the extras.
11+
Make sure to install `apt install portaudio19-dev` (Debian/Ubuntu) or
12+
`brew install portaudio` (MacOS) before installing the extras
13+
""",
14+
*args,
15+
**kwargs,
16+
):
17+
super().__init__(msg, *args, **kwargs)
1318

1419

1520
class MicrophoneStream:
@@ -25,6 +30,10 @@ def __init__(
2530
channels: The number of channels to record audio from.
2631
sample_rate: The sample rate to record audio at.
2732
"""
33+
try:
34+
import pyaudio
35+
except ImportError:
36+
raise AssemblyAIExtrasNotInstalledError
2837

2938
self._pyaudio = pyaudio.PyAudio()
3039
self.sample_rate = sample_rate

tests/unit/test_imports.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import sys
2+
from importlib import reload
3+
4+
import pytest
5+
import pytest_mock
6+
7+
import assemblyai as aai
8+
9+
10+
class ImportFailureMocker:
11+
def __init__(self, module: str):
12+
self.module = module
13+
14+
def find_spec(self, fullname, path, target=None):
15+
if fullname == self.module:
16+
raise ImportError
17+
18+
def __enter__(self):
19+
# Remove module if already imported
20+
if self.module in sys.modules:
21+
del sys.modules[self.module]
22+
23+
# Add self as first importer
24+
sys.meta_path.insert(0, self)
25+
return self
26+
27+
def __exit__(self, type, value, traceback):
28+
# Remove self as importer
29+
sys.meta_path.pop(0)
30+
31+
32+
def __reload_assesmblyai_module():
33+
reload(aai)
34+
aai.settings.api_key = "test"
35+
36+
37+
def test_import_sdk_without_extras_installed():
38+
with ImportFailureMocker("pyaudio"):
39+
__reload_assesmblyai_module()
40+
# Test succeeds if no failures
41+
42+
43+
def test_import_sdk_and_use_extras_without_extras_installed():
44+
with ImportFailureMocker("pyaudio"):
45+
__reload_assesmblyai_module()
46+
47+
with pytest.raises(aai.extras.AssemblyAIExtrasNotInstalledError):
48+
aai.extras.MicrophoneStream()
49+
50+
51+
def test_import_sdk_and_use_extras_with_extras_installed(
52+
mocker: pytest_mock.MockerFixture,
53+
):
54+
import pyaudio
55+
56+
__reload_assesmblyai_module()
57+
58+
mocker.patch.object(pyaudio.PyAudio, "open", return_value=None)
59+
aai.extras.MicrophoneStream()
60+
61+
# Test succeeds if no failures

tests/unit/test_realtime_transcriber.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import assemblyai as aai
1414

1515
aai.settings.api_key = "test"
16-
aai.settings.base_url = "https://api.assemblyai.com/v2"
1716

1817

1918
def _disable_rw_threads(mocker: MockFixture):
@@ -28,6 +27,8 @@ def test_realtime_connect_has_parameters(mocker: MockFixture):
2827
"""
2928
Test that the connect method has the correct parameters set
3029
"""
30+
aai.settings.base_url = "https://api.assemblyai.com/v2"
31+
3132
actual_url = None
3233
actual_additional_headers = None
3334
actual_open_timeout = None

0 commit comments

Comments
 (0)