Skip to content
Open
129 changes: 127 additions & 2 deletions src/google/adk/plugins/save_files_as_artifacts_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

import copy
import logging
import os
import tempfile
from typing import Optional
import urllib.parse

from google.genai import Client
from google.genai import types

from ..agents.invocation_context import InvocationContext
Expand All @@ -31,6 +34,12 @@
# capabilities.
_MODEL_ACCESSIBLE_URI_SCHEMES = {'gs', 'https', 'http'}

# Maximum file size for inline_data (20MB as per Gemini API documentation)
# Maximum file size for Files API (2GB as per Gemini API documentation)
# https://ai.google.dev/gemini-api/docs/files
_MAX_INLINE_DATA_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
_MAX_FILES_API_SIZE_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB


class SaveFilesAsArtifactsPlugin(BasePlugin):
"""A plugin that saves files embedded in user messages as artifacts.
Expand Down Expand Up @@ -81,18 +90,77 @@ async def on_user_message_callback(
continue

try:
# Use display_name if available, otherwise generate a filename
# Check file size before processing
inline_data = part.inline_data
file_size = len(inline_data.data) if inline_data.data else 0

# Use display_name if available, otherwise generate a filename
file_name = inline_data.display_name
if not file_name:
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
logger.info(
f'No display_name found, using generated filename: {file_name}'
)

# Store original filename for display to user/ placeholder
# Store original filename for display to user/placeholder
display_name = file_name

# Check if file exceeds Files API limit (2GB)
if file_size > _MAX_FILES_API_SIZE_BYTES:
file_size_gb = file_size / (1024 * 1024 * 1024)
error_message = (
f'File {display_name} ({file_size_gb:.2f} GB) exceeds the'
' maximum supported size of 2GB. Please upload a smaller file.'
)
logger.warning(error_message)
new_parts.append(types.Part(text=f'[Upload Error: {error_message}]'))
modified = True
continue

# For files larger than 20MB, use Files API
if file_size > _MAX_INLINE_DATA_SIZE_BYTES:
file_size_mb = file_size / (1024 * 1024)
logger.info(
f'File {display_name} ({file_size_mb:.2f} MB) exceeds'
' inline_data limit. Uploading via Files API...'
)

# Upload to Files API and convert to file_data
try:
file_part = await self._upload_to_files_api(
inline_data=inline_data,
file_name=file_name,
)

# Save the file_data artifact
version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
user_id=invocation_context.user_id,
session_id=invocation_context.session.id,
filename=file_name,
artifact=copy.copy(file_part),
)

placeholder_part = types.Part(
text=f'[Uploaded Artifact: "{display_name}"]'
)
new_parts.append(placeholder_part)
new_parts.append(file_part)
modified = True
logger.info(f'Successfully uploaded {display_name} via Files API')
except Exception as e:
error_message = (
f'Failed to upload file {display_name} ({file_size_mb:.2f} MB)'
f' via Files API: {str(e)}'
)
logger.error(error_message)
new_parts.append(
types.Part(text=f'[Upload Error: {error_message}]')
)
modified = True
continue

# For files <= 20MB, use inline_data (existing behavior)
# Create a copy to stop mutation of the saved artifact if the original part is modified
version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
Expand Down Expand Up @@ -131,6 +199,63 @@ async def on_user_message_callback(
else:
return None

async def _upload_to_files_api(
self,
*,
inline_data: types.Blob,
file_name: str,
) -> types.Part:

# Create a temporary file with the inline data
temp_file_path = None
try:
# Determine file extension from display_name or mime_type
file_extension = ''
if inline_data.display_name and '.' in inline_data.display_name:
file_extension = os.path.splitext(inline_data.display_name)[1]
elif inline_data.mime_type:
# Simple mime type to extension mapping
mime_to_ext = {
'application/pdf': '.pdf',
'image/png': '.png',
'image/jpeg': '.jpg',
'image/gif': '.gif',
'text/plain': '.txt',
'application/json': '.json',
}
file_extension = mime_to_ext.get(inline_data.mime_type, '')

# Create temporary file
with tempfile.NamedTemporaryFile(
mode='wb',
suffix=file_extension,
delete=False,
) as temp_file:
temp_file.write(inline_data.data)
temp_file_path = temp_file.name

# Upload to Files API
client = Client()
uploaded_file = client.files.upload(file=temp_file_path)

# Create file_data Part
return types.Part(
file_data=types.FileData(
file_uri=uploaded_file.uri,
mime_type=inline_data.mime_type,
display_name=inline_data.display_name or file_name,
)
)
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception as cleanup_error:
logger.warning(
f'Failed to cleanup temp file {temp_file_path}: {cleanup_error}'
)

async def _build_file_reference_part(
self,
*,
Expand Down
Loading