Merge branch 'sb/atomic-push'
[gitweb.git] / contrib / credential / wincred / git-credential-wincred.c
index 66194929fca77a3b0ab41fecf8abc945d9bf2c66..006134043a472edd54720d6d4237e9484476d3a3 100644 (file)
@@ -9,6 +9,8 @@
 
 /* common helpers */
 
+#define ARRAY_SIZE(x) (sizeof(x)/sizeof(x[0]))
+
 static void die(const char *err, ...)
 {
        char msg[4096];
@@ -30,14 +32,6 @@ static void *xmalloc(size_t size)
        return ret;
 }
 
-static char *xstrdup(const char *str)
-{
-       char *ret = strdup(str);
-       if (!ret)
-               die("Out of memory");
-       return ret;
-}
-
 /* MinGW doesn't have wincred.h, so we need to define stuff */
 
 typedef struct _CREDENTIAL_ATTRIBUTEW {
@@ -67,20 +61,14 @@ typedef struct _CREDENTIALW {
 #define CRED_MAX_ATTRIBUTES 64
 
 typedef BOOL (WINAPI *CredWriteWT)(PCREDENTIALW, DWORD);
-typedef BOOL (WINAPI *CredUnPackAuthenticationBufferWT)(DWORD, PVOID, DWORD,
-    LPWSTR, DWORD *, LPWSTR, DWORD *, LPWSTR, DWORD *);
 typedef BOOL (WINAPI *CredEnumerateWT)(LPCWSTR, DWORD, DWORD *,
     PCREDENTIALW **);
-typedef BOOL (WINAPI *CredPackAuthenticationBufferWT)(DWORD, LPWSTR, LPWSTR,
-    PBYTE, DWORD *);
 typedef VOID (WINAPI *CredFreeT)(PVOID);
 typedef BOOL (WINAPI *CredDeleteWT)(LPCWSTR, DWORD, DWORD);
 
-static HMODULE advapi, credui;
+static HMODULE advapi;
 static CredWriteWT CredWriteW;
-static CredUnPackAuthenticationBufferWT CredUnPackAuthenticationBufferW;
 static CredEnumerateWT CredEnumerateW;
-static CredPackAuthenticationBufferWT CredPackAuthenticationBufferW;
 static CredFreeT CredFree;
 static CredDeleteWT CredDeleteW;
 
@@ -88,74 +76,103 @@ static void load_cred_funcs(void)
 {
        /* load DLLs */
        advapi = LoadLibrary("advapi32.dll");
-       credui = LoadLibrary("credui.dll");
-       if (!advapi || !credui)
-               die("failed to load DLLs");
+       if (!advapi)
+               die("failed to load advapi32.dll");
 
        /* get function pointers */
        CredWriteW = (CredWriteWT)GetProcAddress(advapi, "CredWriteW");
-       CredUnPackAuthenticationBufferW = (CredUnPackAuthenticationBufferWT)
-           GetProcAddress(credui, "CredUnPackAuthenticationBufferW");
        CredEnumerateW = (CredEnumerateWT)GetProcAddress(advapi,
            "CredEnumerateW");
-       CredPackAuthenticationBufferW = (CredPackAuthenticationBufferWT)
-           GetProcAddress(credui, "CredPackAuthenticationBufferW");
        CredFree = (CredFreeT)GetProcAddress(advapi, "CredFree");
        CredDeleteW = (CredDeleteWT)GetProcAddress(advapi, "CredDeleteW");
-       if (!CredWriteW || !CredUnPackAuthenticationBufferW ||
-           !CredEnumerateW || !CredPackAuthenticationBufferW || !CredFree ||
-           !CredDeleteW)
+       if (!CredWriteW || !CredEnumerateW || !CredFree || !CredDeleteW)
                die("failed to load functions");
 }
 
-static char target_buf[1024];
-static char *protocol, *host, *path, *username;
-static WCHAR *wusername, *password, *target;
+static WCHAR *wusername, *password, *protocol, *host, *path, target[1024];
 
-static void write_item(const char *what, WCHAR *wbuf)
+static void write_item(const char *what, LPCWSTR wbuf, int wlen)
 {
        char *buf;
-       int len = WideCharToMultiByte(CP_UTF8, 0, wbuf, -1, NULL, 0, NULL,
+       int len = WideCharToMultiByte(CP_UTF8, 0, wbuf, wlen, NULL, 0, NULL,
            FALSE);
        buf = xmalloc(len);
 
-       if (!WideCharToMultiByte(CP_UTF8, 0, wbuf, -1, buf, len, NULL, FALSE))
+       if (!WideCharToMultiByte(CP_UTF8, 0, wbuf, wlen, buf, len, NULL, FALSE))
                die("WideCharToMultiByte failed!");
 
        printf("%s=", what);
-       fwrite(buf, 1, len - 1, stdout);
+       fwrite(buf, 1, len, stdout);
        putchar('\n');
        free(buf);
 }
 
-static int match_attr(const CREDENTIALW *cred, const WCHAR *keyword,
-    const char *want)
+/*
+ * Match an (optional) expected string and a delimiter in the target string,
+ * consuming the matched text by updating the target pointer.
+ */
+
+static LPCWSTR wcsstr_last(LPCWSTR str, LPCWSTR find)
 {
-       int i;
-       if (!want)
-               return 1;
+       LPCWSTR res = NULL, pos;
+       for (pos = wcsstr(str, find); pos; pos = wcsstr(pos + 1, find))
+               res = pos;
+       return res;
+}
 
-       for (i = 0; i < cred->AttributeCount; ++i)
-               if (!wcscmp(cred->Attributes[i].Keyword, keyword))
-                       return !strcmp((const char *)cred->Attributes[i].Value,
-                           want);
+static int match_part_with_last(LPCWSTR *ptarget, LPCWSTR want, LPCWSTR delim, int last)
+{
+       LPCWSTR delim_pos, start = *ptarget;
+       int len;
+
+       /* find start of delimiter (or end-of-string if delim is empty) */
+       if (*delim)
+               delim_pos = last ? wcsstr_last(start, delim) : wcsstr(start, delim);
+       else
+               delim_pos = start + wcslen(start);
+
+       /*
+        * match text up to delimiter, or end of string (e.g. the '/' after
+        * host is optional if not followed by a path)
+        */
+       if (delim_pos)
+               len = delim_pos - start;
+       else
+               len = wcslen(start);
+
+       /* update ptarget if we either found a delimiter or need a match */
+       if (delim_pos || want)
+               *ptarget = delim_pos ? delim_pos + wcslen(delim) : start + len;
+
+       return !want || (!wcsncmp(want, start, len) && !want[len]);
+}
+
+static int match_part(LPCWSTR *ptarget, LPCWSTR want, LPCWSTR delim)
+{
+       return match_part_with_last(ptarget, want, delim, 0);
+}
 
-       return 0; /* not found */
+static int match_part_last(LPCWSTR *ptarget, LPCWSTR want, LPCWSTR delim)
+{
+       return match_part_with_last(ptarget, want, delim, 1);
 }
 
 static int match_cred(const CREDENTIALW *cred)
 {
-       return (!wusername || !wcscmp(wusername, cred->UserName)) &&
-           match_attr(cred, L"git_protocol", protocol) &&
-           match_attr(cred, L"git_host", host) &&
-           match_attr(cred, L"git_path", path);
+       LPCWSTR target = cred->TargetName;
+       if (wusername && wcscmp(wusername, cred->UserName))
+               return 0;
+
+       return match_part(&target, L"git", L":") &&
+               match_part(&target, protocol, L"://") &&
+               match_part_last(&target, wusername, L"@") &&
+               match_part(&target, host, L"/") &&
+               match_part(&target, path, L"");
 }
 
 static void get_credential(void)
 {
-       WCHAR *user_buf, *pass_buf;
-       DWORD user_buf_size = 0, pass_buf_size = 0;
-       CREDENTIALW **creds, *cred = NULL;
+       CREDENTIALW **creds;
        DWORD num_creds;
        int i;
 
@@ -165,90 +182,36 @@ static void get_credential(void)
        /* search for the first credential that matches username */
        for (i = 0; i < num_creds; ++i)
                if (match_cred(creds[i])) {
-                       cred = creds[i];
+                       write_item("username", creds[i]->UserName,
+                               wcslen(creds[i]->UserName));
+                       write_item("password",
+                               (LPCWSTR)creds[i]->CredentialBlob,
+                               creds[i]->CredentialBlobSize / sizeof(WCHAR));
                        break;
                }
-       if (!cred)
-               return;
-
-       CredUnPackAuthenticationBufferW(0, cred->CredentialBlob,
-           cred->CredentialBlobSize, NULL, &user_buf_size, NULL, NULL,
-           NULL, &pass_buf_size);
-
-       user_buf = xmalloc(user_buf_size * sizeof(WCHAR));
-       pass_buf = xmalloc(pass_buf_size * sizeof(WCHAR));
-
-       if (!CredUnPackAuthenticationBufferW(0, cred->CredentialBlob,
-           cred->CredentialBlobSize, user_buf, &user_buf_size, NULL, NULL,
-           pass_buf, &pass_buf_size))
-               die("CredUnPackAuthenticationBuffer failed");
 
        CredFree(creds);
-
-       /* zero-terminate (sizes include zero-termination) */
-       user_buf[user_buf_size - 1] = L'\0';
-       pass_buf[pass_buf_size - 1] = L'\0';
-
-       write_item("username", user_buf);
-       write_item("password", pass_buf);
-
-       free(user_buf);
-       free(pass_buf);
-}
-
-static void write_attr(CREDENTIAL_ATTRIBUTEW *attr, const WCHAR *keyword,
-    const char *value)
-{
-       attr->Keyword = (LPWSTR)keyword;
-       attr->Flags = 0;
-       attr->ValueSize = strlen(value) + 1; /* store zero-termination */
-       attr->Value = (LPBYTE)value;
 }
 
 static void store_credential(void)
 {
        CREDENTIALW cred;
-       BYTE *auth_buf;
-       DWORD auth_buf_size = 0;
-       CREDENTIAL_ATTRIBUTEW attrs[CRED_MAX_ATTRIBUTES];
 
        if (!wusername || !password)
                return;
 
-       /* query buffer size */
-       CredPackAuthenticationBufferW(0, wusername, password,
-           NULL, &auth_buf_size);
-
-       auth_buf = xmalloc(auth_buf_size);
-
-       if (!CredPackAuthenticationBufferW(0, wusername, password,
-           auth_buf, &auth_buf_size))
-               die("CredPackAuthenticationBuffer failed");
-
        cred.Flags = 0;
        cred.Type = CRED_TYPE_GENERIC;
        cred.TargetName = target;
        cred.Comment = L"saved by git-credential-wincred";
-       cred.CredentialBlobSize = auth_buf_size;
-       cred.CredentialBlob = auth_buf;
+       cred.CredentialBlobSize = (wcslen(password)) * sizeof(WCHAR);
+       cred.CredentialBlob = (LPVOID)password;
        cred.Persist = CRED_PERSIST_LOCAL_MACHINE;
-       cred.AttributeCount = 1;
-       cred.Attributes = attrs;
+       cred.AttributeCount = 0;
+       cred.Attributes = NULL;
        cred.TargetAlias = NULL;
        cred.UserName = wusername;
 
-       write_attr(attrs, L"git_protocol", protocol);
-
-       if (host) {
-               write_attr(attrs + cred.AttributeCount, L"git_host", host);
-               cred.AttributeCount++;
-       }
-
-       if (path) {
-               write_attr(attrs + cred.AttributeCount, L"git_path", path);
-               cred.AttributeCount++;
-       }
-
        if (!CredWriteW(&cred, 0))
                die("CredWrite failed");
 }
@@ -284,10 +247,13 @@ static void read_credential(void)
 
        while (fgets(buf, sizeof(buf), stdin)) {
                char *v;
+               int len = strlen(buf);
+               /* strip trailing CR / LF */
+               while (len && strchr("\r\n", buf[len - 1]))
+                       buf[--len] = 0;
 
-               if (!strcmp(buf, "\n"))
+               if (!*buf)
                        break;
-               buf[strlen(buf)-1] = '\0';
 
                v = strchr(buf, '=');
                if (!v)
@@ -295,13 +261,12 @@ static void read_credential(void)
                *v++ = '\0';
 
                if (!strcmp(buf, "protocol"))
-                       protocol = xstrdup(v);
+                       protocol = utf8_to_utf16_dup(v);
                else if (!strcmp(buf, "host"))
-                       host = xstrdup(v);
+                       host = utf8_to_utf16_dup(v);
                else if (!strcmp(buf, "path"))
-                       path = xstrdup(v);
+                       path = utf8_to_utf16_dup(v);
                else if (!strcmp(buf, "username")) {
-                       username = xstrdup(v);
                        wusername = utf8_to_utf16_dup(v);
                } else if (!strcmp(buf, "password"))
                        password = utf8_to_utf16_dup(v);
@@ -330,22 +295,20 @@ int main(int argc, char *argv[])
                return 0;
 
        /* prepare 'target', the unique key for the credential */
-       strncat(target_buf, "git:", sizeof(target_buf));
-       strncat(target_buf, protocol, sizeof(target_buf));
-       strncat(target_buf, "://", sizeof(target_buf));
-       if (username) {
-               strncat(target_buf, username, sizeof(target_buf));
-               strncat(target_buf, "@", sizeof(target_buf));
+       wcscpy(target, L"git:");
+       wcsncat(target, protocol, ARRAY_SIZE(target));
+       wcsncat(target, L"://", ARRAY_SIZE(target));
+       if (wusername) {
+               wcsncat(target, wusername, ARRAY_SIZE(target));
+               wcsncat(target, L"@", ARRAY_SIZE(target));
        }
        if (host)
-               strncat(target_buf, host, sizeof(target_buf));
+               wcsncat(target, host, ARRAY_SIZE(target));
        if (path) {
-               strncat(target_buf, "/", sizeof(target_buf));
-               strncat(target_buf, path, sizeof(target_buf));
+               wcsncat(target, L"/", ARRAY_SIZE(target));
+               wcsncat(target, path, ARRAY_SIZE(target));
        }
 
-       target = utf8_to_utf16_dup(target_buf);
-
        if (!strcmp(argv[1], "get"))
                get_credential();
        else if (!strcmp(argv[1], "store"))