#!/usr/bin/python
# -*- coding: utf-8 -*-
"""Do value iteration to solve a Markov decision process.

I wrote this for the AI class <https://www.ai-class.com/> in order to
more deeply understand Markov decision processes, and in particular to
fully understand question 14 on the midterm.

"""

import string

class MDPSolver:
    def calculate_new_value(self, value):
        return dict((state,
                     self.reward(state) if self.absorbing(state)
                     else
                     self.gamma * max(self.expected_value(value, state, action)
                                      for action in self.actions)
                     + self.reward(state)
                 )
                for state in value.keys())

    def expected_value(self, value, state, action):
        return sum(probability * value[nextstate]
                   for probability, nextstate
                   in self.results(state, action))

    def maxdelta(self, old_value, new_value):
        return max(abs(old_value[k] - new_value[k])
                   for k in old_value.keys())

    def sequence_of_values(self, old_value, minmax = 1e-12):
        while True:
            yield old_value

            new_value = self.calculate_new_value(old_value)

            if self.maxdelta(old_value, new_value) < minmax:
                break

            old_value = new_value

    def policy(self, value):
        return dict((state,
                     '' if self.absorbing(state) or self.unreachable(state)
                     else
                     # Python's name for argmax is max(x, key=y).
                     max(self.actions,
                         key = lambda action:
                             self.expected_value(value, state, action)))
                    for state in value.keys())

class BaseGridWorld(MDPSolver):
    def reward(self, state):
        return self.reward_map.get(state, self.default_reward)

    def absorbing(self, state):
        return state in self.reward_map

    def unreachable(self, (x, y)):
        return (   (x, y) in self.unreachable_states
                or x < 1 or x > self.width
                or y < 1 or y > self.height)

    actions = ['N', 'S', 'E', 'W']

    def results(self, (x, y), action):
        for probability, (dx, dy) in self.deltas[action]:
            (nx, ny) = (x + dx, y + dy)
            yield (probability,
                   (x, y) if self.unreachable((nx, ny)) else (nx, ny))

    def draw_grid(self, contents):
        print ' ',
        for x in range(1, self.width+1):
            print '%5d' % x,
        print
        print

        for y, label in zip(range(self.height, 0, -1), string.ascii_lowercase):
            print label,
            for x in range(1, self.width+1):
                print '%5.5s' % contents[x, y],
            print
            print

    def format_values(self, values):
        rv = {}
        for (x, y), value in values.items():
            if self.absorbing((x, y)):
                rv[x, y] = '*' + '%4.4s' % self.reward((x, y))
            elif self.unreachable((x, y)):
                rv[x, y] = ' ### '
            else:
                rv[x, y] = '%3.1f' % values[x, y]

        return rv

    def draw_values(self, values):
        self.draw_grid(self.format_values(values))

    def initial_values(self):
        return dict(((x, y),
                     self.reward((x, y)) if self.absorbing((x, y)) else 0)
                    for x in range(1, self.width+1)
                    for y in range(1, self.height+1))

    intended_deltas = [('N', (0, 1)),
                       ('S', (0, -1)),
                       ('E', (1, 0)),
                       ('W', (-1, 0)),
                       ]


class GridWorld1(BaseGridWorld):
    """As in Unit 9, section 19, Value Iteration 3, from the AI class.

    States are (x, y), both 1-based, y counting from the bottom.

    Note that this demonstrates an error in Thrun’s explanation: the
    faraway states do not stay with their V at 0 until the values
    propagate back from the exit to reach them, but rather become
    increasingly dismayed at their -3 surroundings as they find no way
    out except a long string of -3s, until finally the good news of
    the +100 salvation reaches them.

    It also demonstrates an error in Unit 9, section 28, Stochastic
    Question 2.  Thrun says, “Assuming that we have the value function
    as shown over here, and all the open states have a value of
    assumed 0, because we’re still at the beginning of our value
    update.”  But in fact the open states all have a value of -3 at
    that point, not 0, so the correct value for b3 is 48.3, not 48.6.

    """
    def __init__(self):
        # self.deltas is calculated in __init__ so derived classes can
        # change success_probability.

        left_turn_probability = right_turn_probability = (
            1 - self.success_probability)/2

        # (probability, (deltax, deltay)) lists
        self.deltas = dict((action,
                            [(self.success_probability, (dx, dy)),
                             (left_turn_probability, (-dy, dx)),
                             (right_turn_probability, (dy, -dx)),
                             ])
                           for action, (dx, dy) in self.intended_deltas
                           )

    reward_map = {(4, 3): 100, (4, 2): -100}
    default_reward = -3
    gamma = 1
    unreachable_states = set([(2, 2)])
    success_probability = 0.8
    width, height = 4, 3

def ok(a, b): assert a == b, (a, b)

# A simple unit tests for the .results(state, action) method.
_one_tenth = (1 - 0.8)/2
ok(set(GridWorld1().results((1, 2), 'W')),
   set([(0.8, (1, 2)), (_one_tenth, (1, 1)), (_one_tenth, (1, 3))]))
ok(set(GridWorld1().results((1, 2), 'E')),
   set([(0.8, (1, 2)), (_one_tenth, (1, 1)), (_one_tenth, (1, 3))]))
ok(set(GridWorld1().results((1, 3), 'E')),
   set([(0.8, (2, 3)), (_one_tenth, (1, 2)), (_one_tenth, (1, 3))]))

class GridWorld2(GridWorld1):
    "Deterministic version, for unit 9 part 21, deterministic answer."
    success_probability = 1

class GridWorld3(GridWorld1):
    """Version with no penalty for taking forever:
    Unit 9 part 31, value iterations and policy.

    Note that this shows an error in Thrun’s explanation; he assumes
    that ties will be broken in a sensible way, but the algorithm does
    not do that.

    """
    default_reward = 0

class GridWorld4(GridWorld1):
    """Version with enormous time penalty:
    Unit 9 part 31, value iterations and policy.
    """
    default_reward = -200


class GridWorld5(GridWorld1):
    """Version with only mild exponential time penalty.
    Not from a video.
    """
    default_reward = 0
    gamma = 0.98

class ReviewGridWorld(BaseGridWorld):
    "From unit 12, MDP review."
    def __init__(self):
        reverse_probability = 1 - self.success_probability

        self.deltas = dict((action,
                            [(self.success_probability, (dx, dy)),
                             (reverse_probability, (-dx, -dy))])
                           for action, (dx, dy) in self.intended_deltas
                           )

    reward_map = {(1, 1): -100, (4, 1): 100}
    default_reward = -4
    success_probability = 1
    gamma = 1
    unreachable_states = []
    width, height = 4, 2

class ReviewGridWorld2(ReviewGridWorld):
    "Later in unit 12, MDP review."
    success_probability = 0.8

# These are a couple of amusing variations on the unit 12 grid world.

class FunkyGridWorld(ReviewGridWorld):
    success_probability = 0.5

class BackwardsGridWorld(ReviewGridWorld):
    success_probability = 0.1
    gamma = 0.98
    default_reward = 0

class MidtermGridWorld(BaseGridWorld):
    "Question 14 on midterm 1."
    def __init__(self):
        self.deltas = dict((action, [(1, (dx, dy))])
                            for action, (dx, dy) in self.intended_deltas)
    reward_map = {(4, 2): 100}
    default_reward = -5
    unreachable_states = set([(3, 2)])
    gamma = 1
    width, height = 4, 2

def chug(world):
    ii = 0
    prev_values = None
    print world
    for values in world.sequence_of_values(world.initial_values()):
        world.draw_values(values)

        if prev_values is not None:
            print "iteration", ii, "improved by",
            print world.maxdelta(prev_values, values)
        prev_values = values

        print
        print             

        ii += 1
        if ii > 1000:
            break

    print "Final policy:"
    world.draw_grid(world.policy(values))

def main():
    chug(GridWorld1())
    chug(GridWorld2())
    chug(GridWorld3())
    chug(GridWorld4())
    chug(GridWorld5())
    chug(ReviewGridWorld())
    chug(ReviewGridWorld2())
    chug(FunkyGridWorld())
    chug(BackwardsGridWorld())
    chug(MidtermGridWorld())

if __name__ == '__main__':
    main()

