#!/usr/bin/env python

#
# Copyright 2012 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
     # http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

#
# Derived from oauth2.py (https://github.com/google/gmail-oauth2-tools).
# Heavily modified and rewritten by Stefan Krah.
#
# Works for Microsoft Office 365 in addition to Gmail.


import os
import sys
import json
import argparse
import time
import urllib.request as urllibrequest
import urllib.parse as urllibparse
import urllib.error as urlliberror
from http.server import HTTPServer, BaseHTTPRequestHandler
urlparse = urllibparse
def codebytes(b): return b.encode("utf8")

force_local_redirect = True

def get_code(url):
    urlparsed = urlparse.urlparse(url)
    query_parsed = urlparse.parse_qs(urlparsed.query)
    try:
        code = query_parsed["code"][0]
    except KeyError:
        code = query_parsed["approvalCode"][0]
    return code

# Local webserver nonsense required to receive code from redirect
class OAuthRedirectHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        code = get_code(self.path)
        self.server.oauth_code = code
        self.send_response(200)
        self.send_header("Content-type", "text/html")
        self.end_headers()
        self.wfile.write(b"<html><head></head><body><h2>Your json file is updated.</h2></body></html>")

class OAuthRedirectServer(HTTPServer):
    def __init__(self, port):
        self.oauth_code = None
        HTTPServer.__init__(self, ("localhost", int(port)), OAuthRedirectHandler)
        self.timeout = None

class OAuth2(object):

    def __init__(self, token_data_path):
        self.token_data_path = token_data_path

        with open(self.token_data_path) as f:
            self.data = json.load(f)

    def copy(self, *keys):
        data = {}
        for k in keys:
            try:
                data[k] = self.data[k]
            except KeyError:
                pass
        return data

    def query(self, params):
        lst = []
        for param in sorted(params.items(), key=lambda x: x[0]):
            escaped = urllibparse.quote(param[1], safe='~-._')
            lst.append('%s=%s' % (param[0], escaped))
        return '&'.join(lst)

    def code_url(self, port):
        params = self.copy('scope', 'client_id', 'redirect_uri')
        if 'redirect_uri' not in params or force_local_redirect:
            params['redirect_uri'] = 'http://localhost:' + str(port) + '/'
        params['response_type'] = 'code'
        params['access_type'] = 'offline'
        if 'prompt' in self.data:
            params['prompt'] = self.data['prompt']
        else:
            params['prompt'] = 'consent'
        return '%s?%s' % (self.data['auth_uri'], self.query(params))

    def get_response(self, url, params):
        try:
            encoded = urllibparse.urlencode(params).encode('ascii')
            response = urllibrequest.urlopen(url, encoded).read()
            return json.loads(response)
        except urlliberror.HTTPError as httpError:
            error = httpError.read().decode()
            print("error: " + error)
            sys.exit(1)

    def update_config(self, d):
        self.data['access_token'] = d['access_token']
        self.data['expires_at'] = time.time() + d['expires_in'] - 100

        refresh_token = d.get('refresh_token')
        if refresh_token is not None:
            self.data['refresh_token'] = refresh_token

        with open(self.token_data_path, "w") as f:
            json.dump(self.data, f)

    def init_tokens(self, code, port):
        params = self.copy('user', 'client_id', 'client_secret', 'redirect_uri')
        if 'redirect_uri' not in params or force_local_redirect:
            params['redirect_uri'] = 'http://localhost:' + str(port) + '/'
        params['code'] = code
        params['grant_type'] = 'authorization_code'

        d = self.get_response(self.data['token_uri'], params)
        self.update_config(d)

    def refresh_tokens(self):
        params = self.copy('client_id', 'client_secret', 'refresh_token')
        params['grant_type'] = 'refresh_token'

        d = self.get_response(self.data['token_uri'], params)
        self.update_config(d)

    def token(self):
        if time.time() >= self.data.get('expires_at'):
            self.refresh_tokens()

        return self.data['access_token']


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--init", action="store_true", default=False,
                        help="initialize access and refresh tokens")
    parser.add_argument('tokenfile', metavar='<token data file path>',
                        help="location of the token data file")
    parser.add_argument("-p", "--port", default=None,
                        help="local port to use for receiving oauth2 redirect (default 8083)")

    args = parser.parse_args()

    if not args.port:
        force_local_redirect = False
        args.port = 8083

    auth = OAuth2(args.tokenfile)

    if args.init:
        interrupted = False
        print("Visit this url to obtain a verification code:")
        print("    %s\n" % auth.code_url(args.port))
        if not auth.copy('redirect_uri') or force_local_redirect:
            print("Press Ctrl+C if you are installing getmail on a device without a browser.")
            oauthd = OAuthRedirectServer(args.port)
            try:
                oauthd.handle_request()
                auth.init_tokens(oauthd.oauth_code, args.port)
            except:
                interrupted = True
            finally:
                oauthd.server_close()
        if (auth.copy('redirect_uri') and not force_local_redirect) or interrupted:
            print("Please paste the response from the address bar of the browser you used for the url above:")
            urlinput = input()
            code = get_code(urlinput)
            auth.init_tokens(code,args.port)
        print("\naccess token\n")

    print("%s" % auth.token())

    sys.exit(0)
