From dcce02e64c976e0a9095e11a1ca9b62b3c7e7495 Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 16 Dec 2025 20:56:55 -0700 Subject: [PATCH] feat: Use configured DNS name to lookup instance IP address When a custom DNS name is used to connect to a Cloud SQL instance, the dialer should first attempt to resolve the custom DNS name to an IP address and use that for the connection. If the lookup fails, the dialer should fall back to using the IP address from the instance metadata. Fixes #1362 --- build.sh | 5 ++ google/cloud/sql/connector/connector.py | 27 ++++++ google/cloud/sql/connector/resolver.py | 12 +++ tests/unit/test_connector.py | 111 ++++++++++++++++++++++++ tests/unit/test_resolver.py | 37 ++++++++ 5 files changed, 192 insertions(+) diff --git a/build.sh b/build.sh index fc85edff9..5db0219be 100755 --- a/build.sh +++ b/build.sh @@ -28,6 +28,7 @@ if [[ ! -d venv ]] ; then echo "./venv not found. Setting up venv" python3 -m venv "$PWD/venv" fi + source "$PWD/venv/bin/activate" if which pip3 ; then @@ -135,6 +136,10 @@ function write_e2e_env(){ val=$(gcloud secrets versions access latest --project "$TEST_PROJECT" --secret="$secret_name") echo "export $env_var_name='$val'" done + # Aliases for python e2e tests + echo "export POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME=\"\$POSTGRES_CUSTOMER_CAS_DOMAIN_NAME\"" + echo "export POSTGRES_IAM_USER=\"\$POSTGRES_USER_IAM\"" + echo "export MYSQL_IAM_USER=\"\$MYSQL_USER_IAM\"" } > "$1" } diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index ceff31ac2..798969c2c 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -390,6 +390,33 @@ async def connect_async( # the cache and re-raise the error await self._remove_cached(str(conn_name), enable_iam_auth) raise + + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" + ) + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 7143c1047..e255f328a 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import dns.asyncresolver from google.cloud.sql.connector.connection_name import _is_valid_domain @@ -53,6 +55,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore ) return conn_name + async def resolve_a_record(self, dns: str) -> List[str]: + try: + # Attempt to query the A records. + records = await super().resolve(dns, "A", raise_on_no_answer=True) + # return IP addresses as strings + return [record.to_text() for record in records] + except Exception: + # On any error, return empty list + return [] + async def query_dns(self, dns: str) -> ConnectionName: try: # Attempt to query the TXT records. diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 95766e144..a09b5b72f 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -34,6 +34,7 @@ from google.cloud.sql.connector.exceptions import ConnectorLoopError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.resolver import DnsResolver @pytest.mark.asyncio @@ -548,3 +549,113 @@ def test_connect_closed_connector( exc_info.value.args[0] == "Connection attempt failed because the connector has already been closed." ) + + +@pytest.mark.asyncio +async def test_Connector_connect_async_custom_dns_resolver( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async uses custom DNS name resolution.""" + + # Create a mock DnsResolver that returns a fixed IP + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" + ) as mock_resolve_a: + mock_resolve_a.return_value = ["1.2.3.4"] + + # We also need to patch resolve because DnsResolver.resolve does DNS lookup for TXT + # But we can patch DnsResolver.resolve to return a ConnectionName with domain name + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve" + ) as mock_resolve: + # This must return a ConnectionName object with domain_name set + conn_name_with_domain = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + mock_resolve.return_value = conn_name_with_domain + + async with Connector( + credentials=fake_credentials, + loop=asyncio.get_running_loop(), + resolver=DnsResolver, + ) as connector: + connector._client = fake_client + + # patch db connection creation + with patch( + "google.cloud.sql.connector.asyncpg.connect" + ) as mock_connect: + mock_connect.return_value = True + + # Call connect_async + # Use "db.example.com" as instance connection string (resolver will handle it) + connection = await connector.connect_async( + "db.example.com", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + + # Verify mock_connect was called with resolved IP "1.2.3.4" + # The first arg to mock_connect (which patches connector call) is ip_address + args, _ = mock_connect.call_args + assert args[0] == "1.2.3.4" + assert connection is True + + +@pytest.mark.asyncio +async def test_Connector_connect_async_custom_dns_resolver_fallback( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async falls back if DNS resolution fails.""" + + # Create a mock DnsResolver that returns empty list (failure) + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" + ) as mock_resolve_a: + mock_resolve_a.return_value = [] + + with patch( + "google.cloud.sql.connector.resolver.DnsResolver.resolve" + ) as mock_resolve: + conn_name_with_domain = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + mock_resolve.return_value = conn_name_with_domain + + async with Connector( + credentials=fake_credentials, + loop=asyncio.get_running_loop(), + resolver=DnsResolver, + ) as connector: + connector._client = fake_client + + # Save original IPs to restore later (fake_instance is session-scoped) + original_ips = fake_client.instance.ip_addrs + # Set metadata IP to something specific + fake_client.instance.ip_addrs = {"PRIMARY": "5.6.7.8"} + + try: + with patch( + "google.cloud.sql.connector.asyncpg.connect" + ) as mock_connect: + mock_connect.return_value = True + + connection = await connector.connect_async( + "db.example.com", + "asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + + # Verify mock_connect was called with metadata IP "5.6.7.8" + args, _ = mock_connect.call_args + assert args[0] == "5.6.7.8" + assert connection is True + finally: + # Restore original IPs + fake_client.instance.ip_addrs = original_ips + + diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index a9a7f2632..c649e8e58 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -129,3 +129,40 @@ async def test_DnsResolver_with_bad_dns_name() -> None: with pytest.raises(DnsResolutionError) as exc_info: await resolver.resolve("bad.dns.com") assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`" + + +a_record_query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN A +;ANSWER +db.example.com. 0 IN A 127.0.0.1 +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_resolve_a_record() -> None: + """Test DnsResolver resolves A record into IP address.""" + with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.A, + dns.rdataclass.IN, + dns.message.from_text(a_record_query_text), + ) + mock_resolve.return_value = answer + resolver = DnsResolver() + result = await resolver.resolve_a_record("db.example.com") + assert result == ["127.0.0.1"] + + +async def test_DnsResolver_resolve_a_record_empty() -> None: + """Test DnsResolver resolves A record but gets error.""" + with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve: + mock_resolve.side_effect = Exception("DNS Error") + resolver = DnsResolver() + result = await resolver.resolve_a_record("db.example.com") + assert result == [] \ No newline at end of file