Skip to content

Commit 17815d1

Browse files
committed
initial commit, updated example to spec
1 parent 64dbb70 commit 17815d1

File tree

5 files changed

+126
-87
lines changed

5 files changed

+126
-87
lines changed

adk/ADK.py

Lines changed: 18 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import base64
21
import inspect
32
import json
43
import os
54
import sys
6-
import traceback
7-
import six
5+
from adk.io import create_exception, format_data, format_response
86

97

108
class ADK(object):
@@ -15,6 +13,7 @@ def __init__(self, apply_func, load_func=None):
1513
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
1614
"""
1715
self.FIFO_PATH = "/tmp/algoout"
16+
self.manifest_path = "data_manifest.json"
1817
apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)
1918
if load_func:
2019
load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func)
@@ -49,51 +48,12 @@ def apply(self, payload):
4948
apply_result = self.apply_func(payload, self.load_result)
5049
else:
5150
apply_result = self.apply_func(payload)
52-
response_obj = self.format_response(apply_result)
51+
response_obj = format_response(apply_result)
5352
return response_obj
5453
except Exception as e:
55-
response_obj = self.create_exception(e)
54+
response_obj = create_exception(e)
5655
return response_obj
5756

58-
def format_data(self, request):
59-
if request["content_type"] in ["text", "json"]:
60-
data = request["data"]
61-
elif request["content_type"] == "binary":
62-
data = self.wrap_binary_data(base64.b64decode(request["data"]))
63-
else:
64-
raise Exception("Invalid content_type: {}".format(request["content_type"]))
65-
return data
66-
67-
def is_binary(self, arg):
68-
if six.PY3:
69-
return isinstance(arg, base64.bytes_types)
70-
71-
return isinstance(arg, bytearray)
72-
73-
def wrap_binary_data(self, data):
74-
if six.PY3:
75-
return bytes(data)
76-
else:
77-
return bytearray(data)
78-
79-
def format_response(self, response):
80-
if self.is_binary(response):
81-
content_type = "binary"
82-
response = str(base64.b64encode(response), "utf-8")
83-
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
84-
content_type = "text"
85-
else:
86-
content_type = "json"
87-
response_string = json.dumps(
88-
{
89-
"result": response,
90-
"metadata": {
91-
"content_type": content_type
92-
}
93-
}
94-
)
95-
return response_string
96-
9757
def write_to_pipe(self, payload, pprint=print):
9858
if self.is_local:
9959
if isinstance(payload, dict):
@@ -109,40 +69,24 @@ def write_to_pipe(self, payload, pprint=print):
10969
if os.name == "nt":
11070
sys.stdin = payload
11171

112-
def create_exception(self, exception, loading_exception=False):
113-
if hasattr(exception, "error_type"):
114-
error_type = exception.error_type
115-
elif loading_exception:
116-
error_type = "LoadingError"
117-
else:
118-
error_type = "AlgorithmError"
119-
response = json.dumps({
120-
"error": {
121-
"message": str(exception),
122-
"stacktrace": traceback.format_exc(),
123-
"error_type": error_type,
124-
}
125-
})
126-
return response
127-
12872
def process_local(self, local_payload, pprint):
12973
result = self.apply(local_payload)
13074
self.write_to_pipe(result, pprint=pprint)
13175

13276
def init(self, local_payload=None, pprint=print):
133-
self.load()
134-
if self.is_local and local_payload:
77+
self.load()
78+
if self.is_local and local_payload:
79+
if self.loading_exception:
80+
load_error = create_exception(self.loading_exception, loading_exception=True)
81+
self.write_to_pipe(load_error, pprint=pprint)
82+
self.process_local(local_payload, pprint)
83+
else:
84+
for line in sys.stdin:
85+
request = json.loads(line)
86+
formatted_input = format_data(request)
13587
if self.loading_exception:
136-
load_error = self.create_exception(self.loading_exception, loading_exception=True)
88+
load_error = create_exception(self.loading_exception, loading_exception=True)
13789
self.write_to_pipe(load_error, pprint=pprint)
138-
self.process_local(local_payload, pprint)
139-
else:
140-
for line in sys.stdin:
141-
request = json.loads(line)
142-
formatted_input = self.format_data(request)
143-
if self.loading_exception:
144-
load_error = self.create_exception(self.loading_exception, loading_exception=True)
145-
self.write_to_pipe(load_error, pprint=pprint)
146-
else:
147-
result = self.apply(formatted_input)
148-
self.write_to_pipe(result)
90+
else:
91+
result = self.apply(formatted_input)
92+
self.write_to_pipe(result)

adk/io.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import traceback
2+
import six
3+
import base64
4+
import json
5+
6+
7+
def format_data(request):
8+
if request["content_type"] in ["text", "json"]:
9+
data = request["data"]
10+
elif request["content_type"] == "binary":
11+
data = wrap_binary_data(base64.b64decode(request["data"]))
12+
else:
13+
raise Exception("Invalid content_type: {}".format(request["content_type"]))
14+
return data
15+
16+
17+
def is_binary(arg):
18+
if six.PY3:
19+
return isinstance(arg, base64.bytes_types)
20+
21+
return isinstance(arg, bytearray)
22+
23+
24+
def wrap_binary_data(data):
25+
if six.PY3:
26+
return bytes(data)
27+
else:
28+
return bytearray(data)
29+
30+
31+
def format_response(response):
32+
if is_binary(response):
33+
content_type = "binary"
34+
response = str(base64.b64encode(response), "utf-8")
35+
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
36+
content_type = "text"
37+
else:
38+
content_type = "json"
39+
response_string = json.dumps(
40+
{
41+
"result": response,
42+
"metadata": {
43+
"content_type": content_type
44+
}
45+
}
46+
)
47+
return response_string
48+
49+
50+
def create_exception(exception, loading_exception=False):
51+
if hasattr(exception, "error_type"):
52+
error_type = exception.error_type
53+
elif loading_exception:
54+
error_type = "LoadingError"
55+
else:
56+
error_type = "AlgorithmError"
57+
response = json.dumps({
58+
"error": {
59+
"message": str(exception),
60+
"stacktrace": traceback.format_exc(),
61+
"error_type": error_type,
62+
}
63+
})
64+
return response

adk/manifest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import json
3+
import hashlib
4+
5+
6+
def process_manifest(self):
7+
if os.path.exists(self.manifest_path):
8+
with open(self.manifest_path) as f:
9+
manifest_data = json.load(f)
10+
manifest = manifest_data
11+
12+
13+
def check_hash(model_path, expected_hash):
14+
real_hash = md5(model_path)
15+
if real_hash == expected_hash:
16+
return True
17+
else:
18+
return False
19+
20+
21+
def md5(fname):
22+
hash_md5 = hashlib.md5()
23+
with open(fname, "rb") as f:
24+
for chunk in iter(lambda: f.read(4096), b""):
25+
hash_md5.update(chunk)
26+
return hash_md5.hexdigest()
Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
{
2-
"label_file": {
3-
"filepath": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
4-
"md5_hash": "c2c37ea517e94d9795004a39431a14cb",
5-
"origin_ref": "this file came from imagenet.org",
6-
"uploaded_utc": "2021-05-03-11:05"
7-
},
8-
"squeezenet": {
9-
"filepath": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
10-
"md5_hash": "46a44d32d2c5c07f7f66324bef4c7266",
11-
"origin_ref": "From https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
12-
"uploaded_utc": "2021-05-03-11:05"
2+
"algorithm_name": "pytorch_image_classification",
3+
"model_versioning": "strict",
4+
"required_files" : [
5+
{ "name": "labels",
6+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
7+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
8+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
139
}
10+
],
11+
"optional_files": [
12+
{
13+
"name": "optional_data",
14+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
15+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
16+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
17+
}
18+
]
1419
}

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def load(manifest):
5555
globals = {}
5656
client = Algorithmia.client()
5757
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
58-
globals["model"] = load_model(manifest["squeezenet"], client)
59-
globals["labels"] = load_labels(manifest["label_file"], client)
58+
globals["model"] = load_model(manifest["model_squeezenet"], client)
59+
globals["labels"] = load_labels(manifest["labels"], client)
6060
return globals
6161

6262

0 commit comments

Comments
 (0)