#!/usr/bin/env ruby

# Copyright (c) 2015 Alexandra Figlovskaya <fglval@gmail.com>
# Copyright (c) 2015-2017 Aleksey Cheusov <vle@gmx.net>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

require 'optparse'

$options = {}
$fold_cnt = nil
$tmp_dir = nil
$seed = Random.new_seed
$stratified = true

OptionParser.new do |opts|
  opts.banner = <<EOF
heri-split splits the given dataset (in svmlight format)
  into training and testing sets for further evaluation.
Usage:
    heri-split [OPTIONS] dataset1 [dataset2...]
OPTIONS:
EOF

  opts.on('-h', '--help','display this message and exit') do
    puts opts
    exit 0
  end

  opts.on("-cFOLD_CNT", "--folds=FOLD_CNT", "A number if folds (mandatory option)") do |c|
    $fold_cnt = c.to_i
  end

  opts.on("-dDIR", "--output-dir=DIR", "Output directory (mandatory option)") do |d|
    $tmp_dir = d
  end

  opts.on("-sSEED", "--seed=SEED", "Seed for pseudo-random number generator") do |s|
    if s != "" then
      $seed = s.to_i
    end
  end

  opts.on("-r", "--random", "Use random split instead of stratified") do
    $stratified = false
  end

  opts.separator " "
end.parse!


if $tmp_dir == nil or $fold_cnt == nil then
  STDERR.puts("Options -c and -d are mandatory, see heri-split -h for details")
  exit(1)
end

$rnd = Random.new($seed)

# same as in StratifiedSplitter
$files_test = []
$files_train = []
$testing_fold =  File.open($tmp_dir+"/testing_fold.txt", 'w:ASCII-8BIT')
(1..$fold_cnt).each do |i|
  name_train = "train" + "#{i.to_i}"
  name_test = "test" + "#{i.to_i}"
  $files_test << File.open($tmp_dir+"/"+name_test+".txt", 'w:ASCII-8BIT')
  $files_train << File.open($tmp_dir+ "/"+ name_train+".txt", 'w:ASCII-8BIT')
end

def random_split()
  nums = []
  curr_number = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        nums << curr_number % $fold_cnt
        curr_number += 1
      end
    end
  end

  nums.shuffle!(random: $rnd)

  curr_number = 0
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        fold_num = nums[curr_number]
        $fold_cnt.times do |n|
          if fold_num == n
            $files_test[n].puts line
            $testing_fold.puts n+1
          else
            $files_train[n].puts line
          end
        end

        curr_number += 1
      end
    end
  end
end

def stratified_split()
  classes = Hash.new(0)
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        classes[$1] += 1
      end
    end
  end
  classes_arr = {}
  classes.each do |x, y|
    arr = []
    i = 1
    while i <= y
      arr << i
      i +=1
    end
    cnt = (( y / $fold_cnt.to_f ) ).to_i
    arr.shuffle!(random: $rnd)
    classes_arr [x] = {}
    arr.each_index do |i|
      fold_train = (i * $fold_cnt.to_f / arr.size).to_i
      classes_arr[x][arr[i]] = fold_train
    end
  end

  num_line = Hash.new(0)
  ARGV.each do |fn|
    File.open(fn, "r:ASCII-8BIT").each_line do |line|
      if line =~ /^([^\s]+)\s/
        num_line[$1] += 1
        curr_number = num_line[$1]
        $fold_cnt.times do |n|
          if classes_arr[$1][curr_number] == n
            $files_test[n].puts line
            $testing_fold.puts n+1
          else
            $files_train[n].puts line
          end
        end
      end
    end
  end
end

if $stratified
  stratified_split()
else
  random_split()
end

$files_test.each { |x|
  x.close
}
$files_train.each { |x|
  x.close
}
$testing_fold.close
