
from sys import argv
import math
import numpy as np
from keras.models import Model, Sequential
from keras.layers import *


# experiment settings
repetitions = 10      # number of experiments
seqlen = 100          # sequence length
memcells = 2          # 1 does not work due to a keras bug
maxsteps = 100000     # maximal learning time
pretrain_ff = False   # pre-train the feed forward part; this is NOT e2e learning
verbose = False       # don't print too much

if len(argv) > 1:
	seqlen = int(argv[1])

if len(argv) > 2:
	pretrain_ff = (argv[2] == 'True')

print("settings:")
print("seqence length: " + str(seqlen))
print("# LSTM cells: " + str(memcells))
print("max # training steps: " + str(maxsteps))
print("pre-train the FF net: " + str(pretrain_ff))
print("verbose: " + str(verbose))
print("")


# create two sequences, goal is to predict the next element
x = np.zeros((2, seqlen, seqlen))
x[0, 0, 0] = 1.0
x[0, seqlen-1, 0] = 1.0
x[1, 0, 1] = 1.0
x[1, seqlen-1, 1] = 1.0
for i in range(seqlen-2):
	x[0, i+1, i+2] = 1.0
	x[1, i+1, i+2] = 1.0


# evaluate the network on both sequences
def eval(net):
	return np.sum(np.square(net.predict(x[:, 0:-1, :]) - x[:, 1:, :]))


# experiments loop
sum_steps = 0
num_successes = 0
for run in range(repetitions):
	print("experiment " + str(run+1) + " of " + str(repetitions))
	np.random.seed(42 + run)
	step = 0

	if pretrain_ff:
		# create a feed forward network
		input = Input(shape=(seqlen-1, seqlen), name="input")
		out = TimeDistributed(Dense(seqlen, activation='softmax', use_bias=False, name='out'))(input)
		net = Model(inputs=input, outputs=out)
		net.compile(loss='mse', optimizer='rmsprop')

		# train until convergence
		while step < maxsteps:
			net.fit(x[:, :-1, :], x[:, 1:, :], batch_size=x.shape[0], epochs=100, verbose=0, shuffle=False)
			step += 100
			err = eval(net)
			if verbose:
				print(str(step) + "   " + str(err))
			if err < 1.01:
				break
		w = net.layers[-1].get_weights()
		w[0] = np.concatenate((w[0], np.zeros((memcells, seqlen))), axis=0)

	# create the network
	input = Input(shape=(seqlen-1, seqlen), name="input")
	mem = LSTM(memcells, return_sequences=True, name='mem')(input)
	out = TimeDistributed(Dense(seqlen, activation='softmax', use_bias=False, name='out'))(concatenate([input, mem]))
	net = Model(inputs=input, outputs=out)
	net.compile(loss='mse', optimizer='rmsprop')

	if pretrain_ff:
		net.layers[-1].set_weights(w)

	# train until convergence
	while True:
		net.fit(x[:, :-1, :], x[:, 1:, :], batch_size=x.shape[0], epochs=100, verbose=0, shuffle=False)
		step += 100
		err = eval(net)
		if verbose:
			print(str(step) + "   " + str(err))
		if err < 0.01:
			sum_steps += step
			num_successes += 1
			break
		if step > maxsteps:
			break

	print("# steps: " + str(step))


print("")
print("--------------------------------------------------------")
print("SUMMARY:")
print(" *  fraction of successful runs:   " + str(num_successes / repetitions))
if num_successes > 0:
	print(" *  average number of steps when successful:   " + str(sum_steps / num_successes))
