diff --git a/build.sh b/build.sh index fc85edff..5db0219b 100755 --- a/build.sh +++ b/build.sh @@ -28,6 +28,7 @@ if [[ ! -d venv ]] ; then echo "./venv not found. Setting up venv" python3 -m venv "$PWD/venv" fi + source "$PWD/venv/bin/activate" if which pip3 ; then @@ -135,6 +136,10 @@ function write_e2e_env(){ val=$(gcloud secrets versions access latest --project "$TEST_PROJECT" --secret="$secret_name") echo "export $env_var_name='$val'" done + # Aliases for python e2e tests + echo "export POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME=\"\$POSTGRES_CUSTOMER_CAS_DOMAIN_NAME\"" + echo "export POSTGRES_IAM_USER=\"\$POSTGRES_USER_IAM\"" + echo "export MYSQL_IAM_USER=\"\$MYSQL_USER_IAM\"" } > "$1" } diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index ceff31ac..798969c2 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -390,6 +390,33 @@ async def connect_async( # the cache and re-raise the error await self._remove_cached(str(conn_name), enable_iam_auth) raise + + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" + ) + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 7143c104..e255f328 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import dns.asyncresolver from google.cloud.sql.connector.connection_name import _is_valid_domain @@ -53,6 +55,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore ) return conn_name + async def resolve_a_record(self, dns: str) -> List[str]: + try: + # Attempt to query the A records. + records = await super().resolve(dns, "A", raise_on_no_answer=True) + # return IP addresses as strings + return [record.to_text() for record in records] + except Exception: + # On any error, return empty list + return [] + async def query_dns(self, dns: str) -> ConnectionName: try: # Attempt to query the TXT records. diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 95766e14..a09b5b72 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -34,6 +34,7 @@ from google.cloud.sql.connector.exceptions import ConnectorLoopError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.resolver import DnsResolver @pytest.mark.asyncio @@ -548,3 +549,113 @@ def test_connect_closed_connector( exc_info.value.args[0] == "Connection attempt failed because the connector has already been closed." ) + + +@pytest.mark.asyncio +async def test_Connector_connect_async_custom_dns_resolver( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async uses custom DNS name resolution.""" + + # Create a mock DnsResolver that returns a fixed IP + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" + ) as mock_resolve_a: + mock_resolve_a.return_value = ["1.2.3.4"] + + # We also need to patch resolve because DnsResolver.resolve does DNS lookup for TXT + # But we can patch DnsResolver.resolve to return a ConnectionName with domain name + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve" + ) as mock_resolve: + # This must return a ConnectionName object with domain_name set + conn_name_with_domain = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + mock_resolve.return_value = conn_name_with_domain + + async with Connector( + credentials=fake_credentials, + loop=asyncio.get_running_loop(), + resolver=DnsResolver, + ) as connector: + connector._client = fake_client + + # patch db connection creation + with patch( + "google.cloud.sql.connector.asyncpg.connect" + ) as mock_connect: + mock_connect.return_value = True + + # Call connect_async + # Use "db.example.com" as instance connection string (resolver will handle it) + connection = await connector.connect_async( + "db.example.com", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + + # Verify mock_connect was called with resolved IP "1.2.3.4" + # The first arg to mock_connect (which patches connector call) is ip_address + args, _ = mock_connect.call_args + assert args[0] == "1.2.3.4" + assert connection is True + + +@pytest.mark.asyncio +async def test_Connector_connect_async_custom_dns_resolver_fallback( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async falls back if DNS resolution fails.""" + + # Create a mock DnsResolver that returns empty list (failure) + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" + ) as mock_resolve_a: + mock_resolve_a.return_value = [] + + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve" + ) as mock_resolve: + conn_name_with_domain = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + mock_resolve.return_value = conn_name_with_domain + + async with Connector( + credentials=fake_credentials, + loop=asyncio.get_running_loop(), + resolver=DnsResolver, + ) as connector: + connector._client = fake_client + + # Save original IPs to restore later (fake_instance is session-scoped) + original_ips = fake_client.instance.ip_addrs + # Set metadata IP to something specific + fake_client.instance.ip_addrs = {"PRIMARY": "5.6.7.8"} + + try: + with patch( + "google.cloud.sql.connector.asyncpg.connect" + ) as mock_connect: + mock_connect.return_value = True + + connection = await connector.connect_async( + "db.example.com", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + + # Verify mock_connect was called with metadata IP "5.6.7.8" + args, _ = mock_connect.call_args + assert args[0] == "5.6.7.8" + assert connection is True + finally: + # Restore original IPs + fake_client.instance.ip_addrs = original_ips + + diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index a9a7f263..c649e8e5 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -129,3 +129,40 @@ async def test_DnsResolver_with_bad_dns_name() -> None: with pytest.raises(DnsResolutionError) as exc_info: await resolver.resolve("bad.dns.com") assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`" + + +a_record_query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN A +;ANSWER +db.example.com. 0 IN A 127.0.0.1 +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_resolve_a_record() -> None: + """Test DnsResolver resolves A record into IP address.""" + with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.A, + dns.rdataclass.IN, + dns.message.from_text(a_record_query_text), + ) + mock_resolve.return_value = answer + resolver = DnsResolver() + result = await resolver.resolve_a_record("db.example.com") + assert result == ["127.0.0.1"] + + +async def test_DnsResolver_resolve_a_record_empty() -> None: + """Test DnsResolver resolves A record but gets error.""" + with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve: + mock_resolve.side_effect = Exception("DNS Error") + resolver = DnsResolver() + result = await resolver.resolve_a_record("db.example.com") + assert result == [] \ No newline at end of file