Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ if [[ ! -d venv ]] ; then
echo "./venv not found. Setting up venv"
python3 -m venv "$PWD/venv"
fi

source "$PWD/venv/bin/activate"

if which pip3 ; then
Expand Down Expand Up @@ -135,6 +136,10 @@ function write_e2e_env(){
val=$(gcloud secrets versions access latest --project "$TEST_PROJECT" --secret="$secret_name")
echo "export $env_var_name='$val'"
done
# Aliases for python e2e tests
echo "export POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME=\"\$POSTGRES_CUSTOMER_CAS_DOMAIN_NAME\""
echo "export POSTGRES_IAM_USER=\"\$POSTGRES_USER_IAM\""
echo "export MYSQL_IAM_USER=\"\$MYSQL_USER_IAM\""
} > "$1"

}
Expand Down
27 changes: 27 additions & 0 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,33 @@ async def connect_async(
# the cache and re-raise the error
await self._remove_cached(str(conn_name), enable_iam_auth)
raise

# If the connector is configured with a custom DNS name, attempt to use
# that DNS name to connect to the instance. Fall back to the metadata IP
# address if the DNS name does not resolve to an IP address.
if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver):
try:
ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name)
if ips:
ip_address = ips[0]
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', "
"using it to connect"
)
else:
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' resolved but returned no "
f"entries, using '{ip_address}' from instance metadata"
)
except Exception as e:
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' did not resolve to an IP "
f"address: {e}, using '{ip_address}' from instance metadata"
)

logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
if enable_iam_auth:
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import dns.asyncresolver

from google.cloud.sql.connector.connection_name import _is_valid_domain
Expand Down Expand Up @@ -53,6 +55,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore
)
return conn_name

async def resolve_a_record(self, dns: str) -> List[str]:
try:
# Attempt to query the A records.
records = await super().resolve(dns, "A", raise_on_no_answer=True)
# return IP addresses as strings
return [record.to_text() for record in records]
except Exception:
# On any error, return empty list
return []

async def query_dns(self, dns: str) -> ConnectionName:
try:
# Attempt to query the TXT records.
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from google.cloud.sql.connector.exceptions import ConnectorLoopError
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.resolver import DnsResolver


@pytest.mark.asyncio
Expand Down Expand Up @@ -548,3 +549,113 @@ def test_connect_closed_connector(
exc_info.value.args[0]
== "Connection attempt failed because the connector has already been closed."
)


@pytest.mark.asyncio
async def test_Connector_connect_async_custom_dns_resolver(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""Test that Connector.connect_async uses custom DNS name resolution."""

# Create a mock DnsResolver that returns a fixed IP
with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record"
) as mock_resolve_a:
mock_resolve_a.return_value = ["1.2.3.4"]

# We also need to patch resolve because DnsResolver.resolve does DNS lookup for TXT
# But we can patch DnsResolver.resolve to return a ConnectionName with domain name
with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve"
) as mock_resolve:
# This must return a ConnectionName object with domain_name set
conn_name_with_domain = ConnectionName(
"test-project", "test-region", "test-instance", "db.example.com"
)
mock_resolve.return_value = conn_name_with_domain

async with Connector(
credentials=fake_credentials,
loop=asyncio.get_running_loop(),
resolver=DnsResolver,
) as connector:
connector._client = fake_client

# patch db connection creation
with patch(
"google.cloud.sql.connector.asyncpg.connect"
) as mock_connect:
mock_connect.return_value = True

# Call connect_async
# Use "db.example.com" as instance connection string (resolver will handle it)
connection = await connector.connect_async(
"db.example.com",
"asyncpg",
user="my-user",
password="my-pass",
db="my-db",
)

# Verify mock_connect was called with resolved IP "1.2.3.4"
# The first arg to mock_connect (which patches connector call) is ip_address
args, _ = mock_connect.call_args
assert args[0] == "1.2.3.4"
assert connection is True


@pytest.mark.asyncio
async def test_Connector_connect_async_custom_dns_resolver_fallback(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""Test that Connector.connect_async falls back if DNS resolution fails."""

# Create a mock DnsResolver that returns empty list (failure)
with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record"
) as mock_resolve_a:
mock_resolve_a.return_value = []

with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve"
) as mock_resolve:
conn_name_with_domain = ConnectionName(
"test-project", "test-region", "test-instance", "db.example.com"
)
mock_resolve.return_value = conn_name_with_domain

async with Connector(
credentials=fake_credentials,
loop=asyncio.get_running_loop(),
resolver=DnsResolver,
) as connector:
connector._client = fake_client

# Save original IPs to restore later (fake_instance is session-scoped)
original_ips = fake_client.instance.ip_addrs
# Set metadata IP to something specific
fake_client.instance.ip_addrs = {"PRIMARY": "5.6.7.8"}

try:
with patch(
"google.cloud.sql.connector.asyncpg.connect"
) as mock_connect:
mock_connect.return_value = True

connection = await connector.connect_async(
"db.example.com",
"asyncpg",
user="my-user",
password="my-pass",
db="my-db",
)

# Verify mock_connect was called with metadata IP "5.6.7.8"
args, _ = mock_connect.call_args
assert args[0] == "5.6.7.8"
assert connection is True
finally:
# Restore original IPs
fake_client.instance.ip_addrs = original_ips


37 changes: 37 additions & 0 deletions tests/unit/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,40 @@ async def test_DnsResolver_with_bad_dns_name() -> None:
with pytest.raises(DnsResolutionError) as exc_info:
await resolver.resolve("bad.dns.com")
assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"


a_record_query_text = """id 1234
opcode QUERY
rcode NOERROR
flags QR AA RD RA
;QUESTION
db.example.com. IN A
;ANSWER
db.example.com. 0 IN A 127.0.0.1
;AUTHORITY
;ADDITIONAL
"""


async def test_DnsResolver_resolve_a_record() -> None:
"""Test DnsResolver resolves A record into IP address."""
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
answer = dns.resolver.Answer(
"db.example.com",
dns.rdatatype.A,
dns.rdataclass.IN,
dns.message.from_text(a_record_query_text),
)
mock_resolve.return_value = answer
resolver = DnsResolver()
result = await resolver.resolve_a_record("db.example.com")
assert result == ["127.0.0.1"]


async def test_DnsResolver_resolve_a_record_empty() -> None:
"""Test DnsResolver resolves A record but gets error."""
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
mock_resolve.side_effect = Exception("DNS Error")
resolver = DnsResolver()
result = await resolver.resolve_a_record("db.example.com")
assert result == []
Loading