function [PE,g,H] = tvaflexchain(theta,tvamodel,tvatrial,logistidx,expidx,vcf)
% TVAFLEXCHAIN
%
%  Synopsis
%  ========
%
%  [PE,g,H]  = tvaflexchain(theta,tvamodel,tvatrial,logistidx,expidx)
%  [PE,g,H]  = tvaflexchain(theta,tvamodel,tvatrial,logistidx,expidx,vcf)
%  [v,taumu,w] = tvaflexchain(theta,tvamodel,tvatrial,logistidx,expidx,2)
%  
%  -- Author: Mads Dyrholm --
%     Center for Visual Cognition, University of Copenhagen.
%     2009 - September 2011
%  
%  Purpose
%  =======
%
%  Compute trial response probability, the gradient, and the 
%  Hessian, for various factorizations of the TVA model.
%  
%  Inputs
%  ======
%
%  theta - Vector of parameters. The elements and order depends
%  on the '.facstr' field of tvamodel. The tables below gives
%  the relationship between valid characters of facstr and theta. 
%  The elements must live in open domain on input (see TVAFIXER). 
%  Theta must be stripped from guessing constants on input
%  (see TVASTRIPTHETA).
%
%    character | Corresponding TVA parameter.
%    ----------+-----------------------------------------
%     'a'      | Efficiency of selection, alpha.
%     'v'      | Hazard rate.
%     'w'      | Attentional weight.
%     'C'      | Processing capacity.
%     's'      | Split processing capacity.
%     'u'      | u0 or t0. This must always be present.
%     'm'      | Mu.
%
%  With these characters the user can define which factorisation
%  of the TVA model to use. The following table gives examples
%  of various facstr and their corresponding theta.
%
%    facstr    | theta
%    ----------+-----------------------------------------
%     'vu'     | [v values; t0]
%     'wCu'    | [w values without w_1; C; t0]
%     'awCu'   | [alpha; w values without w_1; C; t0]
%     'awsu'   | [alpha; w values without w_1; s values; t0]
%     'awCum'  | [alpha; w values without w_1; C; t0; mu]
%     'awsum'  | [alpha; w values without w_1; s values; t0; mu]
%  
%  Note that the order of the elements of theta is independent
%  of the order of the characters of facstr. The order of the 
%  elements of theta following the first table above. Note
%  that facstr must always contain a 'u'.
%  
%  tvatrial - A single trial, i.e. a single cell from a 
%  TVADATA array. (see also TVALOADER)
%
%  vcf - Flag whether you want the output g and H to represent
%  theta in open domain or in human domain. Set to 1 for open 
%  domain, set to 0 for human domain. Default is 1.
%
%  Outputs
%  =======
%  
%  PE - Trial response probability.
%
%  g - Gradient array. The element order follows the 
%  ordering of theta.
%
%  H - Hessian matrix. The element order follows the 
%  ordering of theta.

if nargin<6,
  vcf = 1;
end

facstr = tvamodel.facstr;
if size(tvamodel.infl,3)>1
  infl = tvamodel.infl(:,:,tvatrial.trialnum);
else
  infl = tvamodel.infl;
end
K = tvamodel.K;
if ismember('u',facstr) %  & ~ismember('s',facstr) 
  s0 = tvamodel.s0;
end
S = tvatrial.display;
t = tvatrial.t;
R = tvatrial.response;
T = tvatrial.targets;
D = tvatrial.distractors;
if ~isempty(D)
  S(D) = S(D) + tvatrial.places;
end  

% transform input
[theta,logistidx,expidx] = tvahuman(theta,logistidx,expidx);
[alpha,w,C,s,v,u0,chdetgm,mu] = tvadeal(tvamodel,theta);

if ~isempty(v)
  v = infl*v;
  % deflator
  E = eye(length(v));
  P = E(:,S); % for permute wrt S
  defl = infl';
  defl = defl*P;
  % forward
  v = v(S);
else
  if ~isempty(alpha)
    infl = [infl;alpha*infl]; % alpha inflator
  end
  w = infl * [1;w];
  % wtilde
  if ~isempty(alpha)
    z = zeros(length(w),1);
    z(length(w)/2+1:end) = w(1:end/2);
    z = z(S);
    zI = zeros(length(w),1); % for sub chains (***)
    zI(length(w)/2+1:end) = 1;
    zI = zI(S);    
  end
  % deflator
  E = eye(length(w));
  P = E(:,S); % for permute wrt S
  defl = infl';
  defl = defl*P;
  defl = defl(2:end,:);
  % forward
  w = w(S);
  sumw = sum(w);
  w1 = w / sumw;
  if ~isempty(C)
    v = C * w1;
  else
    sinfl = tvamodel.sinfl;
    if ~isempty(alpha)
      sinfl = [sinfl;sinfl];
    end
    s = sinfl * s;
    % deflator s
    E = eye(length(s));
    P = E(:,S); % for permute wrt S
    sdefl = sinfl';
    sdefl = sdefl*P;
    % forward s
    s = s(S);
    v = s .* w1;
  end
end

% carry on...
Nx = size(infl,1);
nS = length(S);

% tau
tau = t-u0;

% mask
if isfield(tvatrial,'unmask')
  unmask = tvatrial.unmask;
else
  unmask = 0;
end

if unmask & ~isempty(mu)
  muval = mu;
else
  muval = 0;
end

% output/runlevel 2
if vcf==2
  [PE,g,H] = deal(v,tau+muval,w);
  return
end

% v must be positive
if any(v<0) 
  [PE,gv,Hvv,gt,Hvt,Htt] = deal(0);
else
  
  % compute
  switch tvatrial.task
   case 'CD'
    x = tvatrial.probe;
    [PE,gv,Hvv,gt,Hvt,Htt] = tvapenc(tau + muval,v,K,x,s0);
   case 'WR'
    [PW,gv,Hvv,gt,Hvt,Htt] = tvapwho(tau + muval,v,K,R,s0);
    PE = PW;
   case 'PR'
    [PP,gv,Hvv,gt,Hvt,Htt] = tvappar(tau + muval,v,K,R,T,s0);
    PE = PP;
  end

end

% compose gradient and Hessian
if ~isempty(mu)
  if unmask
    gmu = -gt;
  else
    gmu = 0;
  end
else
  gmu = [];
end

if ~isempty(w)
  if ~isempty(C)
    % gradient chain
    gC = gv'*w1;
    tmp1 = (gv - gv'*w1)/sumw;
    gw = C*tmp1;
    % Hessian chain
    tmp = (eye(nS) - repmat(w1,[1,nS]))/sumw; 
    gvGCw = gv'*tmp; % sub-chain (**)
    dvdw = C*tmp;
    Hvw = Hvv*dvdw;
    % Sub-chain interactions (*)
    gvGww = repmat(gv,[1,length(gv)]);
    gvGww = gvGww+gvGww';
    gvGww = -C*(gvGww - 2*gv'*w1)/(sumw^2);
    Hww = dvdw'*Hvw + gvGww; % (*)
    HCw = w1'*Hvw + gvGCw; % (**)
    HCC = w1'*Hvv*w1;
    Hwt = dvdw'*Hvt;
    HCt = w1'*Hvt;
    % chain wrt alpha
    if ~isempty(alpha)
      gvGaC = z'*tmp1; % sub chain (4*)
      dwGwa = gw.*zI;% (***) sub chain
      ga = z'*gw;
      Hwa = Hww*z + dwGwa; % (***)
      HCa = HCw*z + gvGaC; % (4*)
      Hat = z'*Hwt;
      Haa = z'*Hww*z;
    else
      ga = [];
      Hwa = [];
      HCa = [];
      Hat = [];
      Haa = [];
    end
    % deflate
    gw = defl*gw;
    Hww = defl*Hww*defl';
    HCw = HCw*defl';
    if ~isempty(alpha)
      Hwa = defl*Hwa;
    end
    Hwt = defl*Hwt;
    % mu
    if ~isempty(mu)
      if unmask
	Hmm =  Htt;
	Htm = -Htt;
	Ham = -Hat;
	Hwm = -Hwt;
	HCm = -HCt;
      else
	Hmm = 0*Htt;
	Htm = 0*Htt;
	Ham = 0*Hat;
	Hwm = 0*Hwt;
	HCm = 0*HCt;
      end
    else
      Hmm = [];
      Htm = [];
      Ham = [];
      Hwm = [];
      HCm = [];
    end
    gs = gC;
    Hsa = HCa;
    Hsw = HCw;
    Hss = HCC;
    Hst = HCt;
    Hsm = HCm;
  else
    % XXXXXX with s
    % gradient chain
    gs = sdefl*(gv.*w1);
    tmp = (eye(nS) - repmat(w1,[1,nS]))/sumw;
    gvGsw = repmat(gv'*tmp,[length(gs),1]); % sub-chain (**)
    dvdw = repmat(s,[1,nS]).*tmp;
    gw = dvdw'*gv;
    % Hessian chain
    Hvw = Hvv*dvdw;
    % Sub-chain interactions (*)
    tilde_gv = gv.*s;
    gvGww = repmat(tilde_gv,[1,length(gv)]);
    gvGww = gvGww+gvGww';
    gvGww = -(gvGww - 2*tilde_gv'*w1)/(sumw^2);
    Hww = dvdw'*Hvw + gvGww; % (*)
    Hsw = sdefl*dvdw'*Hvw + gvGsw; % (**)
    Hss = sdefl*(Hvv.*(w1*w1'))*sdefl';
    Hwt = dvdw'*Hvt;
    Hst = sdefl*diag(w1)*Hvt;
    % chain wrt alpha
    if ~isempty(alpha)
      dvGsa = gv'*tmp*z; % sub chain (4*)
      dwGwa = gw.*zI;% (***) sub chain
      ga = z'*gw;
      Hwa = Hww*z + dwGwa; % (***)
      Hsa = Hsw*z + dvGsa; % (4*)
      Hat = z'*Hwt;
      Haa = z'*Hww*z;
    else
      ga = [];
      Hwa = [];
      Hsa = [];
      Hat = [];
      Haa = [];
    end
    % deflate
    gw = defl*gw;
    Hww = defl*Hww*defl';
    Hsw = Hsw*defl';
    if ~isempty(alpha)
      Hwa = defl*Hwa;
    end
    Hwt = defl*Hwt;
    % mu
    if ~isempty(mu)
      if unmask
	Hmm =  Htt;
	Htm = -Htt;
	Ham = -Hat;
	Hwm = -Hwt;
	Hsm = -Hst;
      else
	Hmm = 0*Htt;
	Htm = 0*Htt;
	Ham = 0*Hat;
	Hwm = 0*Hwt;
	Hsm = 0*Hst;
      end
    else
      Hmm = [];
      Htm = [];
      Ham = [];
      Hwm = [];
      Hsm = [];
    end

  end
  % gather
  g = [ga; gw; gs; gt; gmu];  
  H = [Haa,                 Hwa',Hsa',Hat ,Ham;
       Hwa,Hww ,                 Hsw',Hwt ,Hwm;
       Hsa,Hsw ,Hss ,                  Hst, Hsm;  
       Hat,Hwt',Hst',Htt,                   Htm;
       Ham,Hwm',Hsm' ,Htm,Hmm];
else
  % with v...
  
  % deflate
  gv = defl*gv;
  Hvv = defl*Hvv*defl';
  Hvt = defl*Hvt;    
  % mu
  if ~isempty(mu)
    if unmask
      Hvm = -Hvt;
      Hmm =  Htt;
      Htm = -Htt;
    else
      Hvm = 0*Hvt;
      Hmm = 0*Htt;
      Htm = 0*Htt;
    end
  else
    Hvm = [];
    Hmm = [];
    Htm = [];
  end
  % gather
  g = [gv; gt; gmu];
  H = [Hvv,             Hvt, Hvm;
       Hvt', Htt,            Htm; 
       Hvm', Htm, Hmm];
end

% output variable transformation
if vcf==1
  % logistic
  if ~isempty(logistidx)
    dadu = theta(logistidx).*(1-theta(logistidx));
    g(logistidx) = dadu.*g(logistidx);
    H(:,logistidx) = dadu.*H(:,logistidx);
    H(logistidx,:) = dadu.*H(logistidx,:);
  end
  % exp
  if ~isempty(expidx)
    dwdu = theta(expidx);
    g(expidx) = g(expidx).*dwdu;
    H(:,expidx) = H(:,expidx).*repmat(dwdu',[size(H,1),1]);
    H(expidx,:) = H(expidx,:).*repmat(dwdu,[1,size(H,1)]);
  end
end

