# Copyright 2010 Boris Figovsky <borfig@gmail.com>
#
# This file is part of pybfc.

# pybfc is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# pybfc is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with pybfc.  If not, see <http://www.gnu.org/licenses/>.
"""
A class that manages acquirings of multiple locks in a time.
Releasing an unknown lock (never acquired it before) makes a KeyError.
Non-blocking locks are also supported.

>>> from threading import Thread
>>> def philosopher(multilock, locks, counter):
...     with MLock(multilock, locks):
...         counter[0] = counter[0] + 1
>>> m = MultiLock()
>>> counter = [0]
>>> N = 10
>>> philosophers = [Thread(target = philosopher, args = (m, [i, (i+1) % N], counter)) for i in xrange(N)]
>>> for p in philosophers:
...     p.start()
>>> for p in philosophers:
...     p.join()
>>> counter[0]
10
>>> m.acquire([])
True
>>> m.acquire([])
True
>>> m.release([])
>>> m.release([])
>>> m.acquire('b')
True
>>> m.acquire('ab', False)
False
>>> m.release('b')
>>> m.acquire('ab', False)
True
>>> m.release('ab')
>>> m.release(['foo'])
Traceback (most recent call last):
    ...
KeyError: 'foo'

"""

import threading, collections, itertools

class MultiLock(object):

    __slots__ = ['_locks',
                 '_lock',
                 '_sorting_key',
                 ]

    def __init__(self, lock_factory = threading.Lock, sorting_key = None):
        self._locks = collections.defaultdict(lock_factory)
        self._lock = lock_factory()
        self._sorting_key = sorting_key

    def acquire(self, locks, blocking = True):
        lock_keys = sorted(locks, key = self._sorting_key)
        if not lock_keys:
            return True

        with self._lock:
            lock_values = [self._locks[lock_key] for lock_key in lock_keys]

        if blocking:
            for lock_value in lock_values:
                lock_value.acquire()
        else:
            for index, lock_value in enumerate(lock_values):
                if not lock_value.acquire(False):
                    for lock_value in lock_values[index - 1::-1]:
                        lock_value.release()
                    return False
        return True

    def release(self, locks):
        lock_keys = sorted(locks, key = self._sorting_key, reverse = True)
        if not lock_keys:
            return

        with self._lock:
            lock_values = [self._locks.get(lock_key, None) for lock_key in lock_keys]

        for lock_key, lock_value in itertools.izip(lock_keys, lock_values):
            if lock_value is None:
                raise KeyError(lock_key)

        for lock_value in lock_values:
            lock_value.release()

class MLock(object):
    __slots__ = ['_multilock', '_locks']

    def __init__(self, multilock, locks):
        self._multilock = multilock
        self._locks = locks

    def acquire(self, blocking = True):
        return self._multilock.acquire(self._locks, blocking)

    def release(self):
        return self._multilock.release(self._locks)

    def __enter__(self):
        self.acquire()
        return self

    def __exit__(self, type, value, traceback):
        self.release()
        return False

# TODO: remove this in v0.2
MultiLocker = MLock
