import itertools
import multiprocessing
from collections import defaultdict
from concurrent.futures import Future
from dataclasses import dataclass, replace, field, fields, Field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Literal, Union, DefaultDict, Callable, TypeVar, Any, get_args
from urllib.parse import urlparse
import filelock
from flask import Response, flash, request
from isodate import Duration
from marshmallow import Schema
from werkzeug.utils import secure_filename
from timApp.auth.accesshelper import AccessDenied, verify_admin
from timApp.auth.accesstype import AccessType
from timApp.auth.auth_models import get_duration_now, do_confirm
from timApp.auth.sessioninfo import get_current_user_object
from timApp.document.docentry import DocEntry
from timApp.folder.folder import Folder
from timApp.item.item import Item
from timApp.tim_app import app, csrf
from timApp.timdb.sqa import db
from timApp.user.user import User
from timApp.user.usergroup import UserGroup
from timApp.user.userutils import grant_access
from timApp.util.flask.requesthelper import RouteException
from timApp.util.flask.responsehelper import (
ok_response,
to_json_str,
json_response,
safe_redirect,
)
from timApp.util.flask.typedblueprint import TypedBlueprint
from timApp.util.logger import log_warning
from timApp.util.secret import check_secret, get_secret_or_abort
from timApp.util.utils import (
read_json_lines,
collect_errors_from_hosts,
get_current_time,
)
from tim_common.marshmallow_dataclass import field_for_schema, class_schema
from tim_common.utils import DurationSchema
from tim_common.vendor.requests_futures import FuturesSession
dist_bp = TypedBlueprint("dist_rights", __name__, url_prefix="/distRights")
[docs]@dataclass(slots=True)
class ConfirmOp:
type: Literal["confirm"]
email: str
timestamp: datetime
[docs]@dataclass(slots=True)
class ConfirmGroupOp:
type: Literal["confirmgroup"]
group: str
timestamp: datetime
[docs]@dataclass(slots=True)
class QuitOp:
type: Literal["quit"]
email: str
timestamp: datetime
[docs]@dataclass(slots=True)
class UnlockOp:
type: Literal["unlock"]
email: str
timestamp: datetime
[docs]@dataclass(slots=True)
class ChangeTimeOp:
type: Literal["changetime"]
email: str
secs: int
timestamp: datetime
[docs]@dataclass(slots=True)
class ChangeTimeGroupOp:
type: Literal["changetimegroup"]
group: str
secs: int
timestamp: datetime
[docs]@dataclass(slots=True)
class UndoConfirmOp:
type: Literal["undoconfirm"]
email: str
timestamp: datetime
[docs]@dataclass(slots=True)
class UndoQuitOp:
type: Literal["undoquit"]
email: str
timestamp: datetime
[docs]@dataclass(slots=True)
class ChangeStartTimeGroupOp:
type: Literal["changestarttimegroup"]
group: str
starttime: datetime
timestamp: datetime
[docs]@dataclass(slots=True)
class Right:
require_confirm: bool
duration_from: datetime | None
duration_to: datetime | None
duration: Duration | None
accessible_from: datetime | None
accessible_to: datetime | None
RightOp = Union[
ConfirmOp,
UnlockOp,
ChangeTimeOp,
QuitOp,
UndoConfirmOp,
UndoQuitOp,
ChangeTimeGroupOp,
ChangeStartTimeGroupOp,
ConfirmGroupOp,
]
GroupOp = Union[
ChangeTimeGroupOp,
ChangeStartTimeGroupOp,
ConfirmGroupOp,
]
GroupOps = (
ChangeTimeGroupOp,
ChangeStartTimeGroupOp,
ConfirmGroupOp,
)
def _get_op_name(op_type: RightOp) -> str:
type_field: Field = next((f for f in fields(op_type) if f.name == "type"))
return get_args(type_field.type)[0]
right_op_types: dict[str, Schema] = {
_get_op_name(op_type): class_schema(op_type)() for op_type in get_args(RightOp)
}
# It's faster to deserialize using class schemas instead of generating a field schema for RightOp
def _deserialize_right(line: dict[str, Any]) -> RightOp:
return right_op_types[line["type"]].load(line)
RightOpSchema = field_for_schema(RightOp) # type: ignore[arg-type]
RightSchema = class_schema(Right, base_schema=DurationSchema)()
Email = str
[docs]@dataclass(frozen=True)
class RightLogEntry:
op: RightOp
right: Right
T = TypeVar("T", bound=RightOp)
[docs]@dataclass
class RightLog:
initial_right: Right
group_cache: dict[str, list[Email]] = field(default_factory=dict)
op_history: DefaultDict[Email, list[RightLogEntry]] = field(
default_factory=lambda: defaultdict(list)
)
[docs] def add_op(self, r: RightOp) -> None:
if isinstance(r, ChangeTimeGroupOp):
emails = self.get_group_emails(r)
self.process_group_rights(emails, change_time, r)
return
if isinstance(r, ChangeStartTimeGroupOp):
emails = self.get_group_emails(r)
self.process_group_rights(emails, change_starttime, r)
return
if isinstance(r, ConfirmGroupOp):
emails = self.get_group_emails(r)
self.process_group_rights(emails, confirm_group, r)
return
email = r.email
curr_right = self.get_right(email)
curr_hist = self.op_history[email]
if isinstance(r, ConfirmOp):
do_confirm(curr_right, r.timestamp)
elif isinstance(r, UnlockOp):
if curr_right.duration:
curr_right.accessible_from = r.timestamp
curr_right.accessible_to = (
curr_right.accessible_from
+ get_duration_now(curr_right, r.timestamp)
)
else:
# TODO: This shouldn't happen in practice because non-duration rights cannot be unlocked.
# Perhaps log a warning etc.
pass
elif isinstance(r, ChangeTimeOp):
change_time(curr_right, r)
elif isinstance(r, QuitOp):
curr_right.accessible_to = r.timestamp
elif isinstance(r, UndoConfirmOp):
# We _don't_ want to assign "accessible_from = None" here.
# Otherwise, if the right is reconfirmed, the start time will be wrong (it gets current timestamp).
# If a Right with require_confirm = True is distributed, accessible_from will be saved as None in the
# receiving end, meaning that the right is not active.
# curr_right.accessible_from = None
curr_right.require_confirm = True
elif isinstance(r, UndoQuitOp):
if isinstance(curr_hist[-1].op, QuitOp):
try:
last_active = curr_hist[
-2
].right # -1 is the QuitOp, so one before that
except IndexError:
last_active = self.initial_right
curr_right.accessible_to = last_active.accessible_to
else:
raise Exception("unknown op")
curr_hist.append(RightLogEntry(r, curr_right))
[docs] def process_group_rights(
self,
emails: list[Email],
fn: Callable[[Right, T], None],
r: T,
) -> None:
op_history = self.op_history
for e in emails:
l_op = self.latest_op(e)
if l_op and isinstance(l_op.op, QuitOp):
continue
rig = self.get_right(e)
fn(rig, r)
op_history[e].append(RightLogEntry(r, rig))
[docs] def get_group_emails(self, r: GroupOp) -> list[Email]:
emails = self.group_cache.get(r.group)
if not emails:
emails = [
e
for e, in (
UserGroup.query.join(User, UserGroup.users)
.filter(UserGroup.name == r.group)
.with_entities(User.email)
)
]
if not emails:
if not UserGroup.get_by_name(r.group):
raise Exception(f"Usergroup {r.group} not found")
self.group_cache[r.group] = [e for e in emails]
return emails
[docs] def get_right(self, email: Email) -> Right:
latest = self.latest_op(email)
# Make a copy of the right so we keep it immutable.
return replace(latest.right if latest else self.initial_right)
[docs] def latest_op(self, email: Email) -> RightLogEntry | None:
try:
return self.op_history[email][-1]
except IndexError:
return None
[docs]def change_time(right: Right, op: ChangeTimeOp | ChangeTimeGroupOp) -> None:
if right.accessible_to:
right.accessible_to += timedelta(seconds=op.secs)
if right.duration and not right.accessible_from:
right.duration += timedelta(seconds=op.secs)
[docs]def change_starttime(right: Right, op: ChangeStartTimeGroupOp) -> None:
if right.accessible_from:
# non-duration right (or an unlocked duration)
old_acc_from = right.accessible_from
right.accessible_from = op.starttime
# Keep the difference (accessible_to - accessible_from) constant.
# It's not necessarily always desired, but makes sense in exams with non-duration rights.
if right.accessible_to:
dur = right.accessible_to - old_acc_from
right.accessible_to = right.accessible_from + dur
elif right.duration_from:
dur_to = right.duration_to
# Keep the difference (duration_to - duration_from) constant.
unlock_period = dur_to - right.duration_from if dur_to else None
right.duration_from = op.starttime
right.duration_to = (
right.duration_from + unlock_period if unlock_period else None
)
[docs]def confirm_group(right: Right, op: ConfirmGroupOp) -> None:
do_confirm(right, op.timestamp)
[docs]def get_current_rights(target: str) -> tuple[RightLog, Path]:
fp = Path(app.config["FILES_PATH"])
initial_rights, lines = read_rights(fp / f"{target}.rights.initial", 1)
rights_log_path = fp / f"{target}.rights.log"
try:
logged_rights, _ = read_rights(rights_log_path, 0)
except FileNotFoundError:
logged_rights = []
initial_right: Right = RightSchema.load(lines[0])
rights = RightLog(initial_right)
for r in itertools.chain(initial_rights, logged_rights):
rights.add_op(r)
return rights, rights_log_path
[docs]def read_rights(path: Path, index: int) -> tuple[list[RightOp], list[dict]]:
lines = read_json_lines(path)
return [_deserialize_right(line) for line in lines[index:]], lines
[docs]def do_register_right(op: RightOp, target: str) -> tuple[RightLog | None, str | None]:
rights, right_log_path = get_current_rights(target)
if not isinstance(op, GroupOps):
latest_op = rights.latest_op(op.email)
if (
latest_op
and isinstance(latest_op.op, QuitOp)
and not isinstance(op, UndoQuitOp)
):
return None, f"{target}: Cannot register a non-UndoQuitOp after QuitOp"
if isinstance(op, UndoQuitOp) and (
not latest_op or not isinstance(latest_op.op, QuitOp)
):
return None, f"{target}: There is no QuitOp to undo"
rights.add_op(op)
with right_log_path.open("a") as f:
f.write(to_json_str(op) + "\n")
return rights, None
[docs]def do_dist_rights(op: RightOp, rights: RightLog, target: str) -> list[str]:
emails = rights.group_cache[op.group] if isinstance(op, GroupOps) else [op.email]
session = FuturesSession(max_workers=multiprocessing.cpu_count())
futures = []
host_config = app.config["DIST_RIGHTS_HOSTS"][target]
dist_rights_send_secret = get_secret_or_abort("DIST_RIGHTS_SEND_SECRET")
hosts = host_config["hosts"]
rights_to_send = [{"email": e, "right": rights.get_right(e)} for e in emails]
for m in hosts:
r = session.put(
f"{m}/distRights/receive",
data=to_json_str(
{
"rights": rights_to_send,
"secret": dist_rights_send_secret,
"item_path": host_config["item"],
}
),
headers={"Content-Type": "application/json"},
timeout=10,
)
futures.append(r)
return collect_errors_from_hosts(futures, hosts)
[docs]def register_right_impl(
op: RightOp,
target: str | list[str],
backup: bool = True,
distribute: bool = True,
) -> list[str]:
targets = [target] if isinstance(target, str) else target
errors = []
for tgt in targets:
target_s = secure_filename(tgt)
if not target_s:
raise RouteException(f"invalid target: {tgt}")
with filelock.FileLock(f"/tmp/log_right_{target_s}"):
rights, err = do_register_right(op, target_s)
if err:
errors.append(err)
if distribute and rights:
with filelock.FileLock(f"/tmp/dist_right_{target_s}"):
errors.extend(do_dist_rights(op, rights, target_s))
if backup:
backup_errors = register_op_to_hosts(op, target, is_receiving_backup=True)
if backup_errors:
log_warning(f"Right backup failed for some hosts: {backup_errors}")
return errors
[docs]@dist_bp.post("/register")
@csrf.exempt
def register_right(
op: RightOp,
target: str | list[str],
secret: str,
is_receiving_backup: bool = False,
) -> Response:
check_secret(secret, "DIST_RIGHTS_REGISTER_SECRET")
is_active_distributor = app.config["DIST_RIGHTS_IS_DISTRIBUTOR"]
errors = register_right_impl(
op,
target,
backup=False,
distribute=not is_receiving_backup and is_active_distributor,
)
return json_response({"host_errors": errors})
[docs]@dataclass
class RightEntry:
email: Email
right: Right
[docs]@dist_bp.put("/receive")
@csrf.exempt
def receive_right(
rights: list[RightEntry],
item_path: str,
secret: str,
) -> Response:
check_secret(secret, "DIST_RIGHTS_RECEIVE_SECRET")
uges = (
UserGroup.query.join(User, UserGroup.name == User.name)
.filter(User.email.in_(re.email for re in rights))
.with_entities(UserGroup, User.email)
.all()
)
group_map = {}
for ug, email in uges:
group_map[email] = ug
item: Item | None = Folder.find_by_path(item_path)
if not item:
item = DocEntry.find_by_path(item_path)
if not item:
raise RouteException(f"Item not found: {item_path}")
for r in rights:
ug = group_map[r.email]
right = r.right
grant_access(
ug,
item,
AccessType.view,
# In TIM, a right is considered active whenever accessible_from is set, so if the right still requires
# confirmation, we must set accessible_from to be null.
accessible_from=right.accessible_from
if not right.require_confirm
else None,
accessible_to=right.accessible_to,
duration=right.duration,
duration_from=right.duration_from,
duration_to=right.duration_to,
require_confirm=right.require_confirm,
)
db.session.commit()
return ok_response()
[docs]@dist_bp.get("/changeStartTime")
def change_starttime_route(
group: str,
target: str, # comma-separated; TODO: List[str] doesn't work for GET requests
minutes: int,
redir: str,
) -> Response:
targets = target.split(",")
u = get_current_user_object()
conf_name = "DIST_RIGHTS_START_TIME_GROUP"
start_time_group = app.config[conf_name]
if not start_time_group:
raise RouteException(f"{conf_name} not configured.")
ug = UserGroup.get_by_name(start_time_group)
if u not in ug.users and not u.is_admin:
raise AccessDenied("You are not in the group that can change the start time.")
curr_time = get_current_time()
op = ChangeStartTimeGroupOp(
type="changestarttimegroup",
timestamp=curr_time,
group=group,
starttime=curr_time + timedelta(minutes=minutes),
)
errors = register_right_impl(op, targets)
if errors:
flash(str(errors))
parsed = urlparse(redir)
if parsed.scheme or parsed.netloc:
raise RouteException("redir must be relative")
return safe_redirect(request.host_url + redir)
[docs]def register_op_to_hosts(
op: RightOp, target: str | list[str], is_receiving_backup: bool
) -> list[str]:
curr_host = app.config["TIM_HOST"]
register_hosts = [
h for h in app.config["DIST_RIGHTS_REGISTER_HOSTS"] if h != curr_host
]
session = FuturesSession(max_workers=multiprocessing.cpu_count())
futures: list[Future] = []
for h in register_hosts:
f = session.post(
f"{h}/distRights/register",
to_json_str(
{
"op": op,
"target": target,
"secret": app.config["DIST_RIGHTS_REGISTER_SEND_SECRET"],
"is_receiving_backup": is_receiving_backup,
}
),
headers={"Content-type": "application/json"},
timeout=10,
)
futures.append(f)
return collect_errors_from_hosts(futures, register_hosts)
[docs]@dist_bp.get("/current")
def get_current_rights_route(
groups: str, # comma-separated; TODO: List[str] doesn't work for GET requests
target: str,
) -> Response:
verify_admin()
try:
rights, _ = get_current_rights(target)
except FileNotFoundError:
raise RouteException(f"Unknown target: {target}")
groups_list = groups.split(",")
emails = (
User.query.join(UserGroup, User.groups)
.filter(UserGroup.name.in_(groups_list))
.with_entities(User.email)
.order_by(User.email)
.all()
)
return json_response([{"email": e, "right": rights.get_right(e)} for e, in emails])