Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending the type preservation check #188

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions kernel/typing.ml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

open Basic
open Format
open Rule
Expand Down Expand Up @@ -150,45 +151,52 @@ let unshift_reduce sg q t =
( try Some (Subst.unshift q (R.snf sg t))
with Subst.UnshiftExn -> None )

let rec pseudo_u sg (fail: int*term*term-> unit) (sigma:SS.t) : (int*term*term) list -> SS.t = function
| [] -> sigma
type unif = (int*term*term) list

let rec pseudo_u sg (fail: int*term*term-> unit) (sigma:SS.t)
(acc:unif) : unif -> SS.t * unif = function
| [] -> sigma, acc
| (q,t1,t2)::lst ->
begin
let t1' = R.whnf sg (SS.apply sigma q t1) in
let t2' = R.whnf sg (SS.apply sigma q t2) in
let keepon () = pseudo_u sg fail sigma lst in
if term_eq t1' t2' then keepon ()
let keepon_with acc = pseudo_u sg fail sigma lst acc in
let keepon () = keepon_with acc in
if term_eq t1' t2' then keepon_with acc
else
let warn () = fail (q,t1,t2); keepon () in
let register () = keepon_with ((q,t1',t2')::acc) in
let warn () = fail (q,t1,t2); register() in
match t1', t2' with
| Kind, Kind | Type _, Type _ -> assert false (* Equal terms *)
| DB (_,_,n), DB (_,_,n') when n=n' -> assert false (* Equal terms *)
| _, Kind | Kind, _ |_, Type _ | Type _, _ -> warn ()

| Pi (_,_,a,b), Pi (_,_,a',b') ->
pseudo_u sg fail sigma ((q,a,a')::(q+1,b,b')::lst)
pseudo_u sg fail sigma ((q,a,a')::(q+1,b,b')::lst) acc
| Lam (_,_,_,b), Lam (_,_,_,b') ->
pseudo_u sg fail sigma ((q+1,b,b')::lst)
pseudo_u sg fail sigma ((q+1,b,b')::lst) acc

(* Potentially eta-equivalent terms *)
| Lam (_,i,_,b), a when !Reduction.eta ->
let b' = mk_App (Subst.shift 1 a) (mk_DB dloc i 0) [] in
pseudo_u sg fail sigma ((q+1,b,b')::lst)
pseudo_u sg fail sigma ((q+1,b,b')::lst) acc
| a, Lam (_,i,_,b) when !Reduction.eta ->
let b' = mk_App (Subst.shift 1 a) (mk_DB dloc i 0) [] in
pseudo_u sg fail sigma ((q+1,b,b')::lst)
pseudo_u sg fail sigma ((q+1,b,b')::lst) acc

| Const (_,c), Const (_,c') when name_eq c c' -> keepon ()
| Const (l,cst), t when not (Signature.is_static sg l cst) ->
( match unshift_reduce sg q t with None -> warn () | Some _ -> keepon ())
( match unshift_reduce sg q t with None -> warn ()
| Some t -> keepon_with ((0,t1',t)::acc))
| t, Const (l,cst) when not (Signature.is_static sg l cst) ->
( match unshift_reduce sg q t with None -> warn () | Some _ -> keepon ())
( match unshift_reduce sg q t with None -> warn ()
| Some t -> keepon_with ((0,t,t2')::acc))

| DB (l1,x1,n1), DB (l2,x2,n2) when n1>=q && n2>=q ->
let (n,t) = if n1<n2
then (n1,mk_DB l2 x2 (n2-q))
else (n2,mk_DB l1 x1 (n1-q)) in
pseudo_u sg fail (SS.add sigma (n-q) t) lst
pseudo_u sg fail (SS.add sigma (n-q) t) lst acc
| DB (_,_,n), t when n>=q ->
begin
let n' = n-q in
Expand All @@ -197,7 +205,7 @@ let rec pseudo_u sg (fail: int*term*term-> unit) (sigma:SS.t) : (int*term*term)
| Some ut ->
let t' = if Subst.occurs n' ut then ut else R.snf sg ut in
if Subst.occurs n' t' then warn ()
else pseudo_u sg fail (SS.add sigma n' t') lst
else pseudo_u sg fail (SS.add sigma n' t') lst acc
end
| t, DB (_,_,n) when n>=q ->
begin
Expand All @@ -207,24 +215,24 @@ let rec pseudo_u sg (fail: int*term*term-> unit) (sigma:SS.t) : (int*term*term)
| Some ut ->
let t' = if Subst.occurs n' ut then ut else R.snf sg ut in
if Subst.occurs n' t' then warn ()
else pseudo_u sg fail (SS.add sigma n' t') lst
else pseudo_u sg fail (SS.add sigma n' t') lst acc
end

| App (DB (_,_,n),_,_), _ when n >= q ->
if R.are_convertible sg t1' t2' then keepon () else warn ()
| _ , App (DB (_,_,n),_,_) when n >= q ->
if R.are_convertible sg t1' t2' then keepon () else warn ()

| App (Const (l,cst),_,_), _ when not (Signature.is_static sg l cst) -> keepon ()
| _, App (Const (l,cst),_,_) when not (Signature.is_static sg l cst) -> keepon ()
| App (Const (l,cst),_,_), _ when not (Signature.is_static sg l cst) -> register ()
| _, App (Const (l,cst),_,_) when not (Signature.is_static sg l cst) -> register ()

| App (f,a,args), App (f',a',args') ->
(* f = Kind | Type | DB n when n<q | Pi _
* | Const name when (is_static name) *)
begin
match safe_add_to_list q lst args args' with
| None -> warn () (* Different number of arguments. *)
| Some lst2 -> pseudo_u sg fail sigma ((q,f,f')::(q,a,a')::lst2)
| Some lst2 -> pseudo_u sg fail sigma ((q,f,f')::(q,a,a')::lst2) acc
end

| _, _ -> warn ()
Expand Down
25 changes: 25 additions & 0 deletions tests/OK/higher_order_cstr1.dk
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
(; OK ;)

N : Type.
A : Type.
T : A -> Type.

P : (N -> A) -> Type.
p : f : (N -> A) -> P f.

g : (x : N) -> f : (N -> A) -> T (f x).
def h : f : (N -> A) -> (x : N -> T (f x)) -> P (x => f x).

(; The following rules are well-typed because

we infer the constraint
(under 1 lambda) X x[0] = Y x[0]

and we need to deduce
P (x => X x) = P (x => Y x)
;)

[X,Y] h Y (z => g z X ) --> p (x => X x).
[X,Y] h (y => Y y) (z => g z X ) --> p (x => X x).
[X,Y] h Y (z => g z (x => X x)) --> p (x => X x).
[X,Y] h (y => Y y) (z => g z (x => X x)) --> p (x => X x).
25 changes: 25 additions & 0 deletions tests/OK/higher_order_cstr2.dk
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
(; OK ;)

N : Type.
0 : N.

A : Type.
T : A -> Type.
p : f : (N -> A) -> T (f 0).

g : (x : N) -> f : (N -> A) -> T (f x).
def h : f : (N -> A) -> (x : N -> T (f x)) -> T (f 0).

(; The following rules are well-typed because

we infer the constraint
(under 1 lambda) X x[0] = Y x[0]

and we need to deduce
T (X 0) = T (Y 0)
;)

[X,Y] h Y (z => g z X ) --> p X.
[X,Y] h (y => Y y) (z => g z X ) --> p X.
[X,Y] h Y (z => g z (x => X x)) --> p (x => X x).
[X,Y] h (y => Y y) (z => g z (x => X x)) --> p (x => X x).