Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing vector_add to zok frontend #78

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
341 changes: 196 additions & 145 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u16.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u16

def main() -> u16[5]:
u16[5] a = [1, 2, 3, 4, 5]
u16[5] b = [2, 3, 4, 5, 6]
assert(vadd_u16(a, b) == [3, 5, 7, 9, 11])
return vadd_u16(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u32.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u32

def main() -> u32[5]:
u32[5] a = [1, 2, 3, 4, 5]
u32[5] b = [2, 3, 4, 5, 6]
assert(vadd_u32(a, b) == [3, 5, 7, 9, 11])
return vadd_u32(a,b)
6 changes: 6 additions & 0 deletions scripts/zx_tests/vadd_u32.zxf
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from "EMBED" import vadd_u32

def main() -> u32[5]:
u32[5] a = [1, 2, 3, 4, 5]
u32[5] b = []
return vadd_u32(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u64.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u64

def main() -> u64[5]:
u64[5] a = [1, 2, 3, 4, 5]
u64[5] b = [2, 3, 4, 5, 6]
assert(vadd_u64(a, b) == [3, 5, 7, 9, 11])
return vadd_u64(a,b)
7 changes: 7 additions & 0 deletions scripts/zx_tests/vadd_u8.zx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from "EMBED" import vadd_u8

def main() -> u8[5]:
u8[5] a = [1, 2, 3, 4, 5]
u8[5] b = [2, 3, 4, 5, 6]
assert(vadd_u8(a, b) == [3, 5, 7, 9, 11])
return vadd_u8(a,b)
21 changes: 19 additions & 2 deletions src/front/zsharp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl<'ast> ZGen<'ast> {
"bit_array_le" => {
if args.len() != 2 {
Err(format!(
"Got {} args to EMBED/bit_array_le, expected 1",
"Got {} args to EMBED/bit_array_le, expected 2",
args.len()
))
} else if generics.len() != 1 {
Expand Down Expand Up @@ -305,6 +305,24 @@ impl<'ast> ZGen<'ast> {
Ok(uint_lit(DFL_T.modulus().significant_bits(), 32))
}
}
"vadd_u8" | "vadd_u16" | "vadd_u32" | "vadd_u64" => {
if args.len() != 2 {
Err(format!(
"Got {} args to EMBED/vadd_*, expected 2",
args.len()
))
} else if generics.len() != 1 {
Err(format!(
"Got {} generic args to EMBED/vadd_*, expected 1",
generics.len()
))
} else {
assert!(args.iter().all(|t| matches!(t.type_(), Ty::Array(_, _))));
let b = args.pop().unwrap();
let a = args.pop().unwrap();
vector_op(BV_ADD, a, b)
}
}
_ => Err(format!("Unknown or unimplemented builtin '{}'", f_name)),
}
}
Expand Down Expand Up @@ -891,7 +909,6 @@ impl<'ast> ZGen<'ast> {
} else {
debug!("Expr: {}", e.span().as_str());
}

match e {
ast::Expression::Ternary(u) => {
match self.expr_impl_::<true>(&u.first).ok().and_then(const_bool) {
Expand Down
14 changes: 14 additions & 0 deletions src/front/zsharp/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,20 @@ pub fn bit_array_le(a: T, b: T, n: usize) -> Result<T, String> {
))
}

pub fn vector_op(op: Op, a: T, b: T) -> Result<T, String> {
match (a.ty, b.ty) {
(Ty::Array(a_s, a_ty), Ty::Array(b_s, b_ty)) => {
if a_s == b_s && a_ty == b_ty {
let t = term![Op::Map(Box::new(op)); a.term, b.term];
Ok(T::new(Ty::Array(a_s, a_ty), t))
} else {
panic!("Mismatched array types (this is a bug: type checking should have caught this!)");
}
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
}
_ => Err("Cannot do vector_op on non-array types".to_string()),
}
}

pub struct ZSharp {
values: Option<HashMap<String, Integer>>,
}
Expand Down
23 changes: 23 additions & 0 deletions src/ir/opt/cfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,29 @@ pub fn fold_cache(node: &Term, cache: &mut TermCache<Term>) -> Term {
b.width() + w,
))))
}),
Op::Map(op) => match (get(0).as_array_opt(), get(1).as_array_opt()) {
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
(Some(a), Some(b)) => {
// TODO: extend for n-ary arrays
let mut res = a.clone();
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
let mut merge = ArrayMerge::new(a.clone(), b.clone());
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
for (i, va, vb) in merge.into_iter() {
let r = fold_cache(
&term![*op.clone(); leaf_term(Op::Const(va.clone())), leaf_term(Op::Const(vb.clone()))],
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
cache,
);
match r.as_value_opt() {
Some(v) => {
res = res.clone().store(i, v.clone());
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
}
None => {
panic!("Unable to constant fold idx: {}", i);
}
}
}
Some(leaf_term(Op::Const(Value::Array(res))))
}
_ => None,
},
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
_ => None,
};
let new_t = {
Expand Down
100 changes: 100 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ use hashconsing::{HConsed, WHConsed};
use lazy_static::lazy_static;
use log::debug;
use rug::Integer;
use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Display, Formatter};
use std::iter::Peekable;
use std::sync::{Arc, RwLock};

pub mod bv;
Expand Down Expand Up @@ -721,6 +723,104 @@ impl Array {
self.check_idx(idx);
self.map.get(idx).unwrap_or(&*self.default).clone()
}

/// Iter
pub fn into_iter(&self) -> std::collections::btree_map::IntoIter<Value, Value> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to implement the into_iter function, you should implement the std::iter::IntoIterator trait, not just add a member function.

Continuing my crusade against unnecessary cloning: you probably want impl<'a> IntoIterator for &'a Array, i.e., not a consuming iterator---just one that returns references. So I think what you really want is something like this (untested---see impl<'a, K, V> IntoIterator for &'a BTreeMap<K, V> in the stdlib):

impl<'a> IntoIterator for &'a Array {
    type Item = (&'a Value, &'a Value);
    type IntoIter = std::collection::btree_map::Iter<'a, Value, Value>;
    fn into_iter(self) -> Self::IntoIter {
        self.map.iter()
    }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second, higher-level comment: it's extremely un-idiomatic for an into_iter function to take &self rather than self. This may seem slightly counterintuitive given what I've just said in my prior comment, but notice in my proposed impl IntoIterator that the argument to fn into_iter is self---it's just that we think of self as being a value of type &'a Array.

In contrast, what is currently implemented in the PR takes &self and then clones it to produce a by-value iterator, which is bad because it means that if someone did

    for (k, v) in some_array.into_iter() {
    }

they will have potentially caused a very expensive clone that they didn't expect (because idiomatically, into_iter never clones---it either consumes a value or it takes a reference and returns references).

Does this make sense?

self.map.clone().into_iter()
}
}

/// Merge two Array Iterators
pub struct ArrayMerge {
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
left: Peekable<Box<std::collections::btree_map::IntoIter<Value, Value>>>,
right: Peekable<Box<std::collections::btree_map::IntoIter<Value, Value>>>,
left_dfl: Value,
right_dfl: Value,
}

impl ArrayMerge {
/// Create a new [ArrayMerge] from two [Array]s
pub fn new(a: Array, b: Array) -> Self {
if a.size != b.size {
panic!("IR Arrays have different lengths: {}, {}", a.size, b.size);
}
if a.key_sort != b.key_sort {
panic!(
"IR Arrays have different key sorts: {}, {}",
a.key_sort, b.key_sort
);
}
if a.default.sort() != b.default.sort() {
panic!(
"IR Arrays default values have different key sorts: {}, {}",
a.default.sort(),
b.default.sort()
);
}
edwjchen marked this conversation as resolved.
Show resolved Hide resolved

Self {
left: Box::new(a.into_iter()).peekable(),
right: Box::new(b.into_iter()).peekable(),
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
left_dfl: *a.default,
right_dfl: *b.default,
}
}

/// Iter
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
pub fn into_iter(&mut self) -> Box<dyn Iterator<Item = (Value, Value, Value)>> {
let mut acc: Vec<(Value, Value, Value)> = Vec::new();
let mut next = self.next_();
while let Some(n) = next {
acc.push(n);
next = self.next_();
}
Box::new(acc.into_iter())
}

/// Next
pub fn next_(&mut self) -> Option<(Value, Value, Value)> {
let l_peek = self.left.peek();
let r_peek = self.right.peek();

let mut left_next = false;
let mut right_next = false;
edwjchen marked this conversation as resolved.
Show resolved Hide resolved

let res = match (l_peek, r_peek) {
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
(Some((l_ind, l_val)), Some((r_ind, r_val))) => match l_ind.cmp(r_ind) {
Ordering::Less => {
left_next = true;
Some((l_ind.clone(), l_val.clone(), self.right_dfl.clone()))
edwjchen marked this conversation as resolved.
Show resolved Hide resolved
}
Ordering::Greater => {
right_next = true;
Some((r_ind.clone(), self.left_dfl.clone(), r_val.clone()))
}
Ordering::Equal => {
left_next = true;
right_next = true;
Some((l_ind.clone(), l_val.clone(), r_val.clone()))
}
},
(Some((l_ind, l_val)), None) => {
left_next = true;
Some((l_ind.clone(), l_val.clone(), self.right_dfl.clone()))
}
(None, Some((r_ind, r_val))) => {
right_next = true;
Some((r_ind.clone(), self.left_dfl.clone(), r_val.clone()))
}
(None, None) => None,
};

if left_next {
self.left.next();
}
if right_next {
self.right.next();
}

res
}
}

impl Display for Value {
Expand Down
13 changes: 13 additions & 0 deletions third_party/ZoKrates/zokrates_stdlib/stdlib/EMBED.zok
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,16 @@ def u16_to_u32(u16 i) -> u32:

def u8_to_u16(u8 i) -> u16:
return 0u16

// vector functions
def vadd_u8<N>(u8[N] a, u8[N] b) -> u8[N]:
return [0; N]

def vadd_u16<N>(u16[N] a, u16[N] b) -> u16[N]:
return [0; N]

def vadd_u32<N>(u32[N] a, u32[N] b) -> u32[N]:
return [0; N]

def vadd_u64<N>(u64[N] a, u64[N] b) -> u64[N]:
return [0; N]