44# Copyright (c) 2024 Oracle and/or its affiliates.
55# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
7+ import os
78import logging
89import subprocess
910from unittest import TestCase
1011from unittest .mock import patch
11-
12+ from importlib import reload
1213from parameterized import parameterized
1314
15+ import ads .aqua
16+ import ads .config
1417from ads .aqua .cli import AquaCommand
1518
1619
1720class TestAquaCLI (TestCase ):
1821 """Tests the AQUA CLI."""
1922
20- DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR"
23+ DEFAULT_AQUA_CLI_LOGGING_LEVEL = "ERROR"
2124 logger = logging .getLogger (__name__ )
2225 logging .basicConfig (
2326 format = "%(asctime)s %(module)s %(levelname)s: %(message)s" ,
2427 datefmt = "%m/%d/%Y %I:%M:%S %p" ,
2528 level = logging .INFO ,
2629 )
30+ SERVICE_COMPARTMENT_ID = "ocid1.compartment.oc1..<OCID>"
2731
2832 def test_entrypoint (self ):
2933 """Tests CLI entrypoint."""
@@ -33,15 +37,55 @@ def test_entrypoint(self):
3337
3438 @parameterized .expand (
3539 [
36- ("default" , None , DEFAUL_AQUA_CLI_LOGGING_LEVEL ),
40+ ("default" , None , DEFAULT_AQUA_CLI_LOGGING_LEVEL ),
3741 ("set logging level" , "info" , "info" ),
3842 ]
3943 )
40- @patch ("ads.aqua.cli.set_log_level" )
41- def test_aquacommand (self , name , arg , expected , mock_setting_log ):
42- """Tests aqua command initailzation."""
43- if arg :
44- AquaCommand (arg )
45- else :
46- AquaCommand ()
47- mock_setting_log .assert_called_with (expected )
44+ def test_aquacommand (self , name , arg , expected ):
45+ """Tests aqua command initialization."""
46+ with patch .dict (
47+ os .environ ,
48+ {"ODSC_MODEL_COMPARTMENT_OCID" : TestAquaCLI .SERVICE_COMPARTMENT_ID },
49+ ):
50+ reload (ads .config )
51+ reload (ads .aqua )
52+ reload (ads .aqua .cli )
53+ with patch ("ads.aqua.cli.set_log_level" ) as mock_setting_log :
54+ if arg :
55+ AquaCommand (arg )
56+ else :
57+ AquaCommand ()
58+ mock_setting_log .assert_called_with (expected )
59+
60+ @parameterized .expand (
61+ [
62+ ("default" , None ),
63+ ("using jupyter instance" , "nb-session-ocid" ),
64+ ]
65+ )
66+ def test_aqua_command_without_compartment_env_var (self , name , session_ocid ):
67+ """Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. Also check if NB_SESSION_OCID is
68+ set then log the appropriate message."""
69+
70+ with patch ("sys.exit" ) as mock_exit :
71+ env_dict = {"ODSC_MODEL_COMPARTMENT_OCID" : "" }
72+ if session_ocid :
73+ env_dict .update ({"NB_SESSION_OCID" : session_ocid })
74+ with patch .dict (os .environ , env_dict ):
75+ reload (ads .config )
76+ reload (ads .aqua )
77+ reload (ads .aqua .cli )
78+ with patch ("ads.aqua.cli.set_log_level" ) as mock_setting_log :
79+ with patch ("ads.aqua.logger.error" ) as mock_logger_error :
80+ AquaCommand ()
81+ mock_setting_log .assert_called_with (
82+ TestAquaCLI .DEFAULT_AQUA_CLI_LOGGING_LEVEL
83+ )
84+ mock_logger_error .assert_any_call (
85+ "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
86+ )
87+ if session_ocid :
88+ mock_logger_error .assert_any_call (
89+ f"Aqua is not available for the notebook session { session_ocid } ."
90+ )
91+ mock_exit .assert_called_with (1 )
0 commit comments