/*

smbmount.c

*/

/**************************************************************
 *  $Workfile: smbmount.c $
 *  $Revision: 4 $
 *  $Modtime: 7.01.01 12:12 $
 *  $Date: 15.01.01 18:11 $
 **************************************************************
 *  This software is copyright (C) 1997 by Ilja V.Levinson    *
 *                                                            *
 *  Permission is granted to reproduce and distribute         *
 *  this package by any means so long as no fee is charged    *
 *  above a nominal handling fee and so long as this          *
 *  notice is always included in the copies.                  *
 *  Commercial use or incorporation into commercial software  *
 *  is prohibited without the written permission of the       *
 *  author.                                                   *
 *                                                            *
 *  Other rights are reserved except as explicitly granted    *
 *  by written permission of the author.                      *
 *                                                            *
 *     Ilja V.Levinson, Yekaterinburg, Russia                 *
 *     email: lev@oduurl.ru                                   *
 *************************************************************/

#include <types.h>
#include <stdio.h>
#include <memory.h>
#include <sg_codes.h>
#include <modes.h>
#include <process.h>
#include <const.h>
#include <module.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <events.h>

#include <inet/netdb.h>
#include <inet/socket.h>

#include "smb.h"

#define SKIP(a) a++; if(*a == '=') a++;

struct descriptor_settings {
    enum smb_conn_type  type;
    u_short             cache_size;
    u_short             maxfiles;
};

void usage(void);
char *build_rawname(char *);
char *strupr(char *);
error_code create_descriptor(char *modname, struct descriptor_settings *);
error_code parse_resource(char *service, char **server, char **share, char **user, int *);
error_code _os_ss_mount(path_id, smb_mount_data *);
error_code _os_gs_mount(path_id, smb_mount_data *, int);
error_code do_unmount(char *modname);
error_code do_info(char *modname);
error_code get_server_addr(char *, char *, smb_mount_data *);
char *getpass(char *);
unsigned int atoh(char *);
char *inet_ntoa(struct in_addr in);

char *progname;

main(int argc, char **argv) {
    path_id         fd;
    error_code      result;
    smb_mount_data  m;
    char            *modname;
    char            *mount_point = 0;
    char            *rawname;
    char            *server = 0, *share = 0, *user = 0;
    char            *realhost = 0;
    int             upcase_password = 1;
    int             got_password = 0;
    char            hostname[64];
    int             service_ok = 0;
    struct descriptor_settings  desc;

    progname = argv[0];

    /* Invalid arguments? */
    if (argc < 3)
        usage();

    /* Get our hostname */
    gethostname(hostname, sizeof(hostname)-1);

    /* Initialize required structs */
    memset(&m, 0, sizeof(m));

    /* Set default domain name */
    strcpy(m.domain, "WORKGROUP");

    /* Set default descriptor settings */
    desc.type = STYPE_DISKTREE;             /* disk share by default */
    desc.cache_size = SMB_DIRCACHE_SIZE;    /* dir cache size */
    desc.maxfiles = SMB_MAX_FILES;          /* max files per a directory */

    /* Set default mount values */
    m.file_mode = S_IREAD|S_IWRITE|S_IEXEC;
    m.dir_mode  = S_IREAD|S_IWRITE|S_IEXEC|
                  S_IOREAD|S_IOWRITE|S_IOEXEC|S_IFDIR;
    m.eol = 'o';    /* OS-9 by default */

    m.max_xmit = 8192;
    m.max_mids = 2;

    mount_point = argv[1];
    modname = mount_point[0] == '/' ? &mount_point[1] : mount_point;

    /* Skip mount_point argument */
    argv += 1; argc -= 1;

    /* Parse options and/or resource location */
    while (--argc) {
        char *p;
        if (argv[1][0] == '-') {
            p = &argv[1][1];
            switch( *p ) {
                case 'c' :
                    SKIP(p);
                    if (strlen(p) > sizeof(m.client_name) - 1) {
                        fprintf(stderr, "%s: client name too long: %s\n", progname, p);
                        return 1;
                    }
                    strcpy(m.client_name, p);
                    strupr(m.client_name);
                    break;
                case 'U' :
                    SKIP(p);
                    if (strlen(p) > sizeof(m.user_name) - 1) {
                        fprintf(stderr, "%s: user name too long: %s\n", progname, p);
                        return 1;
                    }
                    strcpy(m.user_name, p);
                    break;
                case 'W' :
                    SKIP(p);
                    if (strlen(p) > sizeof(m.domain) - 1) {
                        fprintf(stderr, "%s: domain name too long: %s\n", progname, p);
                        return 1;
                    }
                    strcpy(m.domain, p);
                    strupr(m.domain);
                    break;
                case 'P' :
                    SKIP(p);
                    if (strlen(p) > sizeof(m.password) - 1) {
                        fprintf(stderr, "%s: password too long: %s\n", progname, p);
                        return 1;
                    }
                    strcpy(m.password, p);
                    memset(p, 'X', strlen(p));  /* wipe password */
                    /* Break thru */
                case 'n' :
                    got_password = 1;
                    break;
                case 'I' :
                    SKIP(p);
                    realhost = p;
                    break;
                case '$' :
                    desc.type = STYPE_IPC;
                    break;
                case 'p' :
                    desc.type = STYPE_PRINTQ;
                    break;
                case 'C' :
                    upcase_password = 0;
                    break;
                case 'u' :
                    result = do_unmount(modname);
                    return result;
                    break;
                case 'i' :
                    result = do_info(modname);
                    return result;
                    break;
                case 'f' :
                    SKIP(p);
                    m.file_mode = atoh(p);
                    break;
                case 'd' :
                    SKIP(p);
                    m.dir_mode = atoh(p) | S_IFDIR;
                    break;
                case 'e' :
                    SKIP(p);
                    if(!(*p == 'd' || *p == 'o' || *p == 'u')) {
                        fprintf(stderr, "%s: invalid EOL ('%c') id (o,d,u only)\n",
                                progname, *p);
                        return 1;
                    }
                    else {
                        m.eol = *p;
                        break;
                    }
                case 'a' :
                    SKIP(p);
                    desc.cache_size = (u_short)atoi(p);
                    break;
                case 'm' :
                    SKIP(p);
                    desc.maxfiles = (u_short)atoi(p);
                    break;
                case 'x' :
                    SKIP(p);
                    m.max_mids = (u_short)atoi(p);
                    break;
                case 't' :
                    m.options |= SMB_OPT_SETIME;
                    break;
                case 'O' :
                    SKIP(p);
                    m.options |= (u_short)atoh(p);
                    break;
                case '?' :
                    usage();
                default  :
                    fprintf(stderr, "%s: invalid option '%c'.\n", progname, *p);
                    return 1;
            }
        }
        else {
            /* Parse resource location */
            if(result = parse_resource(argv[1],
                    &server, &share, &user, &service_ok)) {
                fprintf(stderr, "%s: bad resource location\n", progname);
                return result;
            }
        }
    argv++;
    }

    if(!service_ok) {
        fprintf(stderr, "%s: unknown resource location\n", progname);
        return 1;
    }

    strcpy(m.server_name, server);
    strupr(m.server_name);

    strcpy(m.service, share);
    strupr(m.service);

    if(m.user_name[0] == '\0') {
        if(user[0] == '\0') {
            if(getenv("USER"))
                strncpy(m.user_name, getenv("USER"), sizeof(m.user_name)-2);
        }
        else
            strcpy(m.user_name, user);
    }
    strupr(m.user_name);

    if (m.client_name[0] == '\0') {
        if (strlen(hostname) > sizeof(m.client_name) - 1) {
            fprintf(stderr, "%s: my hostname name too long as a "
                "netbios name: %s\n", progname, hostname);
            fprintf(stderr, "Use option -c=client_name\n");
            return 1;
        }
        strcpy(m.client_name, hostname);
        strupr(m.client_name);
    }

    /* Get server Inet address */
    if(result = get_server_addr(realhost, server, &m))
        return 1;

    /* Handle password */
    if(got_password == 0) {
        char *pw = getpass("Password: ");
        if (strlen(pw) > sizeof(m.password) - 1) {
            fprintf(stderr, "%s: password too long\n", progname);
            return 1;
        }
        strcpy(m.password, pw);
        memset(pw, 'X', strlen(pw)); /* wipe password */
    }

    if (upcase_password == 1)
        strupr(m.password);

    /* Some variables and constants */
    m.version = SMB_VERSION;
    m.cntype = desc.type;

    if(result = create_descriptor(modname, &desc))
        return result;

    /* Prepare to mount */
    rawname = build_rawname(modname);
    if(result = _os_open(rawname, S_IREAD|S_IWRITE, &fd))  {
        (void) _os_unload(modname, 0);
        return result;
    }

    result = _os_ss_mount(fd, &m);
    (void)_os_close(fd);
    memset(m.password, 'X', sizeof(m.password));    /* wipe password */

    if(result)
        (void) _os_unload(modname, 0);

    return result;
}

/* Create our device descriptor on the fly */
error_code create_descriptor(char *modname, struct descriptor_settings *ds) {
    error_code      result;
    u_int16         attr_rev, type_lang;
    mod_dev         *module;
    mh_com          *modhead;
    char            *free;
    event_id        evid;
    int32           evvalue;
    u_int32         optsize;
    char            *fmname;

    /* Get an unique value - to store to desc->port */
    if(result = _os_ev_link("_smbmount", &evid))
        if(result = _os_ev_creat(0, 0, MP_OWNER_READ|MP_OWNER_WRITE,
                                 &evid, "_smbmount", 0x0, MEM_ANY))
                return result;

    evvalue = 1;    /* Increment by 1 */
    if(result = _os_ev_setr(evid, &evvalue, 0))
        return result;

    (void) _os_ev_unlink(evid);

    attr_rev = mkattrevs(MA_REENT, 1);
    type_lang = mktypelang(MT_DEVDESC, ML_OBJECT);

    /* Calculate opts field size */
    if(ds->type == STYPE_PRINTQ) {
        fmname = SMB_PRINTER;
        optsize = sizeof(struct smb_print_opt);
    }
    else {
        fmname = SMB_FMNAME;
        optsize = sizeof(struct smb_opt);
    }

    if(result = _os_mkmodule(modname,
            sizeof(mod_dev) + optsize + 64,
            &attr_rev, &type_lang,
            MP_OWNER_READ|MP_OWNER_EXEC|MP_OWNER_WRITE|
            MP_GROUP_READ|MP_GROUP_EXEC|
            MP_WORLD_READ|MP_WORLD_EXEC,
            (void *)&module,  &modhead, MEM_ANY))
        return result;

    module = (mod_dev *)modhead;
    free = (char *)modhead + sizeof(mod_dev) + optsize;

    /* Set common descriptor fields */
    module->_mport = (char *)evvalue;  /* Must be unique */
    module->_mmode = FAM_WRITE|FAM_APPEND|FAM_READ|FAM_DIR|FAM_SIZE|FAM_EXEC;
    module->_mfmgr = free - (char *)modhead;
    strcpy(free, fmname);
    free += strlen(fmname) + 1;
    free = (char *)(((int)free + 3) & ~3);
    module->_mpdev = free - (char *)modhead;
    strcpy(free, SMB_DRVNAME);

    module->_mopt =    optsize;

    switch(ds->type) {
        case STYPE_PRINTQ:
            {
                struct smb_print_opt *opt = (struct smb_print_opt *)&module->_mdtype;

                opt->pd_dtp = DT_SCF;
            }
            break;

        case STYPE_DISKTREE:
        case STYPE_IPC:
            {
                struct smb_opt *opt = (struct smb_opt *)&module->_mdtype;

                opt->pd_dtp = DT_RBF;
                opt->pd_typ = 0x80;                    /* like hard disk */
                opt->pd_trys = 1;                  /* unused */
                opt->pd_cntl = 1;                  /* format inhibit */
                opt->pd_maxfiles = ds->maxfiles;
                opt->pd_dcsiz = ds->cache_size;

                /* Sanity check */
                if(opt->pd_maxfiles >= opt->pd_dcsiz)
                    /* about 40% */
                    opt->pd_maxfiles = ((u_int)opt->pd_dcsiz*13) >> 5;
            }
            break;

        default:
            return E_PARAM;
    }

    /* Update CRC of the descriptor module */
    if(result = _os_setcrc(modhead)) {
        (void) _os_unlink(modhead);
        return result;
    }
}

/* Create device name for raw access (with '@' at end) */
char *build_rawname(char *modname) {
    static char rawname[16];
    int len = strlen(modname);
    int avail = sizeof(rawname) - 3;

    if(len > avail)
        len = avail;

    rawname[0] = '/';
    strncpy(&rawname[1], modname, sizeof(rawname)-4);
    strcpy(&rawname[len+1], "@");
    return rawname;
}

/* Print help */
void usage(void) {
    printf("\n");
    printf("Usage: %s mount-point [resource] [options]\n", progname);
    printf("\n"
           "\tResource location must have a form //server/share[%%user])\n"
           "\n"
           "\t-c=clientname  Netbios name of client\n"
           "\t-I=hostname    The hostname of server\n"
           "\t-C             Don't convert password to uppercase\n"
           "\t-u             Do unmount\n"
           "\t-$             Connect to IPC$ share\n"
           "\t-p             Connect to printer share\n"
           "\t-t             Also set time from server\n"
           "\t-i             View mount information\n"
           "\t-P=password    Use this password\n"
           "\t-U=username    Use this user name\n"
           "\t-W=domain      Domain name to join to\n"
           "\t-n             Do not use any password\n"
           "\t               If neither -P nor -n are given, you are\n"
           "\t               asked for a password.\n"
           "\t-f=filemode    Use mode for files       (def. is -----ewr)\n"
           "\t-d=dirmode     Use mode for directories (def. is d-ewrewr)\n"
           "\t               modes are hex values without leading 0x.\n"
           "\t-a=cache-size  Set number of directory cache entries  (320)\n"
           "\t-m=max-files   Set max number of files in a directory (128)\n"
           "\t-x=max-mids    Set max number of multiplexed requests (2)\n"
           "\t-e=eol         Set end-of-line character\n"
           "\t               d - DOS (CRLF), u - UNIX (LF), o - OSK (CR).\n"
           "\t-O=options     Set compatibility flags\n"
           "\t-?             Print this help text\n"
           "\n");
    exit(0);
}

/* Parse resource string, get server, share and user name */
error_code parse_resource(char *service,
        char **server, char **share, char **user, int *ok) {

    char   service_copy[SMB_MAX_NAMELEN*3 + 1];
    char   *server_start;

    char   *share_start;
    char   *user_start;

    static char s_server[SMB_MAX_NAMELEN];
    static char s_share[SMB_MAX_NAMELEN];
    static char s_user[SMB_MAX_NAMELEN];

    strncpy(service_copy, service, SMB_MAX_NAMELEN*3);
    server_start = service_copy;

    if (strlen(server_start) < 4) {
        return E_BNAM;
    }

    if (server_start[0] != '/') {
        return E_BNAM;
    }

    while (server_start[0] == '/')
        server_start += 1;

    share_start = strchr(server_start, '/');

    if (share_start == NULL) {
        return E_BNAM;
    }

    share_start[0] = '\0';
    share_start += 1;

    if(strlen(server_start) > sizeof(s_server)-1)
        return E_BNAM;

    user_start = strchr(share_start, '%');

    if (user_start != NULL) {
        if (strlen(user_start + 1) > sizeof(s_user) - 1) {
            fprintf(stderr, "%s: user too long: %s\n", progname, user_start + 1);
            return E_BNAM;
        }

        user_start[0] = '\0';
        user_start += 1;
        strcpy(s_user, user_start);
    }
    else {
        s_user[0] = '\0';
    }

    if(strlen(share_start) > sizeof(s_share) - 1 )
        return E_BNAM;

    /* The length of these has already been checked */
    strcpy(s_server, server_start);
    strcpy(s_share, share_start);

    *server = s_server;
    *share  = s_share;
    *user   = s_user;

    *ok = 1;
    return SUCCESS;
}

/* Unmount */
error_code do_unmount(char *modname) {
    error_code  result;
    path_id     fd;
    char        *rawname = build_rawname(modname);

    if(result = _os_open(rawname, S_IREAD|S_IWRITE, &fd))
        return result;
    result = _os_ss_mount(fd, 0);
    (void)_os_close(fd);
    (void)_os_unload(modname, 0);
    return result;
}

/* Get info about mounted volume */
error_code do_info(char *modname) {
    error_code  result;
    path_id     fd;
    char        *rawname = build_rawname(modname);
    smb_mount_data m;
    int         len;
    static char *typestr[] = {
        "Disk", "Printer", "Device", "IPC"
    };

    if(result = _os_open(rawname, S_IREAD|S_IWRITE, &fd))
        return result;

    result = _os_gs_mount(fd, &m, sizeof(m));
    (void)_os_close(fd);

    if(result)
        return result;

    printf("%-10s %-8s %-12s %-16s //%s/%s\n",
        modname, typestr[m.cntype], m.user_name,
        inet_ntoa(m.addr.sin_addr),
        m.server_name, m.service);

    return SUCCESS;
}

/* Get IP address of the server, both from domain name/numeric notation */
error_code get_server_addr(char *realhost, char *server, smb_mount_data *pm) {
    struct hostent *h;

    if (realhost == 0) {
        if ((h = gethostbyname(server)) == NULL) {
            fprintf(stderr, "%s: unknown host %s\n", progname, server);
            fprintf(stderr, "\tThe -I option may be useful.\n");
            return E_BNAM;
        }
        pm->addr.sin_addr.s_addr = ((struct in_addr *)(h->h_addr))->s_addr;
    }
    else {
        if(isdigit(*realhost)) {
            pm->addr.sin_addr.s_addr = inet_addr(realhost);
        }
        else {
            if ((h = gethostbyname(realhost)) == NULL) {
                fprintf(stderr, "%s: unknown host %s\n", progname, realhost);
                return E_BNAM;
            }
            pm->addr.sin_addr.s_addr = ((struct in_addr *)(h->h_addr))->s_addr;
        }
    }

    pm->addr.sin_family      = AF_INET;
    pm->addr.sin_port        = SMB_PORT;

    return SUCCESS;
}

/* Support section */

#include <ctype.h>

char *strupr( char *str )
{
    char *t = str;

    while (*str) {
        if (islower(*str))
            *str = _toupper(*str);
        str++;
    }
    return t;
}

unsigned int atoh(char *ap) {
    register char *p;
    register unsigned int n;
    register int digit, lcase;

    p = ap;
    n = 0;
    while(*p == ' ' || *p == '\t')
        p++;
    while ((digit = (*p >= '0' && *p <= '9')) ||
        (lcase = (*p >= 'a' && *p <= 'f')) ||
        (*p >= 'A' && *p <= 'F')) {
        n *= 16;
        if (digit)  n += *p++ - '0';
        else if (lcase) n += 10 + (*p++ - 'a');
        else        n += 10 + (*p++ - 'A');
    }
    return(n);
}


_asm("
* Issue mount command
_os_ss_mount:
    movem.l a0/d1,-(sp)
    move.l d1,a0
    move.l #SS_Mount,d1
    os9 I$SetStt
    bcc.b ok
    move.l d1,d0
    bra.b rt
ok  move.l #0,d0
rt  movem.l (sp)+,a0/d1
    rts

* Get mount info
_os_gs_mount:
    movem.l a0/d1/d2,-(sp)
    move.l 16(a7),d2
    move.l d1,a0
    move.l #SS_Mount,d1
    os9 I$GetStt
    bcc.b ok1
    move.l d1,d0
    bra.b rt1
ok1 move.l #0,d0
rt1 movem.l (sp)+,a0/d1/d2
    rts
");