(*************************************************************
 *                                                           *
 *  Cryptographic protocol verifier                          *
 *                                                           *
 *  Bruno Blanchet, Vincent Cheval, and Marc Sylvestre       *
 *                                                           *
 *  Copyright (C) INRIA, CNRS 2000-2020                      *
 *                                                           *
 *************************************************************)

(*

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details (in file LICENSE).

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

*)
free c.
(* 

Needham-Schroeder public key protocol

Message 1: A -> S : (A, B)
Message 2: S -> A : { pkB, B }skS
Message 3: A -> B : { Na, A }pkB
Message 4: B -> S : (B, A)
Message 5: S -> B : { pkA, A }skS
Message 6: B -> A : { Na, Nb }pkA
Message 7: A -> B : { Nb }pkB

The heart of the protocol is messages 3, 6, 7.

*)

(* Public key cryptography *)

fun pk/1.
fun encrypt/2.
reduc decrypt(encrypt(x,pk(y)),y) = x.

(* Host names
   The server has a table (host name, public key), which we
   represent by the function getkey. *)

fun host/1.
private reduc getkey(host(x)) = x.

(* Signatures *)

fun sign/2.
reduc checksign(sign(x,y),pk(y)) = x.
reduc getmess(sign(x,y)) = x.

(* Shared-key cryptography *)

fun sencrypt/2.
reduc sdecrypt(sencrypt(x,y),y) = x.

(* Secrecy assumptions *)

not skA.
not skB.
not skS.

private free secretANa, secretANb, secretBNa, secretBNb.
query attacker:secretANa;
      attacker:secretANb;
      attacker:secretBNa;
      attacker:secretBNb.
query evinj:endBparam(x) ==> evinj:beginBparam(x).
query evinj:endBfull(x1,x2,x3,x4,x5,x6) ==> evinj:beginBfull(x1,x2,x3,x4,x5,x6).
query evinj:endAparam(x) ==> evinj:beginAparam(x).
query evinj:endAfull(x1,x2,x3,x4,x5,x6) ==> evinj:beginAfull(x1,x2,x3,x4,x5,x6).

let processA =  
	(* Choose the other host *)
	in(c,hostX); 
	event beginBparam(hostX); 
	(* Message 1: Get the public key certificate for the other host *)
	out(c, (hostA, hostX));
	(* Message 2 *)
	in(c, ms); 
	let (pkX,=hostX) = checksign(ms,pkS) in
        (* Message 3 *)
	new Na; 
        out(c, encrypt((Na, hostA), pkX));
        (* Message 6 *)
        in(c, m); 
	let (=Na, NX2) = decrypt(m, skA) in
	event beginBfull(Na, hostA, hostX, pkX, pkA, NX2);
        (* Message 7 *)
        out(c, encrypt(NX2, pkX));
        (* OK *)
	if hostX = hostB then
	event endAparam(hostA);
	event endAfull(Na, hostA, hostX, pkX, pkA, NX2);
	out(c, sencrypt(secretANa, Na));
	out(c, sencrypt(secretANb, NX2)).

let processB = 
        (* Message 3 *)
	in(c, m);
	let (NY, hostY) = decrypt(m, skB) in
	event beginAparam(hostY);
	(* Message 4: Get the public key certificate for the other host *)
        out(c, (hostB, hostY));
	(* Message 5 *)
	in(c,ms);
        let (pkY,=hostY) = checksign(ms,pkS) in
        (* Message 6 *)
	new Nb;
	event beginAfull(NY, hostY, hostB, pkB, pkY, Nb);
	out(c, encrypt((NY, Nb), pkY));
        (* Message 7 *)
	in(c, m3);
        if Nb = decrypt(m3, skB) then
	(* OK *)
        if hostY = hostA then
	event endBparam(hostB);
	event endBfull(NY, hostY, hostB, pkB, pkA, Nb);
	out(c, sencrypt(secretBNa, NY));
	out(c, sencrypt(secretBNb, Nb)).

let processS =  in(c,m); 
	        let (a,b) = m in
		let sb = getkey(b) in
                out(c,sign((sb,b),skS)).

process new skA; let pkA = pk(skA) in
        out(c, pkA);
        new skB; let pkB = pk(skB) in
        out(c, pkB);
	new skS; let pkS = pk(skS) in
	out(c, pkS);
	let hostA = host(pkA) in
	out(c, hostA);
	let hostB = host(pkB) in
	out(c, hostB);
	((!processA) | (!processB) | (!processS))
