Skip to content

Commit d9e1fba

Browse files
Update benchmark.py
1 parent 6219c16 commit d9e1fba

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

benchmarks/utils/benchmark.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def parse_args():
4040
type=str, default="local",
4141
help="memory placement policy, 'local','interleave' or 'none'")
4242
parser.add_argument("-fa",
43-
type=int, default=0, choices=range(0,2),
43+
action="store_true",
4444
help="enable flash attention")
4545
return parser.parse_args()
4646

@@ -152,8 +152,10 @@ def main():
152152
"/llm/llama-batched-bench", "-m", args.model, "-c", str(args.kv_cache), "-b", "2048", "-ub", "512", "-npp", str(args.prompt_size), "-ntg", str(args.tg_size),
153153
"-npl", str(args.batch_size), "-t", str(args.num_threads), "-tb", str(args.num_threads), "--no-mmap"]
154154

155-
if args.fa != 0 :
156-
cmd.append("--flash-attn")
155+
if args.fa:
156+
cmd += ["-fa", "on"]
157+
else:
158+
cmd += ["-fa", "off"]
157159

158160
else:
159161
print("FAIL: batched-bench not found!")

0 commit comments

Comments
 (0)