/*
 * Copyright (c) 1991-2004 Kyoto University
 * Copyright (c) 2000-2004 NAIST
 * All rights reserved
 */

/* word_align.c --- perform viterbi alignment for given words or phonemes */

/* $Id: word_align.c,v 1.9 2004/03/23 03:00:16 ri Exp $ */

#include <julius.h>

#define PER_WORD 1
#define PER_PHONEME 2
#define PER_STATE 3

/* build sentence HMM from word sequence */
static HMM_Logical **
make_phseq(WORD_ID *wseq, short num, int *num_ret, int **end_ret, int per_what)
{
  HMM_Logical **ph;		/* phoneme sequence */
  int phnum;			/* num of above */
  WORD_ID tmpw, w;
  int i, j, pn, st, endn;
  HMM_Logical *tmpp, *ret;

  /* make ph[] from wseq[] */
  /* 1. calc total phone num and malloc */
  phnum = 0;
  for (w=0;w<num;w++) phnum += winfo->wlen[wseq[w]];
  ph = (HMM_Logical **)mymalloc(sizeof(HMM_Logical *) * phnum);
  /* 2. make phoneme sequence */
  st = 0;
  pn = 0;
  endn = 0;
  for (w=0;w<num;w++) {
    tmpw = wseq[w];
    for (i=0;i<winfo->wlen[tmpw];i++) {
      tmpp = winfo->wseq[tmpw][i];
      /* handle cross-word context dependency */
      if (ccd_flag) {
	if (w > 0 && i == 0) {	/* word head */
	  
	  if ((ret = get_left_context_HMM(tmpp, ph[pn-1]->name, hmminfo)) != NULL) {
	    tmpp = ret;
	  }
	  /* if triphone not found, fallback to bi/mono-phone  */
	  /* use pseudo phone when no bi-phone found in alignment... */
	}
	if (w < num-1 && i == winfo->wlen[tmpw] - 1) { /* word tail */
	  if ((ret = get_right_context_HMM(tmpp, winfo->wseq[wseq[w+1]][0]->name, hmminfo)) != NULL) {
	    tmpp = ret;
	  }
	}
      }
      ph[pn++] = tmpp;
      if (per_what == PER_STATE) {
	for (j=0;j<hmm_logical_state_num(tmpp)-2;j++) {
	  (*end_ret)[endn++] = st + j;
	}
      }
      st += hmm_logical_state_num(tmpp) - 2;
      if (per_what == PER_PHONEME) (*end_ret)[endn++] = st - 1;
    }
    if (per_what == PER_WORD) (*end_ret)[endn++] = st - 1;
  }
  *num_ret = phnum;
  return ph;
}


/* build sentence HMM, call viterbi_segment() and output result */
/* NOTE: words are in reverse order */
static void
do_align(WORD_ID *words, short wnum, HTK_Param *param, int per_what)
{
  HMM_Logical **phones;		/* phoneme sequence */
  int phonenum;			/* num of above */
  HMM *shmm;			/* sentence HMM */
  int *end_state;		/* state number of word ends */
  int *end_frame;		/* segmented last frame of words */
  LOGPROB *end_score;		/* normalized score of each words */
  LOGPROB allscore;		/* total score of this word sequence */
  WORD_ID w;
  int i, rlen;
  int end_num;
  int *id_seq, *phloc, *stloc;

  /* initialize result storage buffer */
  switch(per_what) {
  case PER_WORD:
    j_printf("=== word alignment begin ===\n");
    end_num = wnum;
    phloc = (int *)mymalloc(sizeof(int)*wnum);
    i = 0;
    for(w=0;w<wnum;w++) {
      phloc[w] = i;
      i += winfo->wlen[words[w]];
    }
    break;
  case PER_PHONEME:
    j_printf("=== phoneme alignment begin ===\n");
    end_num = 0;
    for(w=0;w<wnum;w++) end_num += winfo->wlen[words[w]];
    break;
  case PER_STATE:
    j_printf("=== state alignment begin ===\n");
    end_num = 0;
    for(w=0;w<wnum;w++) {
      for (i=0;i<winfo->wlen[words[w]]; i++) {
	end_num += hmm_logical_state_num(winfo->wseq[words[w]][i]) - 2;
      }
    }
    phloc = (int *)mymalloc(sizeof(int)*end_num);
    stloc = (int *)mymalloc(sizeof(int)*end_num);
    {
      int j,n,p;
      n = 0;
      p = 0;
      for(w=0;w<wnum;w++) {
	for(i=0;i<winfo->wlen[words[w]]; i++) {
	  for(j=0; j<hmm_logical_state_num(winfo->wseq[words[w]][i]) - 2; j++) {
	    phloc[n] = p;
	    stloc[n] = j + 1;
	    n++;
	  }
	  p++;
	}
      }
    }
    
    break;
  }
  end_state = (int *)mymalloc(sizeof(int) * end_num);

  /* make phoneme sequence word sequence */
  phones = make_phseq(words, wnum, &phonenum, &end_state, per_what);
  /* build the sentence HMMs */
  shmm = new_make_word_hmm(hmminfo, phones, phonenum);

  /* call viterbi segmentation function */
  allscore = viterbi_segment(shmm, param, end_state, end_num, &id_seq, &end_frame, &end_score, &rlen);

  /* print result */
  {
    int i,p,n;
    j_printf("id: from  to    n_score    applied HMMs (logical[physical] or {pseudo})\n");
    j_printf("------------------------------------------------------------\n");
    for (i=0;i<rlen;i++) {
      j_printf("%2d: %4d %4d  %f ", id_seq[i], (i == 0) ? 0 : end_frame[i-1]+1, end_frame[i], end_score[i]);
      switch(per_what) {
      case PER_WORD:
	for(p=0;p<winfo->wlen[words[id_seq[i]]];p++) {
	  n = phloc[id_seq[i]] + p;
	  if (phones[n]->is_pseudo) {
	    j_printf(" %s{%s}", phones[n]->name, phones[n]->body.pseudo->name);
	  } else if (strmatch(phones[n]->name, phones[n]->body.defined->name)) {
	    j_printf(" %s", phones[n]->name);
	  } else {
	    j_printf(" %s[%s]", phones[n]->name, phones[n]->body.defined->name);
	  }
	}
	break;
      case PER_PHONEME:
	n = id_seq[i];
	if (phones[n]->is_pseudo) {
	  j_printf(" {%s}", phones[n]->name);
	} else if (strmatch(phones[n]->name, phones[n]->body.defined->name)) {
	  j_printf(" %s", phones[n]->name);
	} else {
	  j_printf(" %s[%s]", phones[n]->name, phones[n]->body.defined->name);
	}
	break;
      case PER_STATE:
	n = phloc[id_seq[i]];
	if (phones[n]->is_pseudo) {
	  j_printf(" {%s}", phones[n]->name);
	} else if (strmatch(phones[n]->name, phones[n]->body.defined->name)) {
	  j_printf(" %s", phones[n]->name);
	} else {
	  j_printf(" %s[%s]", phones[n]->name, phones[n]->body.defined->name);
	}
	j_printf(" #%d", stloc[id_seq[i]]);
	break;
      }
      j_printf("\n");
    }
  }
  j_printf("re-computed AM score: %f\n", allscore);

  free_hmm(shmm);
  free(id_seq);
  free(phones);
  free(end_score);
  free(end_frame);
  free(end_state);

  switch(per_what) {
  case PER_WORD:
    free(phloc);
    j_printf("=== word alignment end ===\n");
    break;
  case PER_PHONEME:
    j_printf("=== phoneme alignment end ===\n");
    break;
  case PER_STATE:
    free(phloc);
    free(stloc);
    j_printf("=== state alignment end ===\n");
  }
  
}

/* entry function */
/* do forced alignment per words (words[0..wnum-1]) for param  */
void
word_align(WORD_ID *words, short wnum, HTK_Param *param)
{
  do_align(words, wnum, param, PER_WORD);
}
/* do forced alignment per words (words[wnum-1..0]) for param  */
/* words are given in reverse order */
void
word_rev_align(WORD_ID *revwords, short wnum, HTK_Param *param)
{
  WORD_ID *words;		/* word sequence (true order) */
  int w;
  words = (WORD_ID *)mymalloc(sizeof(WORD_ID) * wnum);
  for (w=0;w<wnum;w++) words[w] = revwords[wnum-w-1];
  do_align(words, wnum, param, PER_WORD);
  free(words);
}
/* do forced alignment per phoneme for words[0..num-1] */
void
phoneme_align(WORD_ID *words, short num, HTK_Param *param)
{
  do_align(words, num, param, PER_PHONEME);
}
/* do forced alignment per phoneme for words[num-1..0] */
/* words are given in reverse order */
void
phoneme_rev_align(WORD_ID *revwords, short num, HTK_Param *param)
{
  WORD_ID *words;		/* word sequence (true order) */
  int p;
  words = (WORD_ID *)mymalloc(sizeof(WORD_ID) * num);
  for (p=0;p<num;p++) words[p] = revwords[num-p-1];
  do_align(words, num, param, PER_PHONEME);
  free(words);
}
/* do forced alignment per state for words[0..num-1] */
void
state_align(WORD_ID *words, short num, HTK_Param *param)
{
  do_align(words, num, param, PER_STATE);
}
/* do forced alignment per state for words[num-1..0] */
/* words are given in reverse order */
void
state_rev_align(WORD_ID *revwords, short num, HTK_Param *param)
{
  WORD_ID *words;		/* word sequence (true order) */
  int p;
  words = (WORD_ID *)mymalloc(sizeof(WORD_ID) * num);
  for (p=0;p<num;p++) words[p] = revwords[num-p-1];
  do_align(words, num, param, PER_STATE);
  free(words);
}
