1- import base64
21import inspect
32import json
43import os
54import sys
6- import traceback
7- import six
5+ from adk .io import create_exception , format_data , format_response
86
97
108class ADK (object ):
@@ -15,6 +13,7 @@ def __init__(self, apply_func, load_func=None):
1513 :param load_func: An optional supplier function used if load time events are required, has an arity of 0.
1614 """
1715 self .FIFO_PATH = "/tmp/algoout"
16+ self .manifest_path = "data_manifest.json"
1817 apply_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (apply_func )
1918 if load_func :
2019 load_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (load_func )
@@ -49,51 +48,12 @@ def apply(self, payload):
4948 apply_result = self .apply_func (payload , self .load_result )
5049 else :
5150 apply_result = self .apply_func (payload )
52- response_obj = self . format_response (apply_result )
51+ response_obj = format_response (apply_result )
5352 return response_obj
5453 except Exception as e :
55- response_obj = self . create_exception (e )
54+ response_obj = create_exception (e )
5655 return response_obj
5756
58- def format_data (self , request ):
59- if request ["content_type" ] in ["text" , "json" ]:
60- data = request ["data" ]
61- elif request ["content_type" ] == "binary" :
62- data = self .wrap_binary_data (base64 .b64decode (request ["data" ]))
63- else :
64- raise Exception ("Invalid content_type: {}" .format (request ["content_type" ]))
65- return data
66-
67- def is_binary (self , arg ):
68- if six .PY3 :
69- return isinstance (arg , base64 .bytes_types )
70-
71- return isinstance (arg , bytearray )
72-
73- def wrap_binary_data (self , data ):
74- if six .PY3 :
75- return bytes (data )
76- else :
77- return bytearray (data )
78-
79- def format_response (self , response ):
80- if self .is_binary (response ):
81- content_type = "binary"
82- response = str (base64 .b64encode (response ), "utf-8" )
83- elif isinstance (response , six .string_types ) or isinstance (response , six .text_type ):
84- content_type = "text"
85- else :
86- content_type = "json"
87- response_string = json .dumps (
88- {
89- "result" : response ,
90- "metadata" : {
91- "content_type" : content_type
92- }
93- }
94- )
95- return response_string
96-
9757 def write_to_pipe (self , payload , pprint = print ):
9858 if self .is_local :
9959 if isinstance (payload , dict ):
@@ -109,40 +69,24 @@ def write_to_pipe(self, payload, pprint=print):
10969 if os .name == "nt" :
11070 sys .stdin = payload
11171
112- def create_exception (self , exception , loading_exception = False ):
113- if hasattr (exception , "error_type" ):
114- error_type = exception .error_type
115- elif loading_exception :
116- error_type = "LoadingError"
117- else :
118- error_type = "AlgorithmError"
119- response = json .dumps ({
120- "error" : {
121- "message" : str (exception ),
122- "stacktrace" : traceback .format_exc (),
123- "error_type" : error_type ,
124- }
125- })
126- return response
127-
12872 def process_local (self , local_payload , pprint ):
12973 result = self .apply (local_payload )
13074 self .write_to_pipe (result , pprint = pprint )
13175
13276 def init (self , local_payload = None , pprint = print ):
133- self .load ()
134- if self .is_local and local_payload :
77+ self .load ()
78+ if self .is_local and local_payload :
79+ if self .loading_exception :
80+ load_error = create_exception (self .loading_exception , loading_exception = True )
81+ self .write_to_pipe (load_error , pprint = pprint )
82+ self .process_local (local_payload , pprint )
83+ else :
84+ for line in sys .stdin :
85+ request = json .loads (line )
86+ formatted_input = format_data (request )
13587 if self .loading_exception :
136- load_error = self . create_exception (self .loading_exception , loading_exception = True )
88+ load_error = create_exception (self .loading_exception , loading_exception = True )
13789 self .write_to_pipe (load_error , pprint = pprint )
138- self .process_local (local_payload , pprint )
139- else :
140- for line in sys .stdin :
141- request = json .loads (line )
142- formatted_input = self .format_data (request )
143- if self .loading_exception :
144- load_error = self .create_exception (self .loading_exception , loading_exception = True )
145- self .write_to_pipe (load_error , pprint = pprint )
146- else :
147- result = self .apply (formatted_input )
148- self .write_to_pipe (result )
90+ else :
91+ result = self .apply (formatted_input )
92+ self .write_to_pipe (result )
0 commit comments