diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..d5a2923ff5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "c/kastore"] + path = c/kastore + url = https://github.com/tskit-dev/kastore.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..a4abd17e1f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 tskit-dev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index bc65f319c4..4c69c5e5b4 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,9 @@ # tskit The tree sequence toolkit -**This repository is a placeholder, and currently only exists to ensure the - name isn't taken on PyPI.** +**This repository is currently under heavy development.** -Currently, the `tskit` tools are still bundled with the Python package [`msprime`](https://github.com/tskit-dev/msprime). -We expect to separate the two soon, -but to use `tskit` in the meantime, [install `msprime`](https://msprime.readthedocs.io/en/latest/installation.html), -and then -``` -import msprime as tskit -``` -will get you the functionality that `tskit` will provide in the future. +Currently, the `tskit` tools are bundled with the Python package +[`msprime`](https://github.com/tskit-dev/msprime) and we are in the process +of extracting them here. diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst new file mode 100644 index 0000000000..ee5179e567 --- /dev/null +++ b/c/CHANGELOG.rst @@ -0,0 +1,6 @@ +-------------------- +[0.0.0] - 2019-01-19 +-------------------- + +Initial extraction of tskit code from msprime. Relicense to MIT. +Code copied at hash 29921408661d5fe0b1a82b1ca302a8b87510fd23 diff --git a/c/Makefile b/c/Makefile new file mode 100644 index 0000000000..c28838deff --- /dev/null +++ b/c/Makefile @@ -0,0 +1,47 @@ +CC=gcc +CFLAGS=-g -O2 -std=c99 -pedantic -Werror -Wall -W \ + -Wmissing-prototypes -Wstrict-prototypes \ + -Wconversion -Wshadow -Wpointer-arith \ + -Wcast-align -Wcast-qual \ + -Wwrite-strings -Wnested-externs \ + -fshort-enums -fno-common -Dinline= \ + -Ikastore/c +LDFLAGS=-lm +OBJECTS=tsk_core.o tsk_tables.o tsk_trees.o tsk_genotypes.o \ + tsk_convert.o tsk_stats.o kastore.o +HEADERS=tsk_core.h tsk_tables.h tsk_trees.h tsk_genotypes.h \ + tsk_convert.h tsk_stats.h +EXECUTABLES=main test_core test_tables test_trees test_genotypes + +all: old_tests ${EXECUTABLES} libtskit.a + +main: main.c ${OBJECTS} argtable3.o + ${CC} ${CFLAGS} -o $@ $^ ${LDFLAGS} + +libtskit.a: ${OBJECTS} + ar rcs $@ $^ + +test_%: test_%.c testlib.o ${OBJECTS} + ${CC} ${CFLAGS} -o $@ $^ ${LDFLAGS} -lcunit + +tsk_%.o: tsk_%.c tsk_%.h tsk_core.h + ${CC} ${CFLAGS} $< -c + +# We can't turn on all the usual compiler checks because of CUnit. +testlib.o: testlib.c + ${CC} -Ikastore/c -std=c99 -Wall -g -c $^ + +argtable3.o: argtable3.c + ${CC} -Dlint -Wall -g -O2 -c argtable3.c + +old_tests: old_tests.c ${OBJECTS} + ${CC} -Ikastore/c -std=c99 -Wall -g $^ ${LDFLAGS} -lcunit -lgsl -lgslcblas + +kastore.o: kastore/c/kastore.c + $(CC) -c $(CFLAGS) $(CPPFLAGS) $< + +clean: + rm -f *.o ${EXECUTABLES} + +tags: + ctags *.c *.h diff --git a/c/argtable3.c b/c/argtable3.c new file mode 100644 index 0000000000..a86774c922 --- /dev/null +++ b/c/argtable3.c @@ -0,0 +1,4955 @@ +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include "argtable3.h" + +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 2013 Tom G. Huang + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#ifndef ARG_UTILS_H +#define ARG_UTILS_H + +#define ARG_ENABLE_TRACE 0 +#define ARG_ENABLE_LOG 1 + +#ifdef __cplusplus +extern "C" { +#endif + +enum +{ + EMINCOUNT = 1, + EMAXCOUNT, + EBADINT, + EOVERFLOW, + EBADDOUBLE, + EBADDATE, + EREGNOMATCH +}; + + +#if defined(_MSC_VER) +#define ARG_TRACE(x) \ + __pragma(warning(push)) \ + __pragma(warning(disable:4127)) \ + do { if (ARG_ENABLE_TRACE) dbg_printf x; } while (0) \ + __pragma(warning(pop)) + +#define ARG_LOG(x) \ + __pragma(warning(push)) \ + __pragma(warning(disable:4127)) \ + do { if (ARG_ENABLE_LOG) dbg_printf x; } while (0) \ + __pragma(warning(pop)) +#else +#define ARG_TRACE(x) \ + do { if (ARG_ENABLE_TRACE) dbg_printf x; } while (0) + +#define ARG_LOG(x) \ + do { if (ARG_ENABLE_LOG) dbg_printf x; } while (0) +#endif + +extern void dbg_printf(const char *fmt, ...); + +#ifdef __cplusplus +} +#endif + +#endif + +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include + + +void dbg_printf(const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); +} + +/* $Id: getopt.h,v 1.1 2009/10/16 19:50:28 rodney Exp rodney $ */ +/* $OpenBSD: getopt.h,v 1.1 2002/12/03 20:24:29 millert Exp $ */ +/* $NetBSD: getopt.h,v 1.4 2000/07/07 10:43:54 ad Exp $ */ + +/*- + * Copyright (c) 2000 The NetBSD Foundation, Inc. + * All rights reserved. + * + * This code is derived from software contributed to The NetBSD Foundation + * by Dieter Baron and Thomas Klausner. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. All advertising materials mentioning features or use of this software + * must display the following acknowledgement: + * This product includes software developed by the NetBSD + * Foundation, Inc. and its contributors. + * 4. Neither the name of The NetBSD Foundation nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _GETOPT_H_ +#define _GETOPT_H_ + +#if 0 +#include +#endif + +/* + * GNU-like getopt_long() and 4.4BSD getsubopt()/optreset extensions + */ +#define no_argument 0 +#define required_argument 1 +#define optional_argument 2 + +struct option { + /* name of long option */ + const char *name; + /* + * one of no_argument, required_argument, and optional_argument: + * whether option takes an argument + */ + int has_arg; + /* if not NULL, set *flag to val when option found */ + int *flag; + /* if flag not NULL, value to set *flag to; else return value */ + int val; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +int getopt_long(int, char * const *, const char *, + const struct option *, int *); +int getopt_long_only(int, char * const *, const char *, + const struct option *, int *); +#ifndef _GETOPT_DEFINED +#define _GETOPT_DEFINED +int getopt(int, char * const *, const char *); +int getsubopt(char **, char * const *, char **); + +extern char *optarg; /* getopt(3) external variables */ +extern int opterr; +extern int optind; +extern int optopt; +extern int optreset; +extern char *suboptarg; /* getsubopt(3) external variable */ +#endif /* _GETOPT_DEFINED */ + +#ifdef __cplusplus +} +#endif +#endif /* !_GETOPT_H_ */ +/* $Id: getopt_long.c,v 1.1 2009/10/16 19:50:28 rodney Exp rodney $ */ +/* $OpenBSD: getopt_long.c,v 1.23 2007/10/31 12:34:57 chl Exp $ */ +/* $NetBSD: getopt_long.c,v 1.15 2002/01/31 22:43:40 tv Exp $ */ + +/* + * Copyright (c) 2002 Todd C. Miller + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + * Sponsored in part by the Defense Advanced Research Projects + * Agency (DARPA) and Air Force Research Laboratory, Air Force + * Materiel Command, USAF, under agreement number F39502-99-1-0512. + */ + +#ifndef lint +static const char rcsid[]="$Id: getopt_long.c,v 1.1 2009/10/16 19:50:28 rodney Exp rodney $"; +#endif /* lint */ +/*- + * Copyright (c) 2000 The NetBSD Foundation, Inc. + * All rights reserved. + * + * This code is derived from software contributed to The NetBSD Foundation + * by Dieter Baron and Thomas Klausner. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#if 0 +#include +#endif +#include +#include +#include + + +#define REPLACE_GETOPT /* use this getopt as the system getopt(3) */ + +#ifdef REPLACE_GETOPT +int opterr = 1; /* if error message should be printed */ +int optind = 1; /* index into parent argv vector */ +int optopt = '?'; /* character checked for validity */ +int optreset; /* reset getopt */ +char *optarg; /* argument associated with option */ +#endif + +#define PRINT_ERROR ((opterr) && (*options != ':')) + +#define FLAG_PERMUTE 0x01 /* permute non-options to the end of argv */ +#define FLAG_ALLARGS 0x02 /* treat non-options as args to option "-1" */ +#define FLAG_LONGONLY 0x04 /* operate as getopt_long_only */ + +/* return values */ +#define BADCH (int)'?' +#define BADARG ((*options == ':') ? (int)':' : (int)'?') +#define INORDER (int)1 + +#define EMSG "" + +static int getopt_internal(int, char * const *, const char *, + const struct option *, int *, int); +static int parse_long_options(char * const *, const char *, + const struct option *, int *, int); +static int gcd(int, int); +static void permute_args(int, int, int, char * const *); + +static char *place = EMSG; /* option letter processing */ + +/* XXX: set optreset to 1 rather than these two */ +static int nonopt_start = -1; /* first non option argument (for permute) */ +static int nonopt_end = -1; /* first option after non options (for permute) */ + +/* Error messages */ +static const char recargchar[] = "option requires an argument -- %c"; +static const char recargstring[] = "option requires an argument -- %s"; +static const char ambig[] = "ambiguous option -- %.*s"; +static const char noarg[] = "option doesn't take an argument -- %.*s"; +static const char illoptchar[] = "unknown option -- %c"; +static const char illoptstring[] = "unknown option -- %s"; + + + +#ifdef _WIN32 + +/* Windows needs warnx(). We change the definition though: + * 1. (another) global is defined, opterrmsg, which holds the error message + * 2. errors are always printed out on stderr w/o the program name + * Note that opterrmsg always gets set no matter what opterr is set to. The + * error message will not be printed if opterr is 0 as usual. + */ + +#include +#include + +extern char opterrmsg[128]; +char opterrmsg[128]; /* buffer for the last error message */ + +static void warnx(const char *fmt, ...) +{ + va_list ap; + va_start(ap, fmt); + /* + Make sure opterrmsg is always zero-terminated despite the _vsnprintf() + implementation specifics and manually suppress the warning. + */ + memset(opterrmsg, 0, sizeof opterrmsg); + if (fmt != NULL) + _vsnprintf(opterrmsg, sizeof(opterrmsg) - 1, fmt, ap); + va_end(ap); + +#pragma warning(suppress: 6053) + fprintf(stderr, "%s\n", opterrmsg); +} + +#else +#include +#endif /*_WIN32*/ + + +/* + * Compute the greatest common divisor of a and b. + */ +static int +gcd(int a, int b) +{ + int c; + + c = a % b; + while (c != 0) { + a = b; + b = c; + c = a % b; + } + + return (b); +} + +/* + * Exchange the block from nonopt_start to nonopt_end with the block + * from nonopt_end to opt_end (keeping the same order of arguments + * in each block). + */ +static void +permute_args(int panonopt_start, int panonopt_end, int opt_end, + char * const *nargv) +{ + int cstart, cyclelen, i, j, ncycle, nnonopts, nopts, pos; + char *swap; + + /* + * compute lengths of blocks and number and size of cycles + */ + nnonopts = panonopt_end - panonopt_start; + nopts = opt_end - panonopt_end; + ncycle = gcd(nnonopts, nopts); + cyclelen = (opt_end - panonopt_start) / ncycle; + + for (i = 0; i < ncycle; i++) { + cstart = panonopt_end+i; + pos = cstart; + for (j = 0; j < cyclelen; j++) { + if (pos >= panonopt_end) + pos -= nnonopts; + else + pos += nopts; + swap = nargv[pos]; + /* LINTED const cast */ + ((char **) nargv)[pos] = nargv[cstart]; + /* LINTED const cast */ + ((char **)nargv)[cstart] = swap; + } + } +} + +/* + * parse_long_options -- + * Parse long options in argc/argv argument vector. + * Returns -1 if short_too is set and the option does not match long_options. + */ +static int +parse_long_options(char * const *nargv, const char *options, + const struct option *long_options, int *idx, int short_too) +{ + char *current_argv, *has_equal; + size_t current_argv_len; + int i, match; + + current_argv = place; + match = -1; + + optind++; + + if ((has_equal = strchr(current_argv, '=')) != NULL) { + /* argument found (--option=arg) */ + current_argv_len = has_equal - current_argv; + has_equal++; + } else + current_argv_len = strlen(current_argv); + + for (i = 0; long_options[i].name; i++) { + /* find matching long option */ + if (strncmp(current_argv, long_options[i].name, + current_argv_len)) + continue; + + if (strlen(long_options[i].name) == current_argv_len) { + /* exact match */ + match = i; + break; + } + /* + * If this is a known short option, don't allow + * a partial match of a single character. + */ + if (short_too && current_argv_len == 1) + continue; + + if (match == -1) /* partial match */ + match = i; + else { + /* ambiguous abbreviation */ + if (PRINT_ERROR) + warnx(ambig, (int)current_argv_len, + current_argv); + optopt = 0; + return (BADCH); + } + } + if (match != -1) { /* option found */ + if (long_options[match].has_arg == no_argument + && has_equal) { + if (PRINT_ERROR) + warnx(noarg, (int)current_argv_len, + current_argv); + /* + * XXX: GNU sets optopt to val regardless of flag + */ + if (long_options[match].flag == NULL) + optopt = long_options[match].val; + else + optopt = 0; + return (BADARG); + } + if (long_options[match].has_arg == required_argument || + long_options[match].has_arg == optional_argument) { + if (has_equal) + optarg = has_equal; + else if (long_options[match].has_arg == + required_argument) { + /* + * optional argument doesn't use next nargv + */ + optarg = nargv[optind++]; + } + } + if ((long_options[match].has_arg == required_argument) + && (optarg == NULL)) { + /* + * Missing argument; leading ':' indicates no error + * should be generated. + */ + if (PRINT_ERROR) + warnx(recargstring, + current_argv); + /* + * XXX: GNU sets optopt to val regardless of flag + */ + if (long_options[match].flag == NULL) + optopt = long_options[match].val; + else + optopt = 0; + --optind; + return (BADARG); + } + } else { /* unknown option */ + if (short_too) { + --optind; + return (-1); + } + if (PRINT_ERROR) + warnx(illoptstring, current_argv); + optopt = 0; + return (BADCH); + } + if (idx) + *idx = match; + if (long_options[match].flag) { + *long_options[match].flag = long_options[match].val; + return (0); + } else + return (long_options[match].val); +} + +/* + * getopt_internal -- + * Parse argc/argv argument vector. Called by user level routines. + */ +static int +getopt_internal(int nargc, char * const *nargv, const char *options, + const struct option *long_options, int *idx, int flags) +{ + char *oli; /* option letter list index */ + int optchar, short_too; + static int posixly_correct = -1; + + if (options == NULL) + return (-1); + + /* + * Disable GNU extensions if POSIXLY_CORRECT is set or options + * string begins with a '+'. + */ + if (posixly_correct == -1) + posixly_correct = (getenv("POSIXLY_CORRECT") != NULL); + if (posixly_correct || *options == '+') + flags &= ~FLAG_PERMUTE; + else if (*options == '-') + flags |= FLAG_ALLARGS; + if (*options == '+' || *options == '-') + options++; + + /* + * XXX Some GNU programs (like cvs) set optind to 0 instead of + * XXX using optreset. Work around this braindamage. + */ + if (optind == 0) + optind = optreset = 1; + + optarg = NULL; + if (optreset) + nonopt_start = nonopt_end = -1; +start: + if (optreset || !*place) { /* update scanning pointer */ + optreset = 0; + if (optind >= nargc) { /* end of argument vector */ + place = EMSG; + if (nonopt_end != -1) { + /* do permutation, if we have to */ + permute_args(nonopt_start, nonopt_end, + optind, nargv); + optind -= nonopt_end - nonopt_start; + } + else if (nonopt_start != -1) { + /* + * If we skipped non-options, set optind + * to the first of them. + */ + optind = nonopt_start; + } + nonopt_start = nonopt_end = -1; + return (-1); + } + if (*(place = nargv[optind]) != '-' || + (place[1] == '\0' && strchr(options, '-') == NULL)) { + place = EMSG; /* found non-option */ + if (flags & FLAG_ALLARGS) { + /* + * GNU extension: + * return non-option as argument to option 1 + */ + optarg = nargv[optind++]; + return (INORDER); + } + if (!(flags & FLAG_PERMUTE)) { + /* + * If no permutation wanted, stop parsing + * at first non-option. + */ + return (-1); + } + /* do permutation */ + if (nonopt_start == -1) + nonopt_start = optind; + else if (nonopt_end != -1) { + permute_args(nonopt_start, nonopt_end, + optind, nargv); + nonopt_start = optind - + (nonopt_end - nonopt_start); + nonopt_end = -1; + } + optind++; + /* process next argument */ + goto start; + } + if (nonopt_start != -1 && nonopt_end == -1) + nonopt_end = optind; + + /* + * If we have "-" do nothing, if "--" we are done. + */ + if (place[1] != '\0' && *++place == '-' && place[1] == '\0') { + optind++; + place = EMSG; + /* + * We found an option (--), so if we skipped + * non-options, we have to permute. + */ + if (nonopt_end != -1) { + permute_args(nonopt_start, nonopt_end, + optind, nargv); + optind -= nonopt_end - nonopt_start; + } + nonopt_start = nonopt_end = -1; + return (-1); + } + } + + /* + * Check long options if: + * 1) we were passed some + * 2) the arg is not just "-" + * 3) either the arg starts with -- we are getopt_long_only() + */ + if (long_options != NULL && place != nargv[optind] && + (*place == '-' || (flags & FLAG_LONGONLY))) { + short_too = 0; + if (*place == '-') + place++; /* --foo long option */ + else if (*place != ':' && strchr(options, *place) != NULL) + short_too = 1; /* could be short option too */ + + optchar = parse_long_options(nargv, options, long_options, + idx, short_too); + if (optchar != -1) { + place = EMSG; + return (optchar); + } + } + + if ((optchar = (int)*place++) == (int)':' || + (optchar == (int)'-' && *place != '\0') || + (oli = strchr(options, optchar)) == NULL) { + /* + * If the user specified "-" and '-' isn't listed in + * options, return -1 (non-option) as per POSIX. + * Otherwise, it is an unknown option character (or ':'). + */ + if (optchar == (int)'-' && *place == '\0') + return (-1); + if (!*place) + ++optind; + if (PRINT_ERROR) + warnx(illoptchar, optchar); + optopt = optchar; + return (BADCH); + } + if (long_options != NULL && optchar == 'W' && oli[1] == ';') { + /* -W long-option */ + if (*place) /* no space */ + /* NOTHING */; + else if (++optind >= nargc) { /* no arg */ + place = EMSG; + if (PRINT_ERROR) + warnx(recargchar, optchar); + optopt = optchar; + return (BADARG); + } else /* white space */ + place = nargv[optind]; + optchar = parse_long_options(nargv, options, long_options, + idx, 0); + place = EMSG; + return (optchar); + } + if (*++oli != ':') { /* doesn't take argument */ + if (!*place) + ++optind; + } else { /* takes (optional) argument */ + optarg = NULL; + if (*place) /* no white space */ + optarg = place; + else if (oli[1] != ':') { /* arg not optional */ + if (++optind >= nargc) { /* no arg */ + place = EMSG; + if (PRINT_ERROR) + warnx(recargchar, optchar); + optopt = optchar; + return (BADARG); + } else + optarg = nargv[optind]; + } + place = EMSG; + ++optind; + } + /* dump back option letter */ + return (optchar); +} + +#ifdef REPLACE_GETOPT +/* + * getopt -- + * Parse argc/argv argument vector. + * + * [eventually this will replace the BSD getopt] + */ +int +getopt(int nargc, char * const *nargv, const char *options) +{ + + /* + * We don't pass FLAG_PERMUTE to getopt_internal() since + * the BSD getopt(3) (unlike GNU) has never done this. + * + * Furthermore, since many privileged programs call getopt() + * before dropping privileges it makes sense to keep things + * as simple (and bug-free) as possible. + */ + return (getopt_internal(nargc, nargv, options, NULL, NULL, 0)); +} +#endif /* REPLACE_GETOPT */ + +/* + * getopt_long -- + * Parse argc/argv argument vector. + */ +int +getopt_long(int nargc, char * const *nargv, const char *options, + const struct option *long_options, int *idx) +{ + + return (getopt_internal(nargc, nargv, options, long_options, idx, + FLAG_PERMUTE)); +} + +/* + * getopt_long_only -- + * Parse argc/argv argument vector. + */ +int +getopt_long_only(int nargc, char * const *nargv, const char *options, + const struct option *long_options, int *idx) +{ + + return (getopt_internal(nargc, nargv, options, long_options, idx, + FLAG_PERMUTE|FLAG_LONGONLY)); +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include + +#include "argtable3.h" + + +char * arg_strptime(const char *buf, const char *fmt, struct tm *tm); + + +static void arg_date_resetfn(struct arg_date *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +static int arg_date_scanfn(struct arg_date *parent, const char *argval) +{ + int errorcode = 0; + + if (parent->count == parent->hdr.maxcount) + { + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* no argument value was given, leave parent->tmval[] unaltered but still count it */ + parent->count++; + } + else + { + const char *pend; + struct tm tm = parent->tmval[parent->count]; + + /* parse the given argument value, store result in parent->tmval[] */ + pend = arg_strptime(argval, parent->format, &tm); + if (pend && pend[0] == '\0') + parent->tmval[parent->count++] = tm; + else + errorcode = EBADDATE; + } + + ARG_TRACE(("%s:scanfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static int arg_date_checkfn(struct arg_date *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + + ARG_TRACE(("%s:checkfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static void arg_date_errorfn( + struct arg_date *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + case EBADDATE: + { + struct tm tm; + char buff[200]; + + fprintf(fp, "illegal timestamp format \"%s\"\n", argval); + memset(&tm, 0, sizeof(tm)); + arg_strptime("1999-12-31 23:59:59", "%F %H:%M:%S", &tm); + strftime(buff, sizeof(buff), parent->format, &tm); + printf("correct format is \"%s\"\n", buff); + break; + } + } +} + + +struct arg_date * arg_date0( + const char * shortopts, + const char * longopts, + const char * format, + const char *datatype, + const char *glossary) +{ + return arg_daten(shortopts, longopts, format, datatype, 0, 1, glossary); +} + + +struct arg_date * arg_date1( + const char * shortopts, + const char * longopts, + const char * format, + const char *datatype, + const char *glossary) +{ + return arg_daten(shortopts, longopts, format, datatype, 1, 1, glossary); +} + + +struct arg_date * arg_daten( + const char * shortopts, + const char * longopts, + const char * format, + const char *datatype, + int mincount, + int maxcount, + const char *glossary) +{ + size_t nbytes; + struct arg_date *result; + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + /* default time format is the national date format for the locale */ + if (!format) + format = "%x"; + + nbytes = sizeof(struct arg_date) /* storage for struct arg_date */ + + maxcount * sizeof(struct tm); /* storage for tmval[maxcount] array */ + + /* allocate storage for the arg_date struct + tmval[] array. */ + /* we use calloc because we want the tmval[] array zero filled. */ + result = (struct arg_date *)calloc(1, nbytes); + if (result) + { + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = datatype ? datatype : format; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_date_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_date_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_date_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_date_errorfn; + + /* store the tmval[maxcount] array immediately after the arg_date struct */ + result->tmval = (struct tm *)(result + 1); + + /* init the remaining arg_date member variables */ + result->count = 0; + result->format = format; + } + + ARG_TRACE(("arg_daten() returns %p\n", result)); + return result; +} + + +/*- + * Copyright (c) 1997, 1998, 2005, 2008 The NetBSD Foundation, Inc. + * All rights reserved. + * + * This code was contributed to The NetBSD Foundation by Klaus Klein. + * Heavily optimised by David Laight + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include + +/* + * We do not implement alternate representations. However, we always + * check whether a given modifier is allowed for a certain conversion. + */ +#define ALT_E 0x01 +#define ALT_O 0x02 +#define LEGAL_ALT(x) { if (alt_format & ~(x)) return (0); } +#define TM_YEAR_BASE (1900) + +static int conv_num(const char * *, int *, int, int); + +static const char *day[7] = { + "Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", + "Friday", "Saturday" +}; + +static const char *abday[7] = { + "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat" +}; + +static const char *mon[12] = { + "January", "February", "March", "April", "May", "June", "July", + "August", "September", "October", "November", "December" +}; + +static const char *abmon[12] = { + "Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" +}; + +static const char *am_pm[2] = { + "AM", "PM" +}; + + +static int arg_strcasecmp(const char *s1, const char *s2) +{ + const unsigned char *us1 = (const unsigned char *)s1; + const unsigned char *us2 = (const unsigned char *)s2; + while (tolower(*us1) == tolower(*us2++)) + if (*us1++ == '\0') + return 0; + + return tolower(*us1) - tolower(*--us2); +} + + +static int arg_strncasecmp(const char *s1, const char *s2, size_t n) +{ + if (n != 0) + { + const unsigned char *us1 = (const unsigned char *)s1; + const unsigned char *us2 = (const unsigned char *)s2; + do + { + if (tolower(*us1) != tolower(*us2++)) + return tolower(*us1) - tolower(*--us2); + + if (*us1++ == '\0') + break; + } while (--n != 0); + } + + return 0; +} + + +char * arg_strptime(const char *buf, const char *fmt, struct tm *tm) +{ + char c; + const char *bp; + size_t len = 0; + int alt_format, i, split_year = 0; + + bp = buf; + + while ((c = *fmt) != '\0') { + /* Clear `alternate' modifier prior to new conversion. */ + alt_format = 0; + + /* Eat up white-space. */ + if (isspace(c)) { + while (isspace(*bp)) + bp++; + + fmt++; + continue; + } + + if ((c = *fmt++) != '%') + goto literal; + + +again: + switch (c = *fmt++) + { + case '%': /* "%%" is converted to "%". */ +literal: + if (c != *bp++) + return (0); + break; + + /* + * "Alternative" modifiers. Just set the appropriate flag + * and start over again. + */ + case 'E': /* "%E?" alternative conversion modifier. */ + LEGAL_ALT(0); + alt_format |= ALT_E; + goto again; + + case 'O': /* "%O?" alternative conversion modifier. */ + LEGAL_ALT(0); + alt_format |= ALT_O; + goto again; + + /* + * "Complex" conversion rules, implemented through recursion. + */ + case 'c': /* Date and time, using the locale's format. */ + LEGAL_ALT(ALT_E); + bp = arg_strptime(bp, "%x %X", tm); + if (!bp) + return (0); + break; + + case 'D': /* The date as "%m/%d/%y". */ + LEGAL_ALT(0); + bp = arg_strptime(bp, "%m/%d/%y", tm); + if (!bp) + return (0); + break; + + case 'R': /* The time as "%H:%M". */ + LEGAL_ALT(0); + bp = arg_strptime(bp, "%H:%M", tm); + if (!bp) + return (0); + break; + + case 'r': /* The time in 12-hour clock representation. */ + LEGAL_ALT(0); + bp = arg_strptime(bp, "%I:%M:%S %p", tm); + if (!bp) + return (0); + break; + + case 'T': /* The time as "%H:%M:%S". */ + LEGAL_ALT(0); + bp = arg_strptime(bp, "%H:%M:%S", tm); + if (!bp) + return (0); + break; + + case 'X': /* The time, using the locale's format. */ + LEGAL_ALT(ALT_E); + bp = arg_strptime(bp, "%H:%M:%S", tm); + if (!bp) + return (0); + break; + + case 'x': /* The date, using the locale's format. */ + LEGAL_ALT(ALT_E); + bp = arg_strptime(bp, "%m/%d/%y", tm); + if (!bp) + return (0); + break; + + /* + * "Elementary" conversion rules. + */ + case 'A': /* The day of week, using the locale's form. */ + case 'a': + LEGAL_ALT(0); + for (i = 0; i < 7; i++) { + /* Full name. */ + len = strlen(day[i]); + if (arg_strncasecmp(day[i], bp, len) == 0) + break; + + /* Abbreviated name. */ + len = strlen(abday[i]); + if (arg_strncasecmp(abday[i], bp, len) == 0) + break; + } + + /* Nothing matched. */ + if (i == 7) + return (0); + + tm->tm_wday = i; + bp += len; + break; + + case 'B': /* The month, using the locale's form. */ + case 'b': + case 'h': + LEGAL_ALT(0); + for (i = 0; i < 12; i++) { + /* Full name. */ + len = strlen(mon[i]); + if (arg_strncasecmp(mon[i], bp, len) == 0) + break; + + /* Abbreviated name. */ + len = strlen(abmon[i]); + if (arg_strncasecmp(abmon[i], bp, len) == 0) + break; + } + + /* Nothing matched. */ + if (i == 12) + return (0); + + tm->tm_mon = i; + bp += len; + break; + + case 'C': /* The century number. */ + LEGAL_ALT(ALT_E); + if (!(conv_num(&bp, &i, 0, 99))) + return (0); + + if (split_year) { + tm->tm_year = (tm->tm_year % 100) + (i * 100); + } else { + tm->tm_year = i * 100; + split_year = 1; + } + break; + + case 'd': /* The day of month. */ + case 'e': + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_mday, 1, 31))) + return (0); + break; + + case 'k': /* The hour (24-hour clock representation). */ + LEGAL_ALT(0); + /* FALLTHROUGH */ + case 'H': + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_hour, 0, 23))) + return (0); + break; + + case 'l': /* The hour (12-hour clock representation). */ + LEGAL_ALT(0); + /* FALLTHROUGH */ + case 'I': + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_hour, 1, 12))) + return (0); + if (tm->tm_hour == 12) + tm->tm_hour = 0; + break; + + case 'j': /* The day of year. */ + LEGAL_ALT(0); + if (!(conv_num(&bp, &i, 1, 366))) + return (0); + tm->tm_yday = i - 1; + break; + + case 'M': /* The minute. */ + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_min, 0, 59))) + return (0); + break; + + case 'm': /* The month. */ + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &i, 1, 12))) + return (0); + tm->tm_mon = i - 1; + break; + + case 'p': /* The locale's equivalent of AM/PM. */ + LEGAL_ALT(0); + /* AM? */ + if (arg_strcasecmp(am_pm[0], bp) == 0) { + if (tm->tm_hour > 11) + return (0); + + bp += strlen(am_pm[0]); + break; + } + /* PM? */ + else if (arg_strcasecmp(am_pm[1], bp) == 0) { + if (tm->tm_hour > 11) + return (0); + + tm->tm_hour += 12; + bp += strlen(am_pm[1]); + break; + } + + /* Nothing matched. */ + return (0); + + case 'S': /* The seconds. */ + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_sec, 0, 61))) + return (0); + break; + + case 'U': /* The week of year, beginning on sunday. */ + case 'W': /* The week of year, beginning on monday. */ + LEGAL_ALT(ALT_O); + /* + * XXX This is bogus, as we can not assume any valid + * information present in the tm structure at this + * point to calculate a real value, so just check the + * range for now. + */ + if (!(conv_num(&bp, &i, 0, 53))) + return (0); + break; + + case 'w': /* The day of week, beginning on sunday. */ + LEGAL_ALT(ALT_O); + if (!(conv_num(&bp, &tm->tm_wday, 0, 6))) + return (0); + break; + + case 'Y': /* The year. */ + LEGAL_ALT(ALT_E); + if (!(conv_num(&bp, &i, 0, 9999))) + return (0); + + tm->tm_year = i - TM_YEAR_BASE; + break; + + case 'y': /* The year within 100 years of the epoch. */ + LEGAL_ALT(ALT_E | ALT_O); + if (!(conv_num(&bp, &i, 0, 99))) + return (0); + + if (split_year) { + tm->tm_year = ((tm->tm_year / 100) * 100) + i; + break; + } + split_year = 1; + if (i <= 68) + tm->tm_year = i + 2000 - TM_YEAR_BASE; + else + tm->tm_year = i + 1900 - TM_YEAR_BASE; + break; + + /* + * Miscellaneous conversions. + */ + case 'n': /* Any kind of white-space. */ + case 't': + LEGAL_ALT(0); + while (isspace(*bp)) + bp++; + break; + + + default: /* Unknown/unsupported conversion. */ + return (0); + } + + + } + + /* LINTED functional specification */ + return ((char *)bp); +} + + +static int conv_num(const char * *buf, int *dest, int llim, int ulim) +{ + int result = 0; + + /* The limit also determines the number of valid digits. */ + int rulim = ulim; + + if (**buf < '0' || **buf > '9') + return (0); + + do { + result *= 10; + result += *(*buf)++ - '0'; + rulim /= 10; + } while ((result * 10 <= ulim) && rulim && **buf >= '0' && **buf <= '9'); + + if (result < llim || result > ulim) + return (0); + + *dest = result; + return (1); +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include + +#include "argtable3.h" + + +static void arg_dbl_resetfn(struct arg_dbl *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +static int arg_dbl_scanfn(struct arg_dbl *parent, const char *argval) +{ + int errorcode = 0; + + if (parent->count == parent->hdr.maxcount) + { + /* maximum number of arguments exceeded */ + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* a valid argument with no argument value was given. */ + /* This happens when an optional argument value was invoked. */ + /* leave parent argument value unaltered but still count the argument. */ + parent->count++; + } + else + { + double val; + char *end; + + /* extract double from argval into val */ + val = strtod(argval, &end); + + /* if success then store result in parent->dval[] array otherwise return error*/ + if (*end == 0) + parent->dval[parent->count++] = val; + else + errorcode = EBADDOUBLE; + } + + ARG_TRACE(("%s:scanfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static int arg_dbl_checkfn(struct arg_dbl *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + + ARG_TRACE(("%s:checkfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static void arg_dbl_errorfn( + struct arg_dbl *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + case EBADDOUBLE: + fprintf(fp, "invalid argument \"%s\" to option ", argval); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + } +} + + +struct arg_dbl * arg_dbl0( + const char * shortopts, + const char * longopts, + const char *datatype, + const char *glossary) +{ + return arg_dbln(shortopts, longopts, datatype, 0, 1, glossary); +} + + +struct arg_dbl * arg_dbl1( + const char * shortopts, + const char * longopts, + const char *datatype, + const char *glossary) +{ + return arg_dbln(shortopts, longopts, datatype, 1, 1, glossary); +} + + +struct arg_dbl * arg_dbln( + const char * shortopts, + const char * longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary) +{ + size_t nbytes; + struct arg_dbl *result; + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + nbytes = sizeof(struct arg_dbl) /* storage for struct arg_dbl */ + + (maxcount + 1) * sizeof(double); /* storage for dval[maxcount] array plus one extra for padding to memory boundary */ + + result = (struct arg_dbl *)malloc(nbytes); + if (result) + { + size_t addr; + size_t rem; + + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = datatype ? datatype : ""; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_dbl_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_dbl_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_dbl_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_dbl_errorfn; + + /* Store the dval[maxcount] array on the first double boundary that + * immediately follows the arg_dbl struct. We do the memory alignment + * purely for SPARC and Motorola systems. They require floats and + * doubles to be aligned on natural boundaries. + */ + addr = (size_t)(result + 1); + rem = addr % sizeof(double); + result->dval = (double *)(addr + sizeof(double) - rem); + ARG_TRACE(("addr=%p, dval=%p, sizeof(double)=%d rem=%d\n", addr, result->dval, (int)sizeof(double), (int)rem)); + + result->count = 0; + } + + ARG_TRACE(("arg_dbln() returns %p\n", result)); + return result; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include + +#include "argtable3.h" + + +static void arg_end_resetfn(struct arg_end *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + +static void arg_end_errorfn( + void *parent, + FILE *fp, + int error, + const char *argval, + const char *progname) +{ + /* suppress unreferenced formal parameter warning */ + (void)parent; + + progname = progname ? progname : ""; + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(error) + { + case ARG_ELIMIT: + fputs("too many errors to display", fp); + break; + case ARG_EMALLOC: + fputs("insufficent memory", fp); + break; + case ARG_ENOMATCH: + fprintf(fp, "unexpected argument \"%s\"", argval); + break; + case ARG_EMISSARG: + fprintf(fp, "option \"%s\" requires an argument", argval); + break; + case ARG_ELONGOPT: + fprintf(fp, "invalid option \"%s\"", argval); + break; + default: + fprintf(fp, "invalid option \"-%c\"", error); + break; + } + + fputc('\n', fp); +} + + +struct arg_end * arg_end(int maxcount) +{ + size_t nbytes; + struct arg_end *result; + + nbytes = sizeof(struct arg_end) + + maxcount * sizeof(int) /* storage for int error[maxcount] array*/ + + maxcount * sizeof(void *) /* storage for void* parent[maxcount] array */ + + maxcount * sizeof(char *); /* storage for char* argval[maxcount] array */ + + result = (struct arg_end *)malloc(nbytes); + if (result) + { + /* init the arg_hdr struct */ + result->hdr.flag = ARG_TERMINATOR; + result->hdr.shortopts = NULL; + result->hdr.longopts = NULL; + result->hdr.datatype = NULL; + result->hdr.glossary = NULL; + result->hdr.mincount = 1; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_end_resetfn; + result->hdr.scanfn = NULL; + result->hdr.checkfn = NULL; + result->hdr.errorfn = (arg_errorfn *)arg_end_errorfn; + + /* store error[maxcount] array immediately after struct arg_end */ + result->error = (int *)(result + 1); + + /* store parent[maxcount] array immediately after error[] array */ + result->parent = (void * *)(result->error + maxcount ); + + /* store argval[maxcount] array immediately after parent[] array */ + result->argval = (const char * *)(result->parent + maxcount ); + } + + ARG_TRACE(("arg_end(%d) returns %p\n", maxcount, result)); + return result; +} + + +void arg_print_errors(FILE * fp, struct arg_end * end, const char * progname) +{ + int i; + ARG_TRACE(("arg_errors()\n")); + for (i = 0; i < end->count; i++) + { + struct arg_hdr *errorparent = (struct arg_hdr *)(end->parent[i]); + if (errorparent->errorfn) + errorparent->errorfn(end->parent[i], + fp, + end->error[i], + end->argval[i], + progname); + } +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include + +#include "argtable3.h" + +#ifdef WIN32 +# define FILESEPARATOR1 '\\' +# define FILESEPARATOR2 '/' +#else +# define FILESEPARATOR1 '/' +# define FILESEPARATOR2 '/' +#endif + + +static void arg_file_resetfn(struct arg_file *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +/* Returns ptr to the base filename within *filename */ +static const char * arg_basename(const char *filename) +{ + const char *result = NULL, *result1, *result2; + + /* Find the last occurrence of eother file separator character. */ + /* Two alternative file separator chars are supported as legal */ + /* file separators but not both together in the same filename. */ + result1 = (filename ? strrchr(filename, FILESEPARATOR1) : NULL); + result2 = (filename ? strrchr(filename, FILESEPARATOR2) : NULL); + + if (result2) + result = result2 + 1; /* using FILESEPARATOR2 (the alternative file separator) */ + + if (result1) + result = result1 + 1; /* using FILESEPARATOR1 (the preferred file separator) */ + + if (!result) + result = filename; /* neither file separator was found so basename is the whole filename */ + + /* special cases of "." and ".." are not considered basenames */ + if (result && ( strcmp(".", result) == 0 || strcmp("..", result) == 0 )) + result = filename + strlen(filename); + + return result; +} + + +/* Returns ptr to the file extension within *basename */ +static const char * arg_extension(const char *basename) +{ + /* find the last occurrence of '.' in basename */ + const char *result = (basename ? strrchr(basename, '.') : NULL); + + /* if no '.' was found then return pointer to end of basename */ + if (basename && !result) + result = basename + strlen(basename); + + /* special case: basenames with a single leading dot (eg ".foo") are not considered as true extensions */ + if (basename && result == basename) + result = basename + strlen(basename); + + /* special case: empty extensions (eg "foo.","foo..") are not considered as true extensions */ + if (basename && result && result[1] == '\0') + result = basename + strlen(basename); + + return result; +} + + +static int arg_file_scanfn(struct arg_file *parent, const char *argval) +{ + int errorcode = 0; + + if (parent->count == parent->hdr.maxcount) + { + /* maximum number of arguments exceeded */ + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* a valid argument with no argument value was given. */ + /* This happens when an optional argument value was invoked. */ + /* leave parent arguiment value unaltered but still count the argument. */ + parent->count++; + } + else + { + parent->filename[parent->count] = argval; + parent->basename[parent->count] = arg_basename(argval); + parent->extension[parent->count] = + arg_extension(parent->basename[parent->count]); /* only seek extensions within the basename (not the file path)*/ + parent->count++; + } + + ARG_TRACE(("%s4:scanfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static int arg_file_checkfn(struct arg_file *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + + ARG_TRACE(("%s:checkfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static void arg_file_errorfn( + struct arg_file *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + default: + fprintf(fp, "unknown error at \"%s\"\n", argval); + } +} + + +struct arg_file * arg_file0( + const char * shortopts, + const char * longopts, + const char *datatype, + const char *glossary) +{ + return arg_filen(shortopts, longopts, datatype, 0, 1, glossary); +} + + +struct arg_file * arg_file1( + const char * shortopts, + const char * longopts, + const char *datatype, + const char *glossary) +{ + return arg_filen(shortopts, longopts, datatype, 1, 1, glossary); +} + + +struct arg_file * arg_filen( + const char * shortopts, + const char * longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary) +{ + size_t nbytes; + struct arg_file *result; + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + nbytes = sizeof(struct arg_file) /* storage for struct arg_file */ + + sizeof(char *) * maxcount /* storage for filename[maxcount] array */ + + sizeof(char *) * maxcount /* storage for basename[maxcount] array */ + + sizeof(char *) * maxcount; /* storage for extension[maxcount] array */ + + result = (struct arg_file *)malloc(nbytes); + if (result) + { + int i; + + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.glossary = glossary; + result->hdr.datatype = datatype ? datatype : ""; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_file_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_file_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_file_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_file_errorfn; + + /* store the filename,basename,extension arrays immediately after the arg_file struct */ + result->filename = (const char * *)(result + 1); + result->basename = result->filename + maxcount; + result->extension = result->basename + maxcount; + result->count = 0; + + /* foolproof the string pointers by initialising them with empty strings */ + for (i = 0; i < maxcount; i++) + { + result->filename[i] = ""; + result->basename[i] = ""; + result->extension[i] = ""; + } + } + + ARG_TRACE(("arg_filen() returns %p\n", result)); + return result; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include +#include + +#include "argtable3.h" + + +static void arg_int_resetfn(struct arg_int *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +/* strtol0x() is like strtol() except that the numeric string is */ +/* expected to be prefixed by "0X" where X is a user supplied char. */ +/* The string may optionally be prefixed by white space and + or - */ +/* as in +0X123 or -0X123. */ +/* Once the prefix has been scanned, the remainder of the numeric */ +/* string is converted using strtol() with the given base. */ +/* eg: to parse hex str="-0X12324", specify X='X' and base=16. */ +/* eg: to parse oct str="+0o12324", specify X='O' and base=8. */ +/* eg: to parse bin str="-0B01010", specify X='B' and base=2. */ +/* Failure of conversion is indicated by result where *endptr==str. */ +static long int strtol0X(const char * str, + const char * *endptr, + char X, + int base) +{ + long int val; /* stores result */ + int s = 1; /* sign is +1 or -1 */ + const char *ptr = str; /* ptr to current position in str */ + + /* skip leading whitespace */ + while (isspace(*ptr)) + ptr++; + /* printf("1) %s\n",ptr); */ + + /* scan optional sign character */ + switch (*ptr) + { + case '+': + ptr++; + s = 1; + break; + case '-': + ptr++; + s = -1; + break; + default: + s = 1; + break; + } + /* printf("2) %s\n",ptr); */ + + /* '0X' prefix */ + if ((*ptr++) != '0') + { + /* printf("failed to detect '0'\n"); */ + *endptr = str; + return 0; + } + /* printf("3) %s\n",ptr); */ + if (toupper(*ptr++) != toupper(X)) + { + /* printf("failed to detect '%c'\n",X); */ + *endptr = str; + return 0; + } + /* printf("4) %s\n",ptr); */ + + /* attempt conversion on remainder of string using strtol() */ + val = strtol(ptr, (char * *)endptr, base); + if (*endptr == ptr) + { + /* conversion failed */ + *endptr = str; + return 0; + } + + /* success */ + return s * val; +} + + +/* Returns 1 if str matches suffix (case insensitive). */ +/* Str may contain trailing whitespace, but nothing else. */ +static int detectsuffix(const char *str, const char *suffix) +{ + /* scan pairwise through strings until mismatch detected */ + while( toupper(*str) == toupper(*suffix) ) + { + /* printf("'%c' '%c'\n", *str, *suffix); */ + + /* return 1 (success) if match persists until the string terminator */ + if (*str == '\0') + return 1; + + /* next chars */ + str++; + suffix++; + } + /* printf("'%c' '%c' mismatch\n", *str, *suffix); */ + + /* return 0 (fail) if the matching did not consume the entire suffix */ + if (*suffix != 0) + return 0; /* failed to consume entire suffix */ + + /* skip any remaining whitespace in str */ + while (isspace(*str)) + str++; + + /* return 1 (success) if we have reached end of str else return 0 (fail) */ + return (*str == '\0') ? 1 : 0; +} + + +static int arg_int_scanfn(struct arg_int *parent, const char *argval) +{ + int errorcode = 0; + + if (parent->count == parent->hdr.maxcount) + { + /* maximum number of arguments exceeded */ + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* a valid argument with no argument value was given. */ + /* This happens when an optional argument value was invoked. */ + /* leave parent arguiment value unaltered but still count the argument. */ + parent->count++; + } + else + { + long int val; + const char *end; + + /* attempt to extract hex integer (eg: +0x123) from argval into val conversion */ + val = strtol0X(argval, &end, 'X', 16); + if (end == argval) + { + /* hex failed, attempt octal conversion (eg +0o123) */ + val = strtol0X(argval, &end, 'O', 8); + if (end == argval) + { + /* octal failed, attempt binary conversion (eg +0B101) */ + val = strtol0X(argval, &end, 'B', 2); + if (end == argval) + { + /* binary failed, attempt decimal conversion with no prefix (eg 1234) */ + val = strtol(argval, (char * *)&end, 10); + if (end == argval) + { + /* all supported number formats failed */ + return EBADINT; + } + } + } + } + + /* Safety check for integer overflow. WARNING: this check */ + /* achieves nothing on machines where size(int)==size(long). */ + if ( val > INT_MAX || val < INT_MIN ) + errorcode = EOVERFLOW; + + /* Detect any suffixes (KB,MB,GB) and multiply argument value appropriately. */ + /* We need to be mindful of integer overflows when using such big numbers. */ + if (detectsuffix(end, "KB")) /* kilobytes */ + { + if ( val > (INT_MAX / 1024) || val < (INT_MIN / 1024) ) + errorcode = EOVERFLOW; /* Overflow would occur if we proceed */ + else + val *= 1024; /* 1KB = 1024 */ + } + else if (detectsuffix(end, "MB")) /* megabytes */ + { + if ( val > (INT_MAX / 1048576) || val < (INT_MIN / 1048576) ) + errorcode = EOVERFLOW; /* Overflow would occur if we proceed */ + else + val *= 1048576; /* 1MB = 1024*1024 */ + } + else if (detectsuffix(end, "GB")) /* gigabytes */ + { + if ( val > (INT_MAX / 1073741824) || val < (INT_MIN / 1073741824) ) + errorcode = EOVERFLOW; /* Overflow would occur if we proceed */ + else + val *= 1073741824; /* 1GB = 1024*1024*1024 */ + } + else if (!detectsuffix(end, "")) + errorcode = EBADINT; /* invalid suffix detected */ + + /* if success then store result in parent->ival[] array */ + if (errorcode == 0) + parent->ival[parent->count++] = val; + } + + /* printf("%s:scanfn(%p,%p) returns %d\n",__FILE__,parent,argval,errorcode); */ + return errorcode; +} + + +static int arg_int_checkfn(struct arg_int *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + /*printf("%s:checkfn(%p) returns %d\n",__FILE__,parent,errorcode);*/ + return errorcode; +} + + +static void arg_int_errorfn( + struct arg_int *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + case EBADINT: + fprintf(fp, "invalid argument \"%s\" to option ", argval); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EOVERFLOW: + fputs("integer overflow at option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, " "); + fprintf(fp, "(%s is too large)\n", argval); + break; + } +} + + +struct arg_int * arg_int0( + const char *shortopts, + const char *longopts, + const char *datatype, + const char *glossary) +{ + return arg_intn(shortopts, longopts, datatype, 0, 1, glossary); +} + + +struct arg_int * arg_int1( + const char *shortopts, + const char *longopts, + const char *datatype, + const char *glossary) +{ + return arg_intn(shortopts, longopts, datatype, 1, 1, glossary); +} + + +struct arg_int * arg_intn( + const char *shortopts, + const char *longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary) +{ + size_t nbytes; + struct arg_int *result; + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + nbytes = sizeof(struct arg_int) /* storage for struct arg_int */ + + maxcount * sizeof(int); /* storage for ival[maxcount] array */ + + result = (struct arg_int *)malloc(nbytes); + if (result) + { + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = datatype ? datatype : ""; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_int_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_int_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_int_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_int_errorfn; + + /* store the ival[maxcount] array immediately after the arg_int struct */ + result->ival = (int *)(result + 1); + result->count = 0; + } + + ARG_TRACE(("arg_intn() returns %p\n", result)); + return result; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include + +#include "argtable3.h" + + +static void arg_lit_resetfn(struct arg_lit *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +static int arg_lit_scanfn(struct arg_lit *parent, const char *argval) +{ + int errorcode = 0; + if (parent->count < parent->hdr.maxcount ) + parent->count++; + else + errorcode = EMAXCOUNT; + + ARG_TRACE(("%s:scanfn(%p,%s) returns %d\n", __FILE__, parent, argval, + errorcode)); + return errorcode; +} + + +static int arg_lit_checkfn(struct arg_lit *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + ARG_TRACE(("%s:checkfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static void arg_lit_errorfn( + struct arg_lit *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + switch(errorcode) + { + case EMINCOUNT: + fprintf(fp, "%s: missing option ", progname); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + fprintf(fp, "\n"); + break; + + case EMAXCOUNT: + fprintf(fp, "%s: extraneous option ", progname); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + } + + ARG_TRACE(("%s:errorfn(%p, %p, %d, %s, %s)\n", __FILE__, parent, fp, + errorcode, argval, progname)); +} + + +struct arg_lit * arg_lit0( + const char * shortopts, + const char * longopts, + const char * glossary) +{ + return arg_litn(shortopts, longopts, 0, 1, glossary); +} + + +struct arg_lit * arg_lit1( + const char *shortopts, + const char *longopts, + const char *glossary) +{ + return arg_litn(shortopts, longopts, 1, 1, glossary); +} + + +struct arg_lit * arg_litn( + const char *shortopts, + const char *longopts, + int mincount, + int maxcount, + const char *glossary) +{ + struct arg_lit *result; + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + result = (struct arg_lit *)malloc(sizeof(struct arg_lit)); + if (result) + { + /* init the arg_hdr struct */ + result->hdr.flag = 0; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = NULL; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_lit_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_lit_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_lit_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_lit_errorfn; + + /* init local variables */ + result->count = 0; + } + + ARG_TRACE(("arg_litn() returns %p\n", result)); + return result; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include + +#include "argtable3.h" + +struct arg_rem *arg_rem(const char *datatype, const char *glossary) +{ + struct arg_rem *result = (struct arg_rem *)malloc(sizeof(struct arg_rem)); + if (result) + { + result->hdr.flag = 0; + result->hdr.shortopts = NULL; + result->hdr.longopts = NULL; + result->hdr.datatype = datatype; + result->hdr.glossary = glossary; + result->hdr.mincount = 1; + result->hdr.maxcount = 1; + result->hdr.parent = result; + result->hdr.resetfn = NULL; + result->hdr.scanfn = NULL; + result->hdr.checkfn = NULL; + result->hdr.errorfn = NULL; + } + + ARG_TRACE(("arg_rem() returns %p\n", result)); + return result; +} + +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include + +#include "argtable3.h" + + +#ifndef _TREX_H_ +#define _TREX_H_ +/*************************************************************** + T-Rex a tiny regular expression library + + Copyright (C) 2003-2006 Alberto Demichelis + + This software is provided 'as-is', without any express + or implied warranty. In no event will the authors be held + liable for any damages arising from the use of this software. + + Permission is granted to anyone to use this software for + any purpose, including commercial applications, and to alter + it and redistribute it freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; + you must not claim that you wrote the original software. + If you use this software in a product, an acknowledgment + in the product documentation would be appreciated but + is not required. + + 2. Altered source versions must be plainly marked as such, + and must not be misrepresented as being the original software. + + 3. This notice may not be removed or altered from any + source distribution. + +****************************************************************/ + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _UNICODE +#define TRexChar unsigned short +#define MAX_CHAR 0xFFFF +#define _TREXC(c) L##c +#define trex_strlen wcslen +#define trex_printf wprintf +#else +#define TRexChar char +#define MAX_CHAR 0xFF +#define _TREXC(c) (c) +#define trex_strlen strlen +#define trex_printf printf +#endif + +#ifndef TREX_API +#define TREX_API extern +#endif + +#define TRex_True 1 +#define TRex_False 0 + +#define TREX_ICASE ARG_REX_ICASE + +typedef unsigned int TRexBool; +typedef struct TRex TRex; + +typedef struct { + const TRexChar *begin; + int len; +} TRexMatch; + +TREX_API TRex *trex_compile(const TRexChar *pattern, const TRexChar **error, int flags); +TREX_API void trex_free(TRex *exp); +TREX_API TRexBool trex_match(TRex* exp, const TRexChar* text); +TREX_API TRexBool trex_search(TRex* exp, const TRexChar* text, const TRexChar** out_begin, const TRexChar** out_end); +TREX_API TRexBool trex_searchrange(TRex* exp, const TRexChar* text_begin, const TRexChar* text_end, const TRexChar** out_begin, const TRexChar** out_end); +TREX_API int trex_getsubexpcount(TRex* exp); +TREX_API TRexBool trex_getsubexp(TRex* exp, int n, TRexMatch *subexp); + +#ifdef __cplusplus +} +#endif + +#endif + + + +struct privhdr +{ + const char *pattern; + int flags; +}; + + +static void arg_rex_resetfn(struct arg_rex *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + +static int arg_rex_scanfn(struct arg_rex *parent, const char *argval) +{ + int errorcode = 0; + const TRexChar *error = NULL; + TRex *rex = NULL; + TRexBool is_match = TRex_False; + + if (parent->count == parent->hdr.maxcount ) + { + /* maximum number of arguments exceeded */ + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* a valid argument with no argument value was given. */ + /* This happens when an optional argument value was invoked. */ + /* leave parent argument value unaltered but still count the argument. */ + parent->count++; + } + else + { + struct privhdr *priv = (struct privhdr *)parent->hdr.priv; + + /* test the current argument value for a match with the regular expression */ + /* if a match is detected, record the argument value in the arg_rex struct */ + + rex = trex_compile(priv->pattern, &error, priv->flags); + is_match = trex_match(rex, argval); + if (!is_match) + errorcode = EREGNOMATCH; + else + parent->sval[parent->count++] = argval; + + trex_free(rex); + } + + ARG_TRACE(("%s:scanfn(%p) returns %d\n",__FILE__,parent,errorcode)); + return errorcode; +} + +static int arg_rex_checkfn(struct arg_rex *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + //struct privhdr *priv = (struct privhdr*)parent->hdr.priv; + + /* free the regex "program" we constructed in resetfn */ + //regfree(&(priv->regex)); + + /*printf("%s:checkfn(%p) returns %d\n",__FILE__,parent,errorcode);*/ + return errorcode; +} + +static void arg_rex_errorfn(struct arg_rex *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + case EREGNOMATCH: + fputs("illegal value ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + + default: + { + //char errbuff[256]; + //regerror(errorcode, NULL, errbuff, sizeof(errbuff)); + //printf("%s\n", errbuff); + } + break; + } +} + + +struct arg_rex * arg_rex0(const char * shortopts, + const char * longopts, + const char * pattern, + const char *datatype, + int flags, + const char *glossary) +{ + return arg_rexn(shortopts, + longopts, + pattern, + datatype, + 0, + 1, + flags, + glossary); +} + +struct arg_rex * arg_rex1(const char * shortopts, + const char * longopts, + const char * pattern, + const char *datatype, + int flags, + const char *glossary) +{ + return arg_rexn(shortopts, + longopts, + pattern, + datatype, + 1, + 1, + flags, + glossary); +} + + +struct arg_rex * arg_rexn(const char * shortopts, + const char * longopts, + const char * pattern, + const char *datatype, + int mincount, + int maxcount, + int flags, + const char *glossary) +{ + size_t nbytes; + struct arg_rex *result; + struct privhdr *priv; + int i; + const TRexChar *error = NULL; + TRex *rex = NULL; + + if (!pattern) + { + printf( + "argtable: ERROR - illegal regular expression pattern \"(NULL)\"\n"); + printf("argtable: Bad argument table.\n"); + return NULL; + } + + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + nbytes = sizeof(struct arg_rex) /* storage for struct arg_rex */ + + sizeof(struct privhdr) /* storage for private arg_rex data */ + + maxcount * sizeof(char *); /* storage for sval[maxcount] array */ + + result = (struct arg_rex *)malloc(nbytes); + if (result == NULL) + return result; + + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = datatype ? datatype : pattern; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_rex_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_rex_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_rex_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_rex_errorfn; + + /* store the arg_rex_priv struct immediately after the arg_rex struct */ + result->hdr.priv = result + 1; + priv = (struct privhdr *)(result->hdr.priv); + priv->pattern = pattern; + priv->flags = flags; + + /* store the sval[maxcount] array immediately after the arg_rex_priv struct */ + result->sval = (const char * *)(priv + 1); + result->count = 0; + + /* foolproof the string pointers by initializing them to reference empty strings */ + for (i = 0; i < maxcount; i++) + result->sval[i] = ""; + + /* here we construct and destroy a regex representation of the regular + * expression for no other reason than to force any regex errors to be + * trapped now rather than later. If we don't, then errors may go undetected + * until an argument is actually parsed. + */ + + rex = trex_compile(priv->pattern, &error, priv->flags); + if (rex == NULL) + { + ARG_LOG(("argtable: %s \"%s\"\n", error ? error : _TREXC("undefined"), priv->pattern)); + ARG_LOG(("argtable: Bad argument table.\n")); + } + + trex_free(rex); + + ARG_TRACE(("arg_rexn() returns %p\n", result)); + return result; +} + + + +/* see copyright notice in trex.h */ +#include +#include +#include +#include + +#ifdef _UINCODE +#define scisprint iswprint +#define scstrlen wcslen +#define scprintf wprintf +#define _SC(x) L(x) +#else +#define scisprint isprint +#define scstrlen strlen +#define scprintf printf +#define _SC(x) (x) +#endif + +#ifdef _DEBUG +#include + +static const TRexChar *g_nnames[] = +{ + _SC("NONE"),_SC("OP_GREEDY"), _SC("OP_OR"), + _SC("OP_EXPR"),_SC("OP_NOCAPEXPR"),_SC("OP_DOT"), _SC("OP_CLASS"), + _SC("OP_CCLASS"),_SC("OP_NCLASS"),_SC("OP_RANGE"),_SC("OP_CHAR"), + _SC("OP_EOL"),_SC("OP_BOL"),_SC("OP_WB") +}; + +#endif +#define OP_GREEDY (MAX_CHAR+1) // * + ? {n} +#define OP_OR (MAX_CHAR+2) +#define OP_EXPR (MAX_CHAR+3) //parentesis () +#define OP_NOCAPEXPR (MAX_CHAR+4) //parentesis (?:) +#define OP_DOT (MAX_CHAR+5) +#define OP_CLASS (MAX_CHAR+6) +#define OP_CCLASS (MAX_CHAR+7) +#define OP_NCLASS (MAX_CHAR+8) //negates class the [^ +#define OP_RANGE (MAX_CHAR+9) +#define OP_CHAR (MAX_CHAR+10) +#define OP_EOL (MAX_CHAR+11) +#define OP_BOL (MAX_CHAR+12) +#define OP_WB (MAX_CHAR+13) + +#define TREX_SYMBOL_ANY_CHAR ('.') +#define TREX_SYMBOL_GREEDY_ONE_OR_MORE ('+') +#define TREX_SYMBOL_GREEDY_ZERO_OR_MORE ('*') +#define TREX_SYMBOL_GREEDY_ZERO_OR_ONE ('?') +#define TREX_SYMBOL_BRANCH ('|') +#define TREX_SYMBOL_END_OF_STRING ('$') +#define TREX_SYMBOL_BEGINNING_OF_STRING ('^') +#define TREX_SYMBOL_ESCAPE_CHAR ('\\') + + +typedef int TRexNodeType; + +typedef struct tagTRexNode{ + TRexNodeType type; + int left; + int right; + int next; +}TRexNode; + +struct TRex{ + const TRexChar *_eol; + const TRexChar *_bol; + const TRexChar *_p; + int _first; + int _op; + TRexNode *_nodes; + int _nallocated; + int _nsize; + int _nsubexpr; + TRexMatch *_matches; + int _currsubexp; + void *_jmpbuf; + const TRexChar **_error; + int _flags; +}; + +static int trex_list(TRex *exp); + +static int trex_newnode(TRex *exp, TRexNodeType type) +{ + TRexNode n; + int newid; + n.type = type; + n.next = n.right = n.left = -1; + if(type == OP_EXPR) + n.right = exp->_nsubexpr++; + if(exp->_nallocated < (exp->_nsize + 1)) { + exp->_nallocated *= 2; + exp->_nodes = (TRexNode *)realloc(exp->_nodes, exp->_nallocated * sizeof(TRexNode)); + } + exp->_nodes[exp->_nsize++] = n; + newid = exp->_nsize - 1; + return (int)newid; +} + +static void trex_error(TRex *exp,const TRexChar *error) +{ + if(exp->_error) *exp->_error = error; + longjmp(*((jmp_buf*)exp->_jmpbuf),-1); +} + +static void trex_expect(TRex *exp, int n){ + if((*exp->_p) != n) + trex_error(exp, _SC("expected paren")); + exp->_p++; +} + +static TRexChar trex_escapechar(TRex *exp) +{ + if(*exp->_p == TREX_SYMBOL_ESCAPE_CHAR){ + exp->_p++; + switch(*exp->_p) { + case 'v': exp->_p++; return '\v'; + case 'n': exp->_p++; return '\n'; + case 't': exp->_p++; return '\t'; + case 'r': exp->_p++; return '\r'; + case 'f': exp->_p++; return '\f'; + default: return (*exp->_p++); + } + } else if(!scisprint(*exp->_p)) trex_error(exp,_SC("letter expected")); + return (*exp->_p++); +} + +static int trex_charclass(TRex *exp,int classid) +{ + int n = trex_newnode(exp,OP_CCLASS); + exp->_nodes[n].left = classid; + return n; +} + +static int trex_charnode(TRex *exp,TRexBool isclass) +{ + TRexChar t; + if(*exp->_p == TREX_SYMBOL_ESCAPE_CHAR) { + exp->_p++; + switch(*exp->_p) { + case 'n': exp->_p++; return trex_newnode(exp,'\n'); + case 't': exp->_p++; return trex_newnode(exp,'\t'); + case 'r': exp->_p++; return trex_newnode(exp,'\r'); + case 'f': exp->_p++; return trex_newnode(exp,'\f'); + case 'v': exp->_p++; return trex_newnode(exp,'\v'); + case 'a': case 'A': case 'w': case 'W': case 's': case 'S': + case 'd': case 'D': case 'x': case 'X': case 'c': case 'C': + case 'p': case 'P': case 'l': case 'u': + { + t = *exp->_p; exp->_p++; + return trex_charclass(exp,t); + } + case 'b': + case 'B': + if(!isclass) { + int node = trex_newnode(exp,OP_WB); + exp->_nodes[node].left = *exp->_p; + exp->_p++; + return node; + } //else default + default: + t = *exp->_p; exp->_p++; + return trex_newnode(exp,t); + } + } + else if(!scisprint(*exp->_p)) { + + trex_error(exp,_SC("letter expected")); + } + t = *exp->_p; exp->_p++; + return trex_newnode(exp,t); +} +static int trex_class(TRex *exp) +{ + int ret = -1; + int first = -1,chain; + if(*exp->_p == TREX_SYMBOL_BEGINNING_OF_STRING){ + ret = trex_newnode(exp,OP_NCLASS); + exp->_p++; + }else ret = trex_newnode(exp,OP_CLASS); + + if(*exp->_p == ']') trex_error(exp,_SC("empty class")); + chain = ret; + while(*exp->_p != ']' && exp->_p != exp->_eol) { + if(*exp->_p == '-' && first != -1){ + int r,t; + if(*exp->_p++ == ']') trex_error(exp,_SC("unfinished range")); + r = trex_newnode(exp,OP_RANGE); + if(first>*exp->_p) trex_error(exp,_SC("invalid range")); + if(exp->_nodes[first].type == OP_CCLASS) trex_error(exp,_SC("cannot use character classes in ranges")); + exp->_nodes[r].left = exp->_nodes[first].type; + t = trex_escapechar(exp); + exp->_nodes[r].right = t; + exp->_nodes[chain].next = r; + chain = r; + first = -1; + } + else{ + if(first!=-1){ + int c = first; + exp->_nodes[chain].next = c; + chain = c; + first = trex_charnode(exp,TRex_True); + } + else{ + first = trex_charnode(exp,TRex_True); + } + } + } + if(first!=-1){ + int c = first; + exp->_nodes[chain].next = c; + chain = c; + first = -1; + } + /* hack? */ + exp->_nodes[ret].left = exp->_nodes[ret].next; + exp->_nodes[ret].next = -1; + return ret; +} + +static int trex_parsenumber(TRex *exp) +{ + int ret = *exp->_p-'0'; + int positions = 10; + exp->_p++; + while(isdigit(*exp->_p)) { + ret = ret*10+(*exp->_p++-'0'); + if(positions==1000000000) trex_error(exp,_SC("overflow in numeric constant")); + positions *= 10; + }; + return ret; +} + +static int trex_element(TRex *exp) +{ + int ret = -1; + switch(*exp->_p) + { + case '(': { + int expr,newn; + exp->_p++; + + + if(*exp->_p =='?') { + exp->_p++; + trex_expect(exp,':'); + expr = trex_newnode(exp,OP_NOCAPEXPR); + } + else + expr = trex_newnode(exp,OP_EXPR); + newn = trex_list(exp); + exp->_nodes[expr].left = newn; + ret = expr; + trex_expect(exp,')'); + } + break; + case '[': + exp->_p++; + ret = trex_class(exp); + trex_expect(exp,']'); + break; + case TREX_SYMBOL_END_OF_STRING: exp->_p++; ret = trex_newnode(exp,OP_EOL);break; + case TREX_SYMBOL_ANY_CHAR: exp->_p++; ret = trex_newnode(exp,OP_DOT);break; + default: + ret = trex_charnode(exp,TRex_False); + break; + } + + { + TRexBool isgreedy = TRex_False; + unsigned short p0 = 0, p1 = 0; + switch(*exp->_p){ + case TREX_SYMBOL_GREEDY_ZERO_OR_MORE: p0 = 0; p1 = 0xFFFF; exp->_p++; isgreedy = TRex_True; break; + case TREX_SYMBOL_GREEDY_ONE_OR_MORE: p0 = 1; p1 = 0xFFFF; exp->_p++; isgreedy = TRex_True; break; + case TREX_SYMBOL_GREEDY_ZERO_OR_ONE: p0 = 0; p1 = 1; exp->_p++; isgreedy = TRex_True; break; + case '{': + exp->_p++; + if(!isdigit(*exp->_p)) trex_error(exp,_SC("number expected")); + p0 = (unsigned short)trex_parsenumber(exp); + /*******************************/ + switch(*exp->_p) { + case '}': + p1 = p0; exp->_p++; + break; + case ',': + exp->_p++; + p1 = 0xFFFF; + if(isdigit(*exp->_p)){ + p1 = (unsigned short)trex_parsenumber(exp); + } + trex_expect(exp,'}'); + break; + default: + trex_error(exp,_SC(", or } expected")); + } + /*******************************/ + isgreedy = TRex_True; + break; + + } + if(isgreedy) { + int nnode = trex_newnode(exp,OP_GREEDY); + exp->_nodes[nnode].left = ret; + exp->_nodes[nnode].right = ((p0)<<16)|p1; + ret = nnode; + } + } + if((*exp->_p != TREX_SYMBOL_BRANCH) && (*exp->_p != ')') && (*exp->_p != TREX_SYMBOL_GREEDY_ZERO_OR_MORE) && (*exp->_p != TREX_SYMBOL_GREEDY_ONE_OR_MORE) && (*exp->_p != '\0')) { + int nnode = trex_element(exp); + exp->_nodes[ret].next = nnode; + } + + return ret; +} + +static int trex_list(TRex *exp) +{ + int ret=-1,e; + if(*exp->_p == TREX_SYMBOL_BEGINNING_OF_STRING) { + exp->_p++; + ret = trex_newnode(exp,OP_BOL); + } + e = trex_element(exp); + if(ret != -1) { + exp->_nodes[ret].next = e; + } + else ret = e; + + if(*exp->_p == TREX_SYMBOL_BRANCH) { + int temp,tright; + exp->_p++; + temp = trex_newnode(exp,OP_OR); + exp->_nodes[temp].left = ret; + tright = trex_list(exp); + exp->_nodes[temp].right = tright; + ret = temp; + } + return ret; +} + +static TRexBool trex_matchcclass(int cclass,TRexChar c) +{ + switch(cclass) { + case 'a': return isalpha(c)?TRex_True:TRex_False; + case 'A': return !isalpha(c)?TRex_True:TRex_False; + case 'w': return (isalnum(c) || c == '_')?TRex_True:TRex_False; + case 'W': return (!isalnum(c) && c != '_')?TRex_True:TRex_False; + case 's': return isspace(c)?TRex_True:TRex_False; + case 'S': return !isspace(c)?TRex_True:TRex_False; + case 'd': return isdigit(c)?TRex_True:TRex_False; + case 'D': return !isdigit(c)?TRex_True:TRex_False; + case 'x': return isxdigit(c)?TRex_True:TRex_False; + case 'X': return !isxdigit(c)?TRex_True:TRex_False; + case 'c': return iscntrl(c)?TRex_True:TRex_False; + case 'C': return !iscntrl(c)?TRex_True:TRex_False; + case 'p': return ispunct(c)?TRex_True:TRex_False; + case 'P': return !ispunct(c)?TRex_True:TRex_False; + case 'l': return islower(c)?TRex_True:TRex_False; + case 'u': return isupper(c)?TRex_True:TRex_False; + } + return TRex_False; /*cannot happen*/ +} + +static TRexBool trex_matchclass(TRex* exp,TRexNode *node,TRexChar c) +{ + do { + switch(node->type) { + case OP_RANGE: + if (exp->_flags & TREX_ICASE) + { + if(c >= toupper(node->left) && c <= toupper(node->right)) return TRex_True; + if(c >= tolower(node->left) && c <= tolower(node->right)) return TRex_True; + } + else + { + if(c >= node->left && c <= node->right) return TRex_True; + } + break; + case OP_CCLASS: + if(trex_matchcclass(node->left,c)) return TRex_True; + break; + default: + if (exp->_flags & TREX_ICASE) + { + if (c == tolower(node->type) || c == toupper(node->type)) return TRex_True; + } + else + { + if(c == node->type)return TRex_True; + } + + } + } while((node->next != -1) && (node = &exp->_nodes[node->next])); + return TRex_False; +} + +static const TRexChar *trex_matchnode(TRex* exp,TRexNode *node,const TRexChar *str,TRexNode *next) +{ + + TRexNodeType type = node->type; + switch(type) { + case OP_GREEDY: { + //TRexNode *greedystop = (node->next != -1) ? &exp->_nodes[node->next] : NULL; + TRexNode *greedystop = NULL; + int p0 = (node->right >> 16)&0x0000FFFF, p1 = node->right&0x0000FFFF, nmaches = 0; + const TRexChar *s=str, *good = str; + + if(node->next != -1) { + greedystop = &exp->_nodes[node->next]; + } + else { + greedystop = next; + } + + while((nmaches == 0xFFFF || nmaches < p1)) { + + const TRexChar *stop; + if(!(s = trex_matchnode(exp,&exp->_nodes[node->left],s,greedystop))) + break; + nmaches++; + good=s; + if(greedystop) { + //checks that 0 matches satisfy the expression(if so skips) + //if not would always stop(for instance if is a '?') + if(greedystop->type != OP_GREEDY || + (greedystop->type == OP_GREEDY && ((greedystop->right >> 16)&0x0000FFFF) != 0)) + { + TRexNode *gnext = NULL; + if(greedystop->next != -1) { + gnext = &exp->_nodes[greedystop->next]; + }else if(next && next->next != -1){ + gnext = &exp->_nodes[next->next]; + } + stop = trex_matchnode(exp,greedystop,s,gnext); + if(stop) { + //if satisfied stop it + if(p0 == p1 && p0 == nmaches) break; + else if(nmaches >= p0 && p1 == 0xFFFF) break; + else if(nmaches >= p0 && nmaches <= p1) break; + } + } + } + + if(s >= exp->_eol) + break; + } + if(p0 == p1 && p0 == nmaches) return good; + else if(nmaches >= p0 && p1 == 0xFFFF) return good; + else if(nmaches >= p0 && nmaches <= p1) return good; + return NULL; + } + case OP_OR: { + const TRexChar *asd = str; + TRexNode *temp=&exp->_nodes[node->left]; + while( (asd = trex_matchnode(exp,temp,asd,NULL)) ) { + if(temp->next != -1) + temp = &exp->_nodes[temp->next]; + else + return asd; + } + asd = str; + temp = &exp->_nodes[node->right]; + while( (asd = trex_matchnode(exp,temp,asd,NULL)) ) { + if(temp->next != -1) + temp = &exp->_nodes[temp->next]; + else + return asd; + } + return NULL; + break; + } + case OP_EXPR: + case OP_NOCAPEXPR:{ + TRexNode *n = &exp->_nodes[node->left]; + const TRexChar *cur = str; + int capture = -1; + if(node->type != OP_NOCAPEXPR && node->right == exp->_currsubexp) { + capture = exp->_currsubexp; + exp->_matches[capture].begin = cur; + exp->_currsubexp++; + } + + do { + TRexNode *subnext = NULL; + if(n->next != -1) { + subnext = &exp->_nodes[n->next]; + }else { + subnext = next; + } + if(!(cur = trex_matchnode(exp,n,cur,subnext))) { + if(capture != -1){ + exp->_matches[capture].begin = 0; + exp->_matches[capture].len = 0; + } + return NULL; + } + } while((n->next != -1) && (n = &exp->_nodes[n->next])); + + if(capture != -1) + exp->_matches[capture].len = cur - exp->_matches[capture].begin; + return cur; + } + case OP_WB: + if((str == exp->_bol && !isspace(*str)) + || ((str == exp->_eol && !isspace(*(str-1)))) + || ((!isspace(*str) && isspace(*(str+1)))) + || ((isspace(*str) && !isspace(*(str+1)))) ) { + return (node->left == 'b')?str:NULL; + } + return (node->left == 'b')?NULL:str; + case OP_BOL: + if(str == exp->_bol) return str; + return NULL; + case OP_EOL: + if(str == exp->_eol) return str; + return NULL; + case OP_DOT: + str++; + return str; + case OP_NCLASS: + case OP_CLASS: + if(trex_matchclass(exp,&exp->_nodes[node->left],*str)?(type == OP_CLASS?TRex_True:TRex_False):(type == OP_NCLASS?TRex_True:TRex_False)) { + str++; + return str; + } + return NULL; + case OP_CCLASS: + if(trex_matchcclass(node->left,*str)) { + str++; + return str; + } + return NULL; + default: /* char */ + if (exp->_flags & TREX_ICASE) + { + if(*str != tolower(node->type) && *str != toupper(node->type)) return NULL; + } + else + { + if (*str != node->type) return NULL; + } + str++; + return str; + } + return NULL; +} + +/* public api */ +TRex *trex_compile(const TRexChar *pattern,const TRexChar **error,int flags) +{ + TRex *exp = (TRex *)malloc(sizeof(TRex)); + exp->_eol = exp->_bol = NULL; + exp->_p = pattern; + exp->_nallocated = (int)scstrlen(pattern) * sizeof(TRexChar); + exp->_nodes = (TRexNode *)malloc(exp->_nallocated * sizeof(TRexNode)); + exp->_nsize = 0; + exp->_matches = 0; + exp->_nsubexpr = 0; + exp->_first = trex_newnode(exp,OP_EXPR); + exp->_error = error; + exp->_jmpbuf = malloc(sizeof(jmp_buf)); + exp->_flags = flags; + if(setjmp(*((jmp_buf*)exp->_jmpbuf)) == 0) { + int res = trex_list(exp); + exp->_nodes[exp->_first].left = res; + if(*exp->_p!='\0') + trex_error(exp,_SC("unexpected character")); +#ifdef _DEBUG + { + int nsize,i; + TRexNode *t; + nsize = exp->_nsize; + t = &exp->_nodes[0]; + scprintf(_SC("\n")); + for(i = 0;i < nsize; i++) { + if(exp->_nodes[i].type>MAX_CHAR) + scprintf(_SC("[%02d] %10s "),i,g_nnames[exp->_nodes[i].type-MAX_CHAR]); + else + scprintf(_SC("[%02d] %10c "),i,exp->_nodes[i].type); + scprintf(_SC("left %02d right %02d next %02d\n"),exp->_nodes[i].left,exp->_nodes[i].right,exp->_nodes[i].next); + } + scprintf(_SC("\n")); + } +#endif + exp->_matches = (TRexMatch *) malloc(exp->_nsubexpr * sizeof(TRexMatch)); + memset(exp->_matches,0,exp->_nsubexpr * sizeof(TRexMatch)); + } + else{ + trex_free(exp); + return NULL; + } + return exp; +} + +void trex_free(TRex *exp) +{ + if(exp) { + if(exp->_nodes) free(exp->_nodes); + if(exp->_jmpbuf) free(exp->_jmpbuf); + if(exp->_matches) free(exp->_matches); + free(exp); + } +} + +TRexBool trex_match(TRex* exp,const TRexChar* text) +{ + const TRexChar* res = NULL; + exp->_bol = text; + exp->_eol = text + scstrlen(text); + exp->_currsubexp = 0; + res = trex_matchnode(exp,exp->_nodes,text,NULL); + if(res == NULL || res != exp->_eol) + return TRex_False; + return TRex_True; +} + +TRexBool trex_searchrange(TRex* exp,const TRexChar* text_begin,const TRexChar* text_end,const TRexChar** out_begin, const TRexChar** out_end) +{ + const TRexChar *cur = NULL; + int node = exp->_first; + if(text_begin >= text_end) return TRex_False; + exp->_bol = text_begin; + exp->_eol = text_end; + do { + cur = text_begin; + while(node != -1) { + exp->_currsubexp = 0; + cur = trex_matchnode(exp,&exp->_nodes[node],cur,NULL); + if(!cur) + break; + node = exp->_nodes[node].next; + } + text_begin++; + } while(cur == NULL && text_begin != text_end); + + if(cur == NULL) + return TRex_False; + + --text_begin; + + if(out_begin) *out_begin = text_begin; + if(out_end) *out_end = cur; + return TRex_True; +} + +TRexBool trex_search(TRex* exp,const TRexChar* text, const TRexChar** out_begin, const TRexChar** out_end) +{ + return trex_searchrange(exp,text,text + scstrlen(text),out_begin,out_end); +} + +int trex_getsubexpcount(TRex* exp) +{ + return exp->_nsubexpr; +} + +TRexBool trex_getsubexp(TRex* exp, int n, TRexMatch *subexp) +{ + if( n<0 || n >= exp->_nsubexpr) return TRex_False; + *subexp = exp->_matches[n]; + return TRex_True; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include + +#include "argtable3.h" + + +static void arg_str_resetfn(struct arg_str *parent) +{ + ARG_TRACE(("%s:resetfn(%p)\n", __FILE__, parent)); + parent->count = 0; +} + + +static int arg_str_scanfn(struct arg_str *parent, const char *argval) +{ + int errorcode = 0; + + if (parent->count == parent->hdr.maxcount) + { + /* maximum number of arguments exceeded */ + errorcode = EMAXCOUNT; + } + else if (!argval) + { + /* a valid argument with no argument value was given. */ + /* This happens when an optional argument value was invoked. */ + /* leave parent arguiment value unaltered but still count the argument. */ + parent->count++; + } + else + { + parent->sval[parent->count++] = argval; + } + + ARG_TRACE(("%s:scanfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static int arg_str_checkfn(struct arg_str *parent) +{ + int errorcode = (parent->count < parent->hdr.mincount) ? EMINCOUNT : 0; + + ARG_TRACE(("%s:checkfn(%p) returns %d\n", __FILE__, parent, errorcode)); + return errorcode; +} + + +static void arg_str_errorfn( + struct arg_str *parent, + FILE *fp, + int errorcode, + const char *argval, + const char *progname) +{ + const char *shortopts = parent->hdr.shortopts; + const char *longopts = parent->hdr.longopts; + const char *datatype = parent->hdr.datatype; + + /* make argval NULL safe */ + argval = argval ? argval : ""; + + fprintf(fp, "%s: ", progname); + switch(errorcode) + { + case EMINCOUNT: + fputs("missing option ", fp); + arg_print_option(fp, shortopts, longopts, datatype, "\n"); + break; + + case EMAXCOUNT: + fputs("excess option ", fp); + arg_print_option(fp, shortopts, longopts, argval, "\n"); + break; + } +} + + +struct arg_str * arg_str0( + const char *shortopts, + const char *longopts, + const char *datatype, + const char *glossary) +{ + return arg_strn(shortopts, longopts, datatype, 0, 1, glossary); +} + + +struct arg_str * arg_str1( + const char *shortopts, + const char *longopts, + const char *datatype, + const char *glossary) +{ + return arg_strn(shortopts, longopts, datatype, 1, 1, glossary); +} + + +struct arg_str * arg_strn( + const char *shortopts, + const char *longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary) +{ + size_t nbytes; + struct arg_str *result; + + /* should not allow this stupid error */ + /* we should return an error code warning this logic error */ + /* foolproof things by ensuring maxcount is not less than mincount */ + maxcount = (maxcount < mincount) ? mincount : maxcount; + + nbytes = sizeof(struct arg_str) /* storage for struct arg_str */ + + maxcount * sizeof(char *); /* storage for sval[maxcount] array */ + + result = (struct arg_str *)malloc(nbytes); + if (result) + { + int i; + + /* init the arg_hdr struct */ + result->hdr.flag = ARG_HASVALUE; + result->hdr.shortopts = shortopts; + result->hdr.longopts = longopts; + result->hdr.datatype = datatype ? datatype : ""; + result->hdr.glossary = glossary; + result->hdr.mincount = mincount; + result->hdr.maxcount = maxcount; + result->hdr.parent = result; + result->hdr.resetfn = (arg_resetfn *)arg_str_resetfn; + result->hdr.scanfn = (arg_scanfn *)arg_str_scanfn; + result->hdr.checkfn = (arg_checkfn *)arg_str_checkfn; + result->hdr.errorfn = (arg_errorfn *)arg_str_errorfn; + + /* store the sval[maxcount] array immediately after the arg_str struct */ + result->sval = (const char * *)(result + 1); + result->count = 0; + + /* foolproof the string pointers by initialising them to reference empty strings */ + for (i = 0; i < maxcount; i++) + result->sval[i] = ""; + } + + ARG_TRACE(("arg_strn() returns %p\n", result)); + return result; +} +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "argtable3.h" + +static +void arg_register_error(struct arg_end *end, + void *parent, + int error, + const char *argval) +{ + /* printf("arg_register_error(%p,%p,%d,%s)\n",end,parent,error,argval); */ + if (end->count < end->hdr.maxcount) + { + end->error[end->count] = error; + end->parent[end->count] = parent; + end->argval[end->count] = argval; + end->count++; + } + else + { + end->error[end->hdr.maxcount - 1] = ARG_ELIMIT; + end->parent[end->hdr.maxcount - 1] = end; + end->argval[end->hdr.maxcount - 1] = NULL; + } +} + + +/* + * Return index of first table entry with a matching short option + * or -1 if no match was found. + */ +static +int find_shortoption(struct arg_hdr * *table, char shortopt) +{ + int tabindex; + for(tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + if (table[tabindex]->shortopts && + strchr(table[tabindex]->shortopts, shortopt)) + return tabindex; + } + return -1; +} + + +struct longoptions +{ + int getoptval; + int noptions; + struct option *options; +}; + +#if 0 +static +void dump_longoptions(struct longoptions * longoptions) +{ + int i; + printf("getoptval = %d\n", longoptions->getoptval); + printf("noptions = %d\n", longoptions->noptions); + for (i = 0; i < longoptions->noptions; i++) + { + printf("options[%d].name = \"%s\"\n", + i, + longoptions->options[i].name); + printf("options[%d].has_arg = %d\n", i, longoptions->options[i].has_arg); + printf("options[%d].flag = %p\n", i, longoptions->options[i].flag); + printf("options[%d].val = %d\n", i, longoptions->options[i].val); + } +} +#endif + +static +struct longoptions * alloc_longoptions(struct arg_hdr * *table) +{ + struct longoptions *result; + size_t nbytes; + int noptions = 1; + size_t longoptlen = 0; + int tabindex; + + /* + * Determine the total number of option structs required + * by counting the number of comma separated long options + * in all table entries and return the count in noptions. + * note: noptions starts at 1 not 0 because we getoptlong + * requires a NULL option entry to terminate the option array. + * While we are at it, count the number of chars required + * to store private copies of all the longoption strings + * and return that count in logoptlen. + */ + tabindex = 0; + do + { + const char *longopts = table[tabindex]->longopts; + longoptlen += (longopts ? strlen(longopts) : 0) + 1; + while (longopts) + { + noptions++; + longopts = strchr(longopts + 1, ','); + } + } while(!(table[tabindex++]->flag & ARG_TERMINATOR)); + /*printf("%d long options consuming %d chars in total\n",noptions,longoptlen);*/ + + + /* allocate storage for return data structure as: */ + /* (struct longoptions) + (struct options)[noptions] + char[longoptlen] */ + nbytes = sizeof(struct longoptions) + + sizeof(struct option) * noptions + + longoptlen; + result = (struct longoptions *)malloc(nbytes); + if (result) + { + int option_index = 0; + char *store; + + result->getoptval = 0; + result->noptions = noptions; + result->options = (struct option *)(result + 1); + store = (char *)(result->options + noptions); + + for(tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + const char *longopts = table[tabindex]->longopts; + + while(longopts && *longopts) + { + char *storestart = store; + + /* copy progressive longopt strings into the store */ + while (*longopts != 0 && *longopts != ',') + *store++ = *longopts++; + *store++ = 0; + if (*longopts == ',') + longopts++; + /*fprintf(stderr,"storestart=\"%s\"\n",storestart);*/ + + result->options[option_index].name = storestart; + result->options[option_index].flag = &(result->getoptval); + result->options[option_index].val = tabindex; + if (table[tabindex]->flag & ARG_HASOPTVALUE) + result->options[option_index].has_arg = 2; + else if (table[tabindex]->flag & ARG_HASVALUE) + result->options[option_index].has_arg = 1; + else + result->options[option_index].has_arg = 0; + + option_index++; + } + } + /* terminate the options array with a zero-filled entry */ + result->options[option_index].name = 0; + result->options[option_index].has_arg = 0; + result->options[option_index].flag = 0; + result->options[option_index].val = 0; + } + + /*dump_longoptions(result);*/ + return result; +} + +static +char * alloc_shortoptions(struct arg_hdr * *table) +{ + char *result; + size_t len = 2; + int tabindex; + + /* determine the total number of option chars required */ + for(tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + struct arg_hdr *hdr = table[tabindex]; + len += 3 * (hdr->shortopts ? strlen(hdr->shortopts) : 0); + } + + result = malloc(len); + if (result) + { + char *res = result; + + /* add a leading ':' so getopt return codes distinguish */ + /* unrecognised option and options missing argument values */ + *res++ = ':'; + + for(tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + struct arg_hdr *hdr = table[tabindex]; + const char *shortopts = hdr->shortopts; + while(shortopts && *shortopts) + { + *res++ = *shortopts++; + if (hdr->flag & ARG_HASVALUE) + *res++ = ':'; + if (hdr->flag & ARG_HASOPTVALUE) + *res++ = ':'; + } + } + /* null terminate the string */ + *res = 0; + } + + /*printf("alloc_shortoptions() returns \"%s\"\n",(result?result:"NULL"));*/ + return result; +} + + +/* return index of the table terminator entry */ +static +int arg_endindex(struct arg_hdr * *table) +{ + int tabindex = 0; + while (!(table[tabindex]->flag & ARG_TERMINATOR)) + tabindex++; + return tabindex; +} + + +static +void arg_parse_tagged(int argc, + char * *argv, + struct arg_hdr * *table, + struct arg_end *endtable) +{ + struct longoptions *longoptions; + char *shortoptions; + int copt; + + /*printf("arg_parse_tagged(%d,%p,%p,%p)\n",argc,argv,table,endtable);*/ + + /* allocate short and long option arrays for the given opttable[]. */ + /* if the allocs fail then put an error msg in the last table entry. */ + longoptions = alloc_longoptions(table); + shortoptions = alloc_shortoptions(table); + if (!longoptions || !shortoptions) + { + /* one or both memory allocs failed */ + arg_register_error(endtable, endtable, ARG_EMALLOC, NULL); + /* free anything that was allocated (this is null safe) */ + free(shortoptions); + free(longoptions); + return; + } + + /*dump_longoptions(longoptions);*/ + + /* reset getopts internal option-index to zero, and disable error reporting */ + optind = 0; + opterr = 0; + + /* fetch and process args using getopt_long */ + while( (copt = + getopt_long(argc, argv, shortoptions, longoptions->options, + NULL)) != -1) + { + /* + printf("optarg='%s'\n",optarg); + printf("optind=%d\n",optind); + printf("copt=%c\n",(char)copt); + printf("optopt=%c (%d)\n",optopt, (int)(optopt)); + */ + switch(copt) + { + case 0: + { + int tabindex = longoptions->getoptval; + void *parent = table[tabindex]->parent; + /*printf("long option detected from argtable[%d]\n", tabindex);*/ + if (optarg && optarg[0] == 0 && + (table[tabindex]->flag & ARG_HASVALUE)) + { + /* printf(": long option %s requires an argument\n",argv[optind-1]); */ + arg_register_error(endtable, endtable, ARG_EMISSARG, + argv[optind - 1]); + /* continue to scan the (empty) argument value to enforce argument count checking */ + } + if (table[tabindex]->scanfn) + { + int errorcode = table[tabindex]->scanfn(parent, optarg); + if (errorcode != 0) + arg_register_error(endtable, parent, errorcode, optarg); + } + } + break; + + case '?': + /* + * getopt_long() found an unrecognised short option. + * if it was a short option its value is in optopt + * if it was a long option then optopt=0 + */ + switch (optopt) + { + case 0: + /*printf("?0 unrecognised long option %s\n",argv[optind-1]);*/ + arg_register_error(endtable, endtable, ARG_ELONGOPT, + argv[optind - 1]); + break; + default: + /*printf("?* unrecognised short option '%c'\n",optopt);*/ + arg_register_error(endtable, endtable, optopt, NULL); + break; + } + break; + + case ':': + /* + * getopt_long() found an option with its argument missing. + */ + /*printf(": option %s requires an argument\n",argv[optind-1]); */ + arg_register_error(endtable, endtable, ARG_EMISSARG, + argv[optind - 1]); + break; + + default: + { + /* getopt_long() found a valid short option */ + int tabindex = find_shortoption(table, (char)copt); + /*printf("short option detected from argtable[%d]\n", tabindex);*/ + if (tabindex == -1) + { + /* should never get here - but handle it just in case */ + /*printf("unrecognised short option %d\n",copt);*/ + arg_register_error(endtable, endtable, copt, NULL); + } + else + { + if (table[tabindex]->scanfn) + { + void *parent = table[tabindex]->parent; + int errorcode = table[tabindex]->scanfn(parent, optarg); + if (errorcode != 0) + arg_register_error(endtable, parent, errorcode, optarg); + } + } + break; + } + } + } + + free(shortoptions); + free(longoptions); +} + + +static +void arg_parse_untagged(int argc, + char * *argv, + struct arg_hdr * *table, + struct arg_end *endtable) +{ + int tabindex = 0; + int errorlast = 0; + const char *optarglast = NULL; + void *parentlast = NULL; + + /*printf("arg_parse_untagged(%d,%p,%p,%p)\n",argc,argv,table,endtable);*/ + while (!(table[tabindex]->flag & ARG_TERMINATOR)) + { + void *parent; + int errorcode; + + /* if we have exhausted our argv[optind] entries then we have finished */ + if (optind >= argc) + { + /*printf("arg_parse_untagged(): argv[] exhausted\n");*/ + return; + } + + /* skip table entries with non-null long or short options (they are not untagged entries) */ + if (table[tabindex]->longopts || table[tabindex]->shortopts) + { + /*printf("arg_parse_untagged(): skipping argtable[%d] (tagged argument)\n",tabindex);*/ + tabindex++; + continue; + } + + /* skip table entries with NULL scanfn */ + if (!(table[tabindex]->scanfn)) + { + /*printf("arg_parse_untagged(): skipping argtable[%d] (NULL scanfn)\n",tabindex);*/ + tabindex++; + continue; + } + + /* attempt to scan the current argv[optind] with the current */ + /* table[tabindex] entry. If it succeeds then keep it, otherwise */ + /* try again with the next table[] entry. */ + parent = table[tabindex]->parent; + errorcode = table[tabindex]->scanfn(parent, argv[optind]); + if (errorcode == 0) + { + /* success, move onto next argv[optind] but stay with same table[tabindex] */ + /*printf("arg_parse_untagged(): argtable[%d] successfully matched\n",tabindex);*/ + optind++; + + /* clear the last tentative error */ + errorlast = 0; + } + else + { + /* failure, try same argv[optind] with next table[tabindex] entry */ + /*printf("arg_parse_untagged(): argtable[%d] failed match\n",tabindex);*/ + tabindex++; + + /* remember this as a tentative error we may wish to reinstate later */ + errorlast = errorcode; + optarglast = argv[optind]; + parentlast = parent; + } + + } + + /* if a tenative error still remains at this point then register it as a proper error */ + if (errorlast) + { + arg_register_error(endtable, parentlast, errorlast, optarglast); + optind++; + } + + /* only get here when not all argv[] entries were consumed */ + /* register an error for each unused argv[] entry */ + while (optind < argc) + { + /*printf("arg_parse_untagged(): argv[%d]=\"%s\" not consumed\n",optind,argv[optind]);*/ + arg_register_error(endtable, endtable, ARG_ENOMATCH, argv[optind++]); + } + + return; +} + + +static +void arg_parse_check(struct arg_hdr * *table, struct arg_end *endtable) +{ + int tabindex = 0; + /* printf("arg_parse_check()\n"); */ + do + { + if (table[tabindex]->checkfn) + { + void *parent = table[tabindex]->parent; + int errorcode = table[tabindex]->checkfn(parent); + if (errorcode != 0) + arg_register_error(endtable, parent, errorcode, NULL); + } + } while(!(table[tabindex++]->flag & ARG_TERMINATOR)); +} + + +static +void arg_reset(void * *argtable) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int tabindex = 0; + /*printf("arg_reset(%p)\n",argtable);*/ + do + { + if (table[tabindex]->resetfn) + table[tabindex]->resetfn(table[tabindex]->parent); + } while(!(table[tabindex++]->flag & ARG_TERMINATOR)); +} + + +int arg_parse(int argc, char * *argv, void * *argtable) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + struct arg_end *endtable; + int endindex; + char * *argvcopy = NULL; + + /*printf("arg_parse(%d,%p,%p)\n",argc,argv,argtable);*/ + + /* reset any argtable data from previous invocations */ + arg_reset(argtable); + + /* locate the first end-of-table marker within the array */ + endindex = arg_endindex(table); + endtable = (struct arg_end *)table[endindex]; + + /* Special case of argc==0. This can occur on Texas Instruments DSP. */ + /* Failure to trap this case results in an unwanted NULL result from */ + /* the malloc for argvcopy (next code block). */ + if (argc == 0) + { + /* We must still perform post-parse checks despite the absence of command line arguments */ + arg_parse_check(table, endtable); + + /* Now we are finished */ + return endtable->count; + } + + argvcopy = (char **)malloc(sizeof(char *) * (argc + 1)); + if (argvcopy) + { + int i; + + /* + Fill in the local copy of argv[]. We need a local copy + because getopt rearranges argv[] which adversely affects + susbsequent parsing attempts. + */ + for (i = 0; i < argc; i++) + argvcopy[i] = argv[i]; + + argvcopy[argc] = NULL; + + /* parse the command line (local copy) for tagged options */ + arg_parse_tagged(argc, argvcopy, table, endtable); + + /* parse the command line (local copy) for untagged options */ + arg_parse_untagged(argc, argvcopy, table, endtable); + + /* if no errors so far then perform post-parse checks otherwise dont bother */ + if (endtable->count == 0) + arg_parse_check(table, endtable); + + /* release the local copt of argv[] */ + free(argvcopy); + } + else + { + /* memory alloc failed */ + arg_register_error(endtable, endtable, ARG_EMALLOC, NULL); + } + + return endtable->count; +} + + +/* + * Concatenate contents of src[] string onto *pdest[] string. + * The *pdest pointer is altered to point to the end of the + * target string and *pndest is decremented by the same number + * of chars. + * Does not append more than *pndest chars into *pdest[] + * so as to prevent buffer overruns. + * Its something like strncat() but more efficient for repeated + * calls on the same destination string. + * Example of use: + * char dest[30] = "good" + * size_t ndest = sizeof(dest); + * char *pdest = dest; + * arg_char(&pdest,"bye ",&ndest); + * arg_char(&pdest,"cruel ",&ndest); + * arg_char(&pdest,"world!",&ndest); + * Results in: + * dest[] == "goodbye cruel world!" + * ndest == 10 + */ +static +void arg_cat(char * *pdest, const char *src, size_t *pndest) +{ + char *dest = *pdest; + char *end = dest + *pndest; + + /*locate null terminator of dest string */ + while(dest < end && *dest != 0) + dest++; + + /* concat src string to dest string */ + while(dest < end && *src != 0) + *dest++ = *src++; + + /* null terminate dest string */ + *dest = 0; + + /* update *pdest and *pndest */ + *pndest = end - dest; + *pdest = dest; +} + + +static +void arg_cat_option(char *dest, + size_t ndest, + const char *shortopts, + const char *longopts, + const char *datatype, + int optvalue) +{ + if (shortopts) + { + char option[3]; + + /* note: option array[] is initialiazed dynamically here to satisfy */ + /* a deficiency in the watcom compiler wrt static array initializers. */ + option[0] = '-'; + option[1] = shortopts[0]; + option[2] = 0; + + arg_cat(&dest, option, &ndest); + if (datatype) + { + arg_cat(&dest, " ", &ndest); + if (optvalue) + { + arg_cat(&dest, "[", &ndest); + arg_cat(&dest, datatype, &ndest); + arg_cat(&dest, "]", &ndest); + } + else + arg_cat(&dest, datatype, &ndest); + } + } + else if (longopts) + { + size_t ncspn; + + /* add "--" tag prefix */ + arg_cat(&dest, "--", &ndest); + + /* add comma separated option tag */ + ncspn = strcspn(longopts, ","); + strncat(dest, longopts, (ncspn < ndest) ? ncspn : ndest); + + if (datatype) + { + arg_cat(&dest, "=", &ndest); + if (optvalue) + { + arg_cat(&dest, "[", &ndest); + arg_cat(&dest, datatype, &ndest); + arg_cat(&dest, "]", &ndest); + } + else + arg_cat(&dest, datatype, &ndest); + } + } + else if (datatype) + { + if (optvalue) + { + arg_cat(&dest, "[", &ndest); + arg_cat(&dest, datatype, &ndest); + arg_cat(&dest, "]", &ndest); + } + else + arg_cat(&dest, datatype, &ndest); + } +} + +static +void arg_cat_optionv(char *dest, + size_t ndest, + const char *shortopts, + const char *longopts, + const char *datatype, + int optvalue, + const char *separator) +{ + separator = separator ? separator : ""; + + if (shortopts) + { + const char *c = shortopts; + while(*c) + { + /* "-a|-b|-c" */ + char shortopt[3]; + + /* note: shortopt array[] is initialiazed dynamically here to satisfy */ + /* a deficiency in the watcom compiler wrt static array initializers. */ + shortopt[0] = '-'; + shortopt[1] = *c; + shortopt[2] = 0; + + arg_cat(&dest, shortopt, &ndest); + if (*++c) + arg_cat(&dest, separator, &ndest); + } + } + + /* put separator between long opts and short opts */ + if (shortopts && longopts) + arg_cat(&dest, separator, &ndest); + + if (longopts) + { + const char *c = longopts; + while(*c) + { + size_t ncspn; + + /* add "--" tag prefix */ + arg_cat(&dest, "--", &ndest); + + /* add comma separated option tag */ + ncspn = strcspn(c, ","); + strncat(dest, c, (ncspn < ndest) ? ncspn : ndest); + c += ncspn; + + /* add given separator in place of comma */ + if (*c == ',') + { + arg_cat(&dest, separator, &ndest); + c++; + } + } + } + + if (datatype) + { + if (longopts) + arg_cat(&dest, "=", &ndest); + else if (shortopts) + arg_cat(&dest, " ", &ndest); + + if (optvalue) + { + arg_cat(&dest, "[", &ndest); + arg_cat(&dest, datatype, &ndest); + arg_cat(&dest, "]", &ndest); + } + else + arg_cat(&dest, datatype, &ndest); + } +} + + +/* this function should be deprecated because it doesnt consider optional argument values (ARG_HASOPTVALUE) */ +void arg_print_option(FILE *fp, + const char *shortopts, + const char *longopts, + const char *datatype, + const char *suffix) +{ + char syntax[200] = ""; + suffix = suffix ? suffix : ""; + + /* there is no way of passing the proper optvalue for optional argument values here, so we must ignore it */ + arg_cat_optionv(syntax, + sizeof(syntax), + shortopts, + longopts, + datatype, + 0, + "|"); + + fputs(syntax, fp); + fputs(suffix, fp); +} + + +/* + * Print a GNU style [OPTION] string in which all short options that + * do not take argument values are presented in abbreviated form, as + * in: -xvfsd, or -xvf[sd], or [-xvsfd] + */ +static +void arg_print_gnuswitch(FILE *fp, struct arg_hdr * *table) +{ + int tabindex; + char *format1 = " -%c"; + char *format2 = " [-%c"; + char *suffix = ""; + + /* print all mandatory switches that are without argument values */ + for(tabindex = 0; + table[tabindex] && !(table[tabindex]->flag & ARG_TERMINATOR); + tabindex++) + { + /* skip optional options */ + if (table[tabindex]->mincount < 1) + continue; + + /* skip non-short options */ + if (table[tabindex]->shortopts == NULL) + continue; + + /* skip options that take argument values */ + if (table[tabindex]->flag & ARG_HASVALUE) + continue; + + /* print the short option (only the first short option char, ignore multiple choices)*/ + fprintf(fp, format1, table[tabindex]->shortopts[0]); + format1 = "%c"; + format2 = "[%c"; + } + + /* print all optional switches that are without argument values */ + for(tabindex = 0; + table[tabindex] && !(table[tabindex]->flag & ARG_TERMINATOR); + tabindex++) + { + /* skip mandatory args */ + if (table[tabindex]->mincount > 0) + continue; + + /* skip args without short options */ + if (table[tabindex]->shortopts == NULL) + continue; + + /* skip args with values */ + if (table[tabindex]->flag & ARG_HASVALUE) + continue; + + /* print first short option */ + fprintf(fp, format2, table[tabindex]->shortopts[0]); + format2 = "%c"; + suffix = "]"; + } + + fprintf(fp, "%s", suffix); +} + + +void arg_print_syntax(FILE *fp, void * *argtable, const char *suffix) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int i, tabindex; + + /* print GNU style [OPTION] string */ + arg_print_gnuswitch(fp, table); + + /* print remaining options in abbreviated style */ + for(tabindex = 0; + table[tabindex] && !(table[tabindex]->flag & ARG_TERMINATOR); + tabindex++) + { + char syntax[200] = ""; + const char *shortopts, *longopts, *datatype; + + /* skip short options without arg values (they were printed by arg_print_gnu_switch) */ + if (table[tabindex]->shortopts && + !(table[tabindex]->flag & ARG_HASVALUE)) + continue; + + shortopts = table[tabindex]->shortopts; + longopts = table[tabindex]->longopts; + datatype = table[tabindex]->datatype; + arg_cat_option(syntax, + sizeof(syntax), + shortopts, + longopts, + datatype, + table[tabindex]->flag & ARG_HASOPTVALUE); + + if (strlen(syntax) > 0) + { + /* print mandatory instances of this option */ + for (i = 0; i < table[tabindex]->mincount; i++) + fprintf(fp, " %s", syntax); + + /* print optional instances enclosed in "[..]" */ + switch ( table[tabindex]->maxcount - table[tabindex]->mincount ) + { + case 0: + break; + case 1: + fprintf(fp, " [%s]", syntax); + break; + case 2: + fprintf(fp, " [%s] [%s]", syntax, syntax); + break; + default: + fprintf(fp, " [%s]...", syntax); + break; + } + } + } + + if (suffix) + fprintf(fp, "%s", suffix); +} + + +void arg_print_syntaxv(FILE *fp, void * *argtable, const char *suffix) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int i, tabindex; + + /* print remaining options in abbreviated style */ + for(tabindex = 0; + table[tabindex] && !(table[tabindex]->flag & ARG_TERMINATOR); + tabindex++) + { + char syntax[200] = ""; + const char *shortopts, *longopts, *datatype; + + shortopts = table[tabindex]->shortopts; + longopts = table[tabindex]->longopts; + datatype = table[tabindex]->datatype; + arg_cat_optionv(syntax, + sizeof(syntax), + shortopts, + longopts, + datatype, + table[tabindex]->flag & ARG_HASOPTVALUE, + "|"); + + /* print mandatory options */ + for (i = 0; i < table[tabindex]->mincount; i++) + fprintf(fp, " %s", syntax); + + /* print optional args enclosed in "[..]" */ + switch ( table[tabindex]->maxcount - table[tabindex]->mincount ) + { + case 0: + break; + case 1: + fprintf(fp, " [%s]", syntax); + break; + case 2: + fprintf(fp, " [%s] [%s]", syntax, syntax); + break; + default: + fprintf(fp, " [%s]...", syntax); + break; + } + } + + if (suffix) + fprintf(fp, "%s", suffix); +} + + +void arg_print_glossary(FILE *fp, void * *argtable, const char *format) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int tabindex; + + format = format ? format : " %-20s %s\n"; + for (tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + if (table[tabindex]->glossary) + { + char syntax[200] = ""; + const char *shortopts = table[tabindex]->shortopts; + const char *longopts = table[tabindex]->longopts; + const char *datatype = table[tabindex]->datatype; + const char *glossary = table[tabindex]->glossary; + arg_cat_optionv(syntax, + sizeof(syntax), + shortopts, + longopts, + datatype, + table[tabindex]->flag & ARG_HASOPTVALUE, + ", "); + fprintf(fp, format, syntax, glossary); + } + } +} + + +/** + * Print a piece of text formatted, which means in a column with a + * left and a right margin. The lines are wrapped at whitspaces next + * to right margin. The function does not indent the first line, but + * only the following ones. + * + * Example: + * arg_print_formatted( fp, 0, 5, "Some text that doesn't fit." ) + * will result in the following output: + * + * Some + * text + * that + * doesn' + * t fit. + * + * Too long lines will be wrapped in the middle of a word. + * + * arg_print_formatted( fp, 2, 7, "Some text that doesn't fit." ) + * will result in the following output: + * + * Some + * text + * that + * doesn' + * t fit. + * + * As you see, the first line is not indented. This enables output of + * lines, which start in a line where output already happened. + * + * Author: Uli Fouquet + */ +static +void arg_print_formatted( FILE *fp, + const unsigned lmargin, + const unsigned rmargin, + const char *text ) +{ + const unsigned textlen = strlen( text ); + unsigned line_start = 0; + unsigned line_end = textlen + 1; + const unsigned colwidth = (rmargin - lmargin) + 1; + + /* Someone doesn't like us... */ + if ( line_end < line_start ) + { fprintf( fp, "%s\n", text ); } + + while (line_end - 1 > line_start ) + { + /* Eat leading whitespaces. This is essential because while + wrapping lines, there will often be a whitespace at beginning + of line */ + while ( isspace(*(text + line_start)) ) + { line_start++; } + + if ((line_end - line_start) > colwidth ) + { line_end = line_start + colwidth; } + + /* Find last whitespace, that fits into line */ + while ( ( line_end > line_start ) + && ( line_end - line_start > colwidth ) + && !isspace(*(text + line_end))) + { line_end--; } + + /* Do not print trailing whitespace. If this text + has got only one line, line_end now points to the + last char due to initialization. */ + line_end--; + + /* Output line of text */ + while ( line_start < line_end ) + { + fputc(*(text + line_start), fp ); + line_start++; + } + fputc( '\n', fp ); + + /* Initialize another line */ + if ( line_end + 1 < textlen ) + { + unsigned i; + + for (i = 0; i < lmargin; i++ ) + { fputc( ' ', fp ); } + + line_end = textlen; + } + + /* If we have to print another line, get also the last char. */ + line_end++; + + } /* lines of text */ +} + +/** + * Prints the glossary in strict GNU format. + * Differences to arg_print_glossary() are: + * - wraps lines after 80 chars + * - indents lines without shortops + * - does not accept formatstrings + * + * Contributed by Uli Fouquet + */ +void arg_print_glossary_gnu(FILE *fp, void * *argtable ) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int tabindex; + + for(tabindex = 0; !(table[tabindex]->flag & ARG_TERMINATOR); tabindex++) + { + if (table[tabindex]->glossary) + { + char syntax[200] = ""; + const char *shortopts = table[tabindex]->shortopts; + const char *longopts = table[tabindex]->longopts; + const char *datatype = table[tabindex]->datatype; + const char *glossary = table[tabindex]->glossary; + + if ( !shortopts && longopts ) + { + /* Indent trailing line by 4 spaces... */ + memset( syntax, ' ', 4 ); + *(syntax + 4) = '\0'; + } + + arg_cat_optionv(syntax, + sizeof(syntax), + shortopts, + longopts, + datatype, + table[tabindex]->flag & ARG_HASOPTVALUE, + ", "); + + /* If syntax fits not into column, print glossary in new line... */ + if ( strlen(syntax) > 25 ) + { + fprintf( fp, " %-25s %s\n", syntax, "" ); + *syntax = '\0'; + } + + fprintf( fp, " %-25s ", syntax ); + arg_print_formatted( fp, 28, 79, glossary ); + } + } /* for each table entry */ + + fputc( '\n', fp ); +} + + +/** + * Checks the argtable[] array for NULL entries and returns 1 + * if any are found, zero otherwise. + */ +int arg_nullcheck(void * *argtable) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int tabindex; + /*printf("arg_nullcheck(%p)\n",argtable);*/ + + if (!table) + return 1; + + tabindex = 0; + do + { + /*printf("argtable[%d]=%p\n",tabindex,argtable[tabindex]);*/ + if (!table[tabindex]) + return 1; + } while(!(table[tabindex++]->flag & ARG_TERMINATOR)); + + return 0; +} + + +/* + * arg_free() is deprecated in favour of arg_freetable() due to a flaw in its design. + * The flaw results in memory leak in the (very rare) case that an intermediate + * entry in the argtable array failed its memory allocation while others following + * that entry were still allocated ok. Those subsequent allocations will not be + * deallocated by arg_free(). + * Despite the unlikeliness of the problem occurring, and the even unlikelier event + * that it has any deliterious effect, it is fixed regardless by replacing arg_free() + * with the newer arg_freetable() function. + * We still keep arg_free() for backwards compatibility. + */ +void arg_free(void * *argtable) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + int tabindex = 0; + int flag; + /*printf("arg_free(%p)\n",argtable);*/ + do + { + /* + if we encounter a NULL entry then somewhat incorrectly we presume + we have come to the end of the array. It isnt strictly true because + an intermediate entry could be NULL with other non-NULL entries to follow. + The subsequent argtable entries would then not be freed as they should. + */ + if (table[tabindex] == NULL) + break; + + flag = table[tabindex]->flag; + free(table[tabindex]); + table[tabindex++] = NULL; + + } while(!(flag & ARG_TERMINATOR)); +} + +/* frees each non-NULL element of argtable[], where n is the size of the number of entries in the array */ +void arg_freetable(void * *argtable, size_t n) +{ + struct arg_hdr * *table = (struct arg_hdr * *)argtable; + size_t tabindex = 0; + /*printf("arg_freetable(%p)\n",argtable);*/ + for (tabindex = 0; tabindex < n; tabindex++) + { + if (table[tabindex] == NULL) + continue; + + free(table[tabindex]); + table[tabindex] = NULL; + }; +} + diff --git a/c/argtable3.h b/c/argtable3.h new file mode 100644 index 0000000000..1107de250b --- /dev/null +++ b/c/argtable3.h @@ -0,0 +1,305 @@ +/******************************************************************************* + * This file is part of the argtable3 library. + * + * Copyright (C) 1998-2001,2003-2011,2013 Stewart Heitmann + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of STEWART HEITMANN nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL STEWART HEITMANN BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ******************************************************************************/ + +#ifndef ARGTABLE3 +#define ARGTABLE3 + +#include /* FILE */ +#include /* struct tm */ + +#ifdef __cplusplus +extern "C" { +#endif + +#define ARG_REX_ICASE 1 + +/* bit masks for arg_hdr.flag */ +enum +{ + ARG_TERMINATOR=0x1, + ARG_HASVALUE=0x2, + ARG_HASOPTVALUE=0x4 +}; + +typedef void (arg_resetfn)(void *parent); +typedef int (arg_scanfn)(void *parent, const char *argval); +typedef int (arg_checkfn)(void *parent); +typedef void (arg_errorfn)(void *parent, FILE *fp, int error, const char *argval, const char *progname); + + +/* +* The arg_hdr struct defines properties that are common to all arg_xxx structs. +* The argtable library requires each arg_xxx struct to have an arg_hdr +* struct as its first data member. +* The argtable library functions then use this data to identify the +* properties of the command line option, such as its option tags, +* datatype string, and glossary strings, and so on. +* Moreover, the arg_hdr struct contains pointers to custom functions that +* are provided by each arg_xxx struct which perform the tasks of parsing +* that particular arg_xxx arguments, performing post-parse checks, and +* reporting errors. +* These functions are private to the individual arg_xxx source code +* and are the pointer to them are initiliased by that arg_xxx struct's +* constructor function. The user could alter them after construction +* if desired, but the original intention is for them to be set by the +* constructor and left unaltered. +*/ +struct arg_hdr +{ + char flag; /* Modifier flags: ARG_TERMINATOR, ARG_HASVALUE. */ + const char *shortopts; /* String defining the short options */ + const char *longopts; /* String defiing the long options */ + const char *datatype; /* Description of the argument data type */ + const char *glossary; /* Description of the option as shown by arg_print_glossary function */ + int mincount; /* Minimum number of occurences of this option accepted */ + int maxcount; /* Maximum number of occurences if this option accepted */ + void *parent; /* Pointer to parent arg_xxx struct */ + arg_resetfn *resetfn; /* Pointer to parent arg_xxx reset function */ + arg_scanfn *scanfn; /* Pointer to parent arg_xxx scan function */ + arg_checkfn *checkfn; /* Pointer to parent arg_xxx check function */ + arg_errorfn *errorfn; /* Pointer to parent arg_xxx error function */ + void *priv; /* Pointer to private header data for use by arg_xxx functions */ +}; + +struct arg_rem +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ +}; + +struct arg_lit +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args */ +}; + +struct arg_int +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args */ + int *ival; /* Array of parsed argument values */ +}; + +struct arg_dbl +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args */ + double *dval; /* Array of parsed argument values */ +}; + +struct arg_str +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args */ + const char **sval; /* Array of parsed argument values */ +}; + +struct arg_rex +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args */ + const char **sval; /* Array of parsed argument values */ +}; + +struct arg_file +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of matching command line args*/ + const char **filename; /* Array of parsed filenames (eg: /home/foo.bar) */ + const char **basename; /* Array of parsed basenames (eg: foo.bar) */ + const char **extension; /* Array of parsed extensions (eg: .bar) */ +}; + +struct arg_date +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + const char *format; /* strptime format string used to parse the date */ + int count; /* Number of matching command line args */ + struct tm *tmval; /* Array of parsed time values */ +}; + +enum {ARG_ELIMIT=1, ARG_EMALLOC, ARG_ENOMATCH, ARG_ELONGOPT, ARG_EMISSARG}; +struct arg_end +{ + struct arg_hdr hdr; /* The mandatory argtable header struct */ + int count; /* Number of errors encountered */ + int *error; /* Array of error codes */ + void **parent; /* Array of pointers to offending arg_xxx struct */ + const char **argval; /* Array of pointers to offending argv[] string */ +}; + + +/**** arg_xxx constructor functions *********************************/ + +struct arg_rem* arg_rem(const char* datatype, const char* glossary); + +struct arg_lit* arg_lit0(const char* shortopts, + const char* longopts, + const char* glossary); +struct arg_lit* arg_lit1(const char* shortopts, + const char* longopts, + const char *glossary); +struct arg_lit* arg_litn(const char* shortopts, + const char* longopts, + int mincount, + int maxcount, + const char *glossary); + +struct arg_key* arg_key0(const char* keyword, + int flags, + const char* glossary); +struct arg_key* arg_key1(const char* keyword, + int flags, + const char* glossary); +struct arg_key* arg_keyn(const char* keyword, + int flags, + int mincount, + int maxcount, + const char* glossary); + +struct arg_int* arg_int0(const char* shortopts, + const char* longopts, + const char* datatype, + const char* glossary); +struct arg_int* arg_int1(const char* shortopts, + const char* longopts, + const char* datatype, + const char *glossary); +struct arg_int* arg_intn(const char* shortopts, + const char* longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary); + +struct arg_dbl* arg_dbl0(const char* shortopts, + const char* longopts, + const char* datatype, + const char* glossary); +struct arg_dbl* arg_dbl1(const char* shortopts, + const char* longopts, + const char* datatype, + const char *glossary); +struct arg_dbl* arg_dbln(const char* shortopts, + const char* longopts, + const char *datatype, + int mincount, + int maxcount, + const char *glossary); + +struct arg_str* arg_str0(const char* shortopts, + const char* longopts, + const char* datatype, + const char* glossary); +struct arg_str* arg_str1(const char* shortopts, + const char* longopts, + const char* datatype, + const char *glossary); +struct arg_str* arg_strn(const char* shortopts, + const char* longopts, + const char* datatype, + int mincount, + int maxcount, + const char *glossary); + +struct arg_rex* arg_rex0(const char* shortopts, + const char* longopts, + const char* pattern, + const char* datatype, + int flags, + const char* glossary); +struct arg_rex* arg_rex1(const char* shortopts, + const char* longopts, + const char* pattern, + const char* datatype, + int flags, + const char *glossary); +struct arg_rex* arg_rexn(const char* shortopts, + const char* longopts, + const char* pattern, + const char* datatype, + int mincount, + int maxcount, + int flags, + const char *glossary); + +struct arg_file* arg_file0(const char* shortopts, + const char* longopts, + const char* datatype, + const char* glossary); +struct arg_file* arg_file1(const char* shortopts, + const char* longopts, + const char* datatype, + const char *glossary); +struct arg_file* arg_filen(const char* shortopts, + const char* longopts, + const char* datatype, + int mincount, + int maxcount, + const char *glossary); + +struct arg_date* arg_date0(const char* shortopts, + const char* longopts, + const char* format, + const char* datatype, + const char* glossary); +struct arg_date* arg_date1(const char* shortopts, + const char* longopts, + const char* format, + const char* datatype, + const char *glossary); +struct arg_date* arg_daten(const char* shortopts, + const char* longopts, + const char* format, + const char* datatype, + int mincount, + int maxcount, + const char *glossary); + +struct arg_end* arg_end(int maxerrors); + + +/**** other functions *******************************************/ +int arg_nullcheck(void **argtable); +int arg_parse(int argc, char **argv, void **argtable); +void arg_print_option(FILE *fp, const char *shortopts, const char *longopts, const char *datatype, const char *suffix); +void arg_print_syntax(FILE *fp, void **argtable, const char *suffix); +void arg_print_syntaxv(FILE *fp, void **argtable, const char *suffix); +void arg_print_glossary(FILE *fp, void **argtable, const char *format); +void arg_print_glossary_gnu(FILE *fp, void **argtable); +void arg_print_errors(FILE* fp, struct arg_end* end, const char* progname); +void arg_freetable(void **argtable, size_t n); + +/**** deprecated functions, for back-compatibility only ********/ +void arg_free(void **argtable); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/kastore b/c/kastore new file mode 160000 index 0000000000..a6599e9d58 --- /dev/null +++ b/c/kastore @@ -0,0 +1 @@ +Subproject commit a6599e9d58332e5f46da4869b814babaaf86afc1 diff --git a/c/main.c b/c/main.c new file mode 100644 index 0000000000..2be32b7540 --- /dev/null +++ b/c/main.c @@ -0,0 +1,558 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include "argtable3.h" + +#include "tskit.h" + +/* This file defines a crude CLI for tskit. It is intended for development + * use only. + */ + +typedef struct { + int alphabet; + double mutation_rate; +} mutation_params_t; + +static void +fatal_error(const char *msg, ...) +{ + va_list argp; + fprintf(stderr, "main:"); + va_start(argp, msg); + vfprintf(stderr, msg, argp); + va_end(argp); + fprintf(stderr, "\n"); + exit(EXIT_FAILURE); +} + +static void +fatal_library_error(int err, const char *msg, ...) +{ + va_list argp; + fprintf(stderr, "error:"); + va_start(argp, msg); + vfprintf(stderr, msg, argp); + va_end(argp); + fprintf(stderr, ":%d:'%s'\n", err, tsk_strerror(err)); + exit(EXIT_FAILURE); +} + +static void +load_tree_sequence(tsk_treeseq_t *ts, const char *filename) +{ + int ret = tsk_treeseq_load(ts, filename, 0); + if (ret != 0) { + fatal_library_error(ret, "Load error"); + } +} + +static void +print_variants(tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_vargen_t vg; + uint32_t j, k; + tsk_variant_t* var; + + printf("variants (%d) \n", (int) tsk_treeseq_get_num_sites(ts)); + ret = tsk_vargen_alloc(&vg, ts, NULL, 0, 0); + if (ret != 0) { + fatal_library_error(ret, "tsk_vargen_alloc"); + } + j = 0; + while ((ret = tsk_vargen_next(&vg, &var)) == 1) { + printf("%.2f\t", var->site->position); + for (j = 0; j < var->num_alleles; j++) { + for (k = 0; k < var->allele_lengths[j]; k++) { + printf("%c", var->alleles[j][k]); + } + if (j < var->num_alleles - 1) { + printf(","); + } + } + printf("\t"); + for (k = 0; k < ts->num_samples; k++) { + printf("%d\t", var->genotypes.u8[k]); + } + printf("\n"); + } + if (ret != 0) { + fatal_library_error(ret, "tsk_vargen_next"); + } + tsk_vargen_free(&vg); +} + +static void +print_haplotypes(tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_hapgen_t hg; + uint32_t j; + char *haplotype; + + printf("haplotypes \n"); + ret = tsk_hapgen_alloc(&hg, ts); + if (ret != 0) { + fatal_library_error(ret, "tsk_hapgen_alloc"); + } + for (j = 0; j < ts->num_samples; j++) { + ret = tsk_hapgen_get_haplotype(&hg, (tsk_id_t) j, &haplotype); + if (ret < 0) { + fatal_library_error(ret, "tsk_hapgen_get_haplotype"); + } + printf("%d\t%s\n", j, haplotype); + } + tsk_hapgen_free(&hg); +} + +static void +print_ld_matrix(tsk_treeseq_t *ts) +{ + int ret; + size_t num_sites = tsk_treeseq_get_num_sites(ts); + tsk_site_t sA, sB; + double *r2 = malloc(num_sites * sizeof(double)); + size_t j, k, num_r2_values; + tsk_ld_calc_t ld_calc; + + if (r2 == NULL) { + fatal_error("no memory"); + } + ret = tsk_ld_calc_alloc(&ld_calc, ts); + printf("alloc: ret = %d\n", ret); + if (ret != 0) { + fatal_library_error(ret, "tsk_ld_calc_alloc"); + } + tsk_ld_calc_print_state(&ld_calc, stdout); + for (j = 0; j < num_sites; j++) { + ret = tsk_ld_calc_get_r2_array(&ld_calc, j, TSK_DIR_FORWARD, num_sites, + DBL_MAX, r2, &num_r2_values); + if (ret != 0) { + fatal_library_error(ret, "tsk_ld_calc_get_r2_array"); + } + for (k = 0; k < num_r2_values; k++) { + ret = tsk_treeseq_get_site(ts, j, &sA); + if (ret != 0) { + fatal_library_error(ret, "get_site"); + } + ret = tsk_treeseq_get_site(ts, (j + k + 1), &sB); + if (ret != 0) { + fatal_library_error(ret, "get_site"); + } + printf("%d\t%f\t%d\t%f\t%.3f\n", + (int) sA.id, sA.position, (int) sB.id, sB.position, r2[k]); + } + } + free(r2); + ret = tsk_ld_calc_free(&ld_calc); + if (ret != 0) { + fatal_library_error(ret, "tsk_ld_calc_write_table"); + } +} + +static void +print_stats(tsk_treeseq_t *ts) +{ + int ret = 0; + uint32_t j; + size_t num_samples = tsk_treeseq_get_num_samples(ts) / 2; + tsk_id_t *sample = malloc(num_samples * sizeof(tsk_id_t)); + double pi; + + if (sample == NULL) { + fatal_error("no memory"); + } + for (j = 0; j < num_samples; j++) { + sample[j] = (tsk_id_t) j; + } + ret = tsk_treeseq_get_pairwise_diversity(ts, sample, num_samples, &pi); + if (ret != 0) { + fatal_library_error(ret, "get_pairwise_diversity"); + } + printf("pi = %f\n", pi); + free(sample); +} + +static void +print_vcf(tsk_treeseq_t *ts, unsigned int ploidy, const char *chrom, int verbose) +{ + int ret = 0; + char *record = NULL; + char *header = NULL; + tsk_vcf_converter_t vc; + + ret = tsk_vcf_converter_alloc(&vc, ts, ploidy, chrom); + if (ret != 0) { + fatal_library_error(ret, "vcf alloc"); + } + if (verbose > 0) { + tsk_vcf_converter_print_state(&vc, stdout); + printf("START VCF\n"); + } + ret = tsk_vcf_converter_get_header(&vc, &header); + if (ret != 0) { + fatal_library_error(ret, "vcf get header"); + } + printf("%s", header); + while ((ret = tsk_vcf_converter_next(&vc, &record)) == 1) { + printf("%s", record); + } + if (ret != 0) { + fatal_library_error(ret, "vcf next"); + } + tsk_vcf_converter_free(&vc); +} + +static void +print_newick_trees(tsk_treeseq_t *ts) +{ + int ret = 0; + char *newick = NULL; + size_t precision = 8; + size_t newick_buffer_size = (precision + 5) * tsk_treeseq_get_num_nodes(ts); + tsk_tree_t tree; + + newick = malloc(newick_buffer_size); + if (newick == NULL) { + fatal_error("No memory\n"); + } + + ret = tsk_tree_alloc(&tree, ts, 0); + if (ret != 0) { + fatal_error("ERROR: %d: %s\n", ret, tsk_strerror(ret)); + } + for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { + ret = tsk_convert_newick(&tree, tree.left_root, precision, + 0, newick_buffer_size, newick); + if (ret != 0) { + fatal_library_error(ret ,"newick"); + } + printf("%d:\t%s\n", (int) tree.index, newick); + } + if (ret < 0) { + fatal_error("ERROR: %d: %s\n", ret, tsk_strerror(ret)); + } + tsk_tree_free(&tree); + free(newick); +} + +static void +print_tree_sequence(tsk_treeseq_t *ts, int verbose) +{ + int ret = 0; + tsk_tree_t tree; + + tsk_treeseq_print_state(ts, stdout); + if (verbose > 0) { + printf("========================\n"); + printf("trees\n"); + printf("========================\n"); + ret = tsk_tree_alloc(&tree, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + if (ret != 0) { + fatal_error("ERROR: %d: %s\n", ret, tsk_strerror(ret)); + } + for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { + printf("-------------------------\n"); + printf("New tree: %d: %f (%d)\n", (int) tree.index, + tree.right - tree.left, (int) tree.num_nodes); + printf("-------------------------\n"); + tsk_tree_print_state(&tree, stdout); + } + if (ret < 0) { + fatal_error("ERROR: %d: %s\n", ret, tsk_strerror(ret)); + } + tsk_tree_free(&tree); + } +} + +static void +run_ld(const char *filename, int TSK_UNUSED(verbose)) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_ld_matrix(&ts); + tsk_treeseq_free(&ts); +} + +static void +run_haplotypes(const char *filename, int TSK_UNUSED(verbose)) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_haplotypes(&ts); + tsk_treeseq_free(&ts); +} + +static void +run_variants(const char *filename, int TSK_UNUSED(verbose)) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_variants(&ts); + tsk_treeseq_free(&ts); +} + +static void +run_vcf(const char *filename, int verbose, int ploidy, const char *chrom) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_vcf(&ts, (unsigned int) ploidy, chrom, verbose); + tsk_treeseq_free(&ts); +} + +static void +run_print(const char *filename, int verbose) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_tree_sequence(&ts, verbose); + tsk_treeseq_free(&ts); +} + +static void +run_newick(const char *filename, int TSK_UNUSED(verbose)) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_newick_trees(&ts); + tsk_treeseq_free(&ts); +} + +static void +run_stats(const char *filename, int TSK_UNUSED(verbose)) +{ + tsk_treeseq_t ts; + + load_tree_sequence(&ts, filename); + print_stats(&ts); + tsk_treeseq_free(&ts); +} + +static void +run_simplify(const char *input_filename, const char *output_filename, size_t num_samples, + bool filter_sites, int verbose) +{ + tsk_treeseq_t ts, subset; + tsk_id_t *samples; + int flags = 0; + int ret; + + if (filter_sites) { + flags |= TSK_FILTER_SITES; + } + + load_tree_sequence(&ts, input_filename); + if (verbose > 0) { + printf(">>>>>>>>\nINPUT:\n>>>>>>>>\n"); + tsk_treeseq_print_state(&ts, stdout); + } + if (num_samples == 0) { + num_samples = tsk_treeseq_get_num_samples(&ts); + } else { + num_samples = TSK_MIN(num_samples, tsk_treeseq_get_num_samples(&ts)); + } + ret = tsk_treeseq_get_samples(&ts, &samples); + if (ret != 0) { + fatal_library_error(ret, "get_samples"); + } + ret = tsk_treeseq_simplify(&ts, samples, num_samples, flags, &subset, NULL); + if (ret != 0) { + fatal_library_error(ret, "Subset error"); + } + ret = tsk_treeseq_dump(&subset, output_filename, 0); + if (ret != 0) { + fatal_library_error(ret, "Write error"); + } + if (verbose > 0) { + printf(">>>>>>>>\nOUTPUT:\n>>>>>>>>\n"); + tsk_treeseq_print_state(&subset, stdout); + } + tsk_treeseq_free(&ts); + tsk_treeseq_free(&subset); +} + +int +main(int argc, char** argv) +{ + /* SYNTAX 1: simplify [-vi] [-s] */ + struct arg_rex *cmd1 = arg_rex1(NULL, NULL, "simplify", NULL, REG_ICASE, NULL); + struct arg_lit *verbose1 = arg_lit0("v", "verbose", NULL); + struct arg_int *num_samples1 = arg_int0("s", "sample-size", "", + "Number of samples to keep in the simplified tree sequence."); + struct arg_lit *filter_sites1 = arg_lit0("i", + "filter-invariant-sites", ""); + struct arg_file *infiles1 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_file *outfiles1 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end1 = arg_end(20); + void* argtable1[] = {cmd1, verbose1, filter_sites1, num_samples1, + infiles1, outfiles1, end1}; + int nerrors1; + + /* SYNTAX 2: ld [-v] */ + struct arg_rex *cmd2 = arg_rex1(NULL, NULL, "ld", NULL, REG_ICASE, NULL); + struct arg_lit *verbose2 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles2 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end2 = arg_end(20); + void* argtable2[] = {cmd2, verbose2, infiles2, end2}; + int nerrors2; + + /* SYNTAX 3: haplotypes [-v] */ + struct arg_rex *cmd3 = arg_rex1(NULL, NULL, "haplotypes", NULL, REG_ICASE, NULL); + struct arg_lit *verbose3 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles3 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end3 = arg_end(20); + void* argtable3[] = {cmd3, verbose3, infiles3, end3}; + int nerrors3; + + /* SYNTAX 4: variants [-v] */ + struct arg_rex *cmd4 = arg_rex1(NULL, NULL, "variants", NULL, REG_ICASE, NULL); + struct arg_lit *verbose4 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles4 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end4 = arg_end(20); + void* argtable4[] = {cmd4, verbose4, infiles4, end4}; + int nerrors4; + + /* SYNTAX 5: vcf [-v] */ + struct arg_rex *cmd5 = arg_rex1(NULL, NULL, "vcf", NULL, REG_ICASE, NULL); + struct arg_lit *verbose5 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles5 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_int *ploidy5 = arg_int0("p", "ploidy", "", + "Ploidy level of the VCF"); + struct arg_str *chrom5 = arg_str0("c", "chrom", "", + "Value for the CHROM column in the VCF (default='1')"); + struct arg_end *end5 = arg_end(20); + void* argtable5[] = {cmd5, verbose5, infiles5, ploidy5, chrom5, end5}; + int nerrors5; + + /* SYNTAX 6: print [-v] */ + struct arg_rex *cmd6 = arg_rex1(NULL, NULL, "print", NULL, REG_ICASE, NULL); + struct arg_lit *verbose6 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles6 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end6 = arg_end(20); + void* argtable6[] = {cmd6, verbose6, infiles6, end6}; + int nerrors6; + + /* SYNTAX 7: newick [-v] */ + struct arg_rex *cmd7 = arg_rex1(NULL, NULL, "newick", NULL, REG_ICASE, NULL); + struct arg_lit *verbose7 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles7 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end7 = arg_end(20); + void* argtable7[] = {cmd7, verbose7, infiles7, end7}; + int nerrors7; + + /* SYNTAX 8: stats [-v] */ + struct arg_rex *cmd8 = arg_rex1(NULL, NULL, "stats", NULL, REG_ICASE, NULL); + struct arg_lit *verbose8 = arg_lit0("v", "verbose", NULL); + struct arg_file *infiles8 = arg_file1(NULL, NULL, NULL, NULL); + struct arg_end *end8 = arg_end(20); + void* argtable8[] = {cmd8, verbose8, infiles8, end8}; + int nerrors8; + + int exitcode = EXIT_SUCCESS; + const char *progname = "main"; + + /* Set defaults */ + ploidy5->ival[0] = 1; + chrom5->sval[0] = "1"; + num_samples1->ival[0] = 0; + + nerrors1 = arg_parse(argc, argv, argtable1); + nerrors2 = arg_parse(argc, argv, argtable2); + nerrors3 = arg_parse(argc, argv, argtable3); + nerrors4 = arg_parse(argc, argv, argtable4); + nerrors5 = arg_parse(argc, argv, argtable5); + nerrors6 = arg_parse(argc, argv, argtable6); + nerrors7 = arg_parse(argc, argv, argtable7); + nerrors8 = arg_parse(argc, argv, argtable8); + + if (nerrors1 == 0) { + run_simplify(infiles1->filename[0], outfiles1->filename[0], + (size_t) num_samples1->ival[0], (bool) filter_sites1->count, + verbose1->count); + } else if (nerrors2 == 0) { + run_ld(infiles2->filename[0], verbose2->count); + } else if (nerrors3 == 0) { + run_haplotypes(infiles3->filename[0], verbose3->count); + } else if (nerrors4 == 0) { + run_variants(infiles4->filename[0], verbose4->count); + } else if (nerrors5 == 0) { + run_vcf(infiles5->filename[0], verbose5->count, ploidy5->ival[0], chrom5->sval[0]); + } else if (nerrors6 == 0) { + run_print(infiles6->filename[0], verbose6->count); + } else if (nerrors7 == 0) { + run_newick(infiles7->filename[0], verbose7->count); + } else if (nerrors8 == 0) { + run_stats(infiles8->filename[0], verbose8->count); + } else { + /* We get here if the command line matched none of the possible syntaxes */ + if (cmd1->count > 0) { + arg_print_errors(stdout, end1, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable1, "\n"); + } else if (cmd2->count > 0) { + arg_print_errors(stdout, end2, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable2, "\n"); + } else if (cmd3->count > 0) { + arg_print_errors(stdout, end3, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable3, "\n"); + } else if (cmd4->count > 0) { + arg_print_errors(stdout, end4, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable4, "\n"); + } else if (cmd5->count > 0) { + arg_print_errors(stdout, end5, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable5, "\n"); + } else if (cmd6->count > 0) { + arg_print_errors(stdout, end6, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable6, "\n"); + } else if (cmd7->count > 0) { + arg_print_errors(stdout, end7, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable7, "\n"); + } else if (cmd8->count > 0) { + arg_print_errors(stdout, end8, progname); + printf("usage: %s ", progname); + arg_print_syntax(stdout, argtable8, "\n"); + } else { + /* no correct cmd literals were given, so we cant presume which syntax was intended */ + printf("%s: missing command.\n",progname); + printf("usage 1: %s ", progname); arg_print_syntax(stdout, argtable1, "\n"); + printf("usage 2: %s ", progname); arg_print_syntax(stdout, argtable2, "\n"); + printf("usage 3: %s ", progname); arg_print_syntax(stdout, argtable3, "\n"); + printf("usage 4: %s ", progname); arg_print_syntax(stdout, argtable4, "\n"); + printf("usage 5: %s ", progname); arg_print_syntax(stdout, argtable5, "\n"); + printf("usage 6: %s ", progname); arg_print_syntax(stdout, argtable6, "\n"); + printf("usage 7: %s ", progname); arg_print_syntax(stdout, argtable7, "\n"); + printf("usage 8: %s ", progname); arg_print_syntax(stdout, argtable8, "\n"); + } + } + + arg_freetable(argtable1, sizeof(argtable1) / sizeof(argtable1[0])); + arg_freetable(argtable2, sizeof(argtable2) / sizeof(argtable2[0])); + arg_freetable(argtable3, sizeof(argtable3) / sizeof(argtable3[0])); + arg_freetable(argtable4, sizeof(argtable4) / sizeof(argtable4[0])); + arg_freetable(argtable5, sizeof(argtable5) / sizeof(argtable5[0])); + arg_freetable(argtable6, sizeof(argtable6) / sizeof(argtable6[0])); + arg_freetable(argtable7, sizeof(argtable7) / sizeof(argtable7[0])); + arg_freetable(argtable8, sizeof(argtable8) / sizeof(argtable8[0])); + + return exitcode; +} diff --git a/c/old_tests.c b/c/old_tests.c new file mode 100644 index 0000000000..c620a2523a --- /dev/null +++ b/c/old_tests.c @@ -0,0 +1,3160 @@ +/* +** Copyright (C) 2016-2018 University of Oxford +** +** This file is part of msprime. +** +** msprime is free software: you can redistribute it and/or modify +** it under the terms of the GNU General Public License as published by +** the Free Software Foundation, either version 3 of the License, or +** (at your option) any later version. +** +** msprime is distributed in the hope that it will be useful, +** but WITHOUT ANY WARRANTY; without even the implied warranty of +** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +** GNU General Public License for more details. +** +** You should have received a copy of the GNU General Public License +** along with msprime. If not, see . +*/ + +#define _GNU_SOURCE +/* + * Unit tests for the low-level msprime API. + */ +#include "tsk_genotypes.h" +#include "tsk_convert.h" +#include "tsk_stats.h" + +#include +#include +#include +#include + +#include +#include +#include + +/* Global variables used for test in state in the test suite */ + +char * _tmp_file_name; +FILE * _devnull; + +#define SIMPLE_BOTTLENECK 0 +#define INSTANTANEOUS_BOTTLENECK 1 + +typedef struct { + int type; + double time; + uint32_t population_id; + double parameter; +} bottleneck_desc_t; + +/* Example tree sequences used in some of the tests. */ + + +/* Simple single tree example. */ +const char *single_tree_ex_nodes =/* 6 */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / 5 */ + "0 1 -1 -1\n" /* 4 / \ */ + "0 2 -1 -1\n" /* / \ / \ */ + "0 3 -1 -1\n"; /* 0 1 2 3 */ +const char *single_tree_ex_edges = + "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; +const char *single_tree_ex_sites = + "0.1 0\n" + "0.2 0\n" + "0.3 0\n"; +const char *single_tree_ex_mutations = + "0 2 1 -1\n" + "1 4 1 -1\n" + "1 0 0 1\n" /* Back mutation over 0 */ + "2 0 1 -1\n" /* recurrent mutations over samples */ + "2 1 1 -1\n" + "2 2 1 -1\n" + "2 3 1 -1\n"; + +/* Example from the PLOS paper */ +const char *paper_ex_nodes = + "1 0 -1 0\n" + "1 0 -1 0\n" + "1 0 -1 1\n" + "1 0 -1 1\n" + "0 0.071 -1 -1\n" + "0 0.090 -1 -1\n" + "0 0.170 -1 -1\n" + "0 0.202 -1 -1\n" + "0 0.253 -1 -1\n"; +const char *paper_ex_edges = + "2 10 4 2\n" + "2 10 4 3\n" + "0 10 5 1\n" + "0 2 5 3\n" + "2 10 5 4\n" + "0 7 6 0,5\n" + "7 10 7 0,5\n" + "0 2 8 2,6\n"; +/* We make one mutation for each tree */ +const char *paper_ex_sites = + "1 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *paper_ex_mutations = + "0 2 1\n" + "1 0 1\n" + "2 5 1\n"; +/* Two (diploid) indivduals */ +const char *paper_ex_individuals = + "0 0.2,1.5\n" + "0 0.0,0.0\n"; + +/* An example of a nonbinary tree sequence */ +const char *nonbinary_ex_nodes = + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "0 0.01 0 -1\n" + "0 0.068 0 -1\n" + "0 0.130 0 -1\n" + "0 0.279 0 -1\n" + "0 0.405 0 -1\n"; +const char *nonbinary_ex_edges = + "0 100 8 0,1,2,3\n" + "0 100 9 6,8\n" + "0 100 10 4\n" + "0 17 10 5\n" + "0 100 10 7\n" + "17 100 11 5,9\n" + "0 17 12 9\n" + "0 100 12 10\n" + "17 100 12 11"; +const char *nonbinary_ex_sites = + "1 0\n" + "18 0\n"; +const char *nonbinary_ex_mutations = + "0 2 1\n" + "1 11 1"; + +/* An example of a tree sequence with unary nodes. */ + +const char *unary_ex_nodes = + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "0 0.071 0 -1\n" + "0 0.090 0 -1\n" + "0 0.170 0 -1\n" + "0 0.202 0 -1\n" + "0 0.253 0 -1\n"; +const char *unary_ex_edges = + "2 10 4 2,3\n" + "0 10 5 1\n" + "0 2 5 3\n" + "2 10 5 4\n" + "0 7 6 0,5\n" + "7 10 7 0\n" + "0 2 7 2\n" + "7 10 7 5\n" + "0 7 8 6\n" + "0 2 8 7\n"; + +/* We make one mutation for each tree, over unary nodes if this exist */ +const char *unary_ex_sites = + "1.0 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *unary_ex_mutations = + "0 2 1\n" + "1 6 1\n" + "2 5 1\n"; + +/* An example of a tree sequence with internally sampled nodes. */ + +/* TODO: find a way to draw these side-by-side */ +/* + 7 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +| | | 3 +| | | +| 1 2 +| +0 + + 8 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +3 | | | + | | | + 1 2 | + | + 0 + + 6 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +| | | 3 +| | | +| 1 2 +| +0 +*/ + +const char *internal_sample_ex_nodes = + "1 0.0 0 -1\n" + "1 0.1 0 -1\n" + "1 0.1 0 -1\n" + "1 0.2 0 -1\n" + "0 0.4 0 -1\n" + "1 0.5 0 -1\n" + "0 0.7 0 -1\n" + "0 1.0 0 -1\n" + "0 1.2 0 -1\n"; +const char *internal_sample_ex_edges = + "2 8 4 0\n" + "0 10 4 2\n" + "0 2 4 3\n" + "8 10 4 3\n" + "0 10 5 1,4\n" + "8 10 6 0,5\n" + "0 2 7 0,5\n" + "2 8 8 3,5\n"; +/* We make one mutation for each tree, some above the internal node */ +const char *internal_sample_ex_sites = + "1.0 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *internal_sample_ex_mutations = + "0 2 1\n" + "1 5 1\n" + "2 5 1\n"; + + +/* Simple utilities to parse text so we can write declaritive + * tests. This is not intended as a robust general input mechanism. + */ + +static void +parse_nodes(const char *text, tsk_node_tbl_t *node_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + double time; + int flags, population, individual; + char *name; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + flags = atoi(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + time = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + population = atoi(p); + p = strtok(NULL, whitespace); + if (p == NULL) { + individual = -1; + } else { + individual = atoi(p); + p = strtok(NULL, whitespace); + } + if (p == NULL) { + name = ""; + } else { + name = p; + } + ret = tsk_node_tbl_add_row(node_table, flags, time, population, + individual, name, strlen(name)); + CU_ASSERT_FATAL(ret >= 0); + } +} + +static void +parse_edges(const char *text, tsk_edge_tbl_t *edge_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE], sub_line[MAX_LINE]; + const char *whitespace = " \t"; + char *p, *q; + double left, right; + tsk_id_t parent, child; + uint32_t num_children; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + left = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + right = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + parent = atoi(p); + num_children = 0; + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + + num_children = 1; + q = p; + while (*q != '\0') { + if (*q == ',') { + num_children++; + } + q++; + } + CU_ASSERT_FATAL(num_children >= 1); + strncpy(sub_line, p, MAX_LINE); + q = strtok(sub_line, ","); + for (k = 0; k < num_children; k++) { + CU_ASSERT_FATAL(q != NULL); + child = atoi(q); + ret = tsk_edge_tbl_add_row(edge_table, left, right, parent, child); + CU_ASSERT_FATAL(ret >= 0); + q = strtok(NULL, ","); + } + CU_ASSERT_FATAL(q == NULL); + } +} + +static void +parse_sites(const char *text, tsk_site_tbl_t *site_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + double position; + char ancestral_state[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + position = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + strncpy(ancestral_state, p, MAX_LINE); + ret = tsk_site_tbl_add_row(site_table, position, ancestral_state, + strlen(ancestral_state), NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + } +} + +static void +parse_mutations(const char *text, tsk_mutation_tbl_t *mutation_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + tsk_id_t node; + tsk_id_t site; + tsk_id_t parent; + char derived_state[MAX_LINE]; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + site = atoi(p); + CU_ASSERT_FATAL(p != NULL); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + node = atoi(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + strncpy(derived_state, p, MAX_LINE); + parent = TSK_NULL; + p = strtok(NULL, whitespace); + if (p != NULL) { + parent = atoi(p); + } + ret = tsk_mutation_tbl_add_row(mutation_table, site, node, parent, + derived_state, strlen(derived_state), NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + } +} + +static void +parse_individuals(const char *text, tsk_individual_tbl_t *individual_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + char sub_line[MAX_LINE]; + const char *whitespace = " \t"; + char *p, *q; + double location[MAX_LINE]; + int location_len; + int flags; + char *name; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + flags = atoi(p); + + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + // the locations are comma-separated + location_len = 1; + q = p; + while (*q != '\0') { + if (*q == ',') { + location_len++; + } + q++; + } + CU_ASSERT_FATAL(location_len >= 1); + strncpy(sub_line, p, MAX_LINE); + q = strtok(sub_line, ","); + for (k = 0; k < location_len; k++) { + CU_ASSERT_FATAL(q != NULL); + location[k] = atof(q); + q = strtok(NULL, ","); + } + CU_ASSERT_FATAL(q == NULL); + p = strtok(NULL, whitespace); + if (p == NULL) { + name = ""; + } else { + name = p; + } + ret = tsk_individual_tbl_add_row(individual_table, flags, location, location_len, + name, strlen(name)); + CU_ASSERT_FATAL(ret >= 0); + } +} + +static void +tsk_treeseq_from_text(tsk_treeseq_t *ts, double sequence_length, + const char *nodes, const char *edges, + const char *migrations, const char *sites, const char *mutations, + const char *individuals, const char *provenance) +{ + int ret; + tsk_tbl_collection_t tables; + tsk_id_t max_population_id; + tsk_tbl_size_t j; + + CU_ASSERT_FATAL(ts != NULL); + CU_ASSERT_FATAL(nodes != NULL); + CU_ASSERT_FATAL(edges != NULL); + /* Not supporting provenance here for now */ + CU_ASSERT_FATAL(provenance == NULL); + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = sequence_length; + parse_nodes(nodes, tables.nodes); + parse_edges(edges, tables.edges); + if (sites != NULL) { + parse_sites(sites, tables.sites); + } + if (mutations != NULL) { + parse_mutations(mutations, tables.mutations); + } + if (individuals != NULL) { + parse_individuals(individuals, tables.individuals); + } + /* We need to add in populations if they are referenced */ + max_population_id = -1; + for (j = 0; j < tables.nodes->num_rows; j++) { + max_population_id = TSK_MAX(max_population_id, tables.nodes->population[j]); + } + if (max_population_id >= 0) { + for (j = 0; j <= (tsk_tbl_size_t) max_population_id; j++) { + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, j); + } + } + + ret = tsk_treeseq_alloc(ts, &tables, TSK_BUILD_INDEXES); + /* tsk_treeseq_print_state(ts, stdout); */ + /* printf("ret = %s\n", tsk_strerror(ret)); */ + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_tbl_collection_free(&tables); +} + +static int +get_max_site_mutations(tsk_treeseq_t *ts) +{ + int ret; + int max_mutations = 0; + size_t j; + tsk_site_t site; + + for (j = 0; j < tsk_treeseq_get_num_sites(ts); j++) { + ret = tsk_treeseq_get_site(ts, j, &site); + CU_ASSERT_EQUAL_FATAL(ret, 0); + max_mutations = TSK_MAX(max_mutations, site.mutations_length); + } + return max_mutations; +} + +static bool +multi_mutations_exist(tsk_treeseq_t *ts, size_t start, size_t end) +{ + int ret; + size_t j; + tsk_site_t site; + + for (j = 0; j < TSK_MIN(tsk_treeseq_get_num_sites(ts), end); j++) { + ret = tsk_treeseq_get_site(ts, j, &site); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (site.mutations_length > 1) { + return true; + } + } + return false; +} + +static void +unsort_edges(tsk_edge_tbl_t *edges, size_t start) +{ + size_t j, k; + size_t n = edges->num_rows - start; + tsk_edge_t *buff = malloc(n * sizeof(tsk_edge_t)); + gsl_rng *rng = gsl_rng_alloc(gsl_rng_default); + + CU_ASSERT_FATAL(edges != NULL); + CU_ASSERT_FATAL(rng != NULL); + gsl_rng_set(rng, 1); + + for (j = 0; j < n; j++) { + k = start + j; + buff[j].left = edges->left[k]; + buff[j].right = edges->right[k]; + buff[j].parent = edges->parent[k]; + buff[j].child = edges->child[k]; + } + gsl_ran_shuffle(rng, buff, n, sizeof(tsk_edge_t)); + for (j = 0; j < n; j++) { + k = start + j; + edges->left[k] = buff[j].left; + edges->right[k] = buff[j].right; + edges->parent[k] = buff[j].parent; + edges->child[k] = buff[j].child; + } + free(buff); + gsl_rng_free(rng); +} + +static void +unsort_sites(tsk_site_tbl_t *sites, tsk_mutation_tbl_t *mutations) +{ + double position; + char *ancestral_state = NULL; + size_t j, k, length; + + if (sites->num_rows > 1) { + /* Swap the first two sites */ + CU_ASSERT_EQUAL_FATAL(sites->ancestral_state_offset[0], 0); + + position = sites->position[0]; + length = sites->ancestral_state_offset[1]; + /* Save a copy of the first ancestral state */ + ancestral_state = malloc(length); + CU_ASSERT_FATAL(ancestral_state != NULL); + memcpy(ancestral_state, sites->ancestral_state, length); + /* Now write the ancestral state for the site 1 here */ + k = 0; + for (j = sites->ancestral_state_offset[1]; j < sites->ancestral_state_offset[2]; + j++) { + sites->ancestral_state[k] = sites->ancestral_state[j]; + k++; + } + sites->ancestral_state_offset[1] = k; + memcpy(sites->ancestral_state + k, ancestral_state, length); + sites->position[0] = sites->position[1]; + sites->position[1] = position; + + /* Update the mutations for these sites */ + j = 0; + while (j < mutations->num_rows && mutations->site[j] == 0) { + mutations->site[j] = 1; + j++; + } + while (j < mutations->num_rows && mutations->site[j] == 1) { + mutations->site[j] = 0; + + j++; + } + } + tsk_safe_free(ancestral_state); +} + +static void +add_individuals(tsk_treeseq_t *ts) +{ + int ret; + int max_inds = 20; + tsk_id_t j; + int k = 0; + int ploidy = 2; + tsk_tbl_collection_t tables; + char *metadata = "abc"; + size_t metadata_length = 3; + tsk_id_t *samples; + tsk_tbl_size_t num_samples = tsk_treeseq_get_num_samples(ts); + + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_copy_tables(ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_individual_tbl_clear(tables.individuals); + memset(tables.nodes->individual, 0xff, tables.nodes->num_rows * sizeof(tsk_id_t)); + + k = 0; + for (j = 0; j < num_samples; j++) { + if ((k % ploidy) == 0) { + tsk_individual_tbl_add_row(tables.individuals, (uint32_t) k, + NULL, 0, metadata, metadata_length); + CU_ASSERT_TRUE(ret >= 0) + } + tables.nodes->individual[samples[j]] = k / ploidy; + k += 1; + if (k >= ploidy * max_inds) { + break; + } + } + ret = tsk_treeseq_free(ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_alloc(ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_tbl_collection_free(&tables); +} + +static void +verify_nodes_equal(tsk_node_t *n1, tsk_node_t *n2) +{ + double eps = 1e-6; + + CU_ASSERT_DOUBLE_EQUAL_FATAL(n1->time, n1->time, eps); + CU_ASSERT_EQUAL_FATAL(n1->population, n2->population); + CU_ASSERT_EQUAL_FATAL(n1->flags, n2->flags); + CU_ASSERT_FATAL(n1->metadata_length == n2->metadata_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(n1->metadata, n2->metadata, n1->metadata_length); +} + +static void +verify_edges_equal(tsk_edge_t *r1, tsk_edge_t *r2, double scale) +{ + double eps = 1e-6; + + CU_ASSERT_DOUBLE_EQUAL_FATAL(r1->left * scale, r2->left, eps); + CU_ASSERT_DOUBLE_EQUAL_FATAL(r1->right * scale, r2->right, eps); + CU_ASSERT_EQUAL_FATAL(r1->parent, r2->parent); + CU_ASSERT_EQUAL_FATAL(r1->child, r2->child); +} + +static void +verify_migrations_equal(tsk_migration_t *r1, tsk_migration_t *r2, double scale) +{ + double eps = 1e-6; + + CU_ASSERT_DOUBLE_EQUAL_FATAL(r1->left * scale, r2->left, eps); + CU_ASSERT_DOUBLE_EQUAL_FATAL(r1->right * scale, r2->right, eps); + CU_ASSERT_DOUBLE_EQUAL_FATAL(r1->time, r2->time, eps); + CU_ASSERT_EQUAL_FATAL(r1->node, r2->node); + CU_ASSERT_EQUAL_FATAL(r1->source, r2->source); + CU_ASSERT_EQUAL_FATAL(r1->dest, r2->dest); +} + +static void +verify_provenances_equal(tsk_provenance_t *p1, tsk_provenance_t *p2) +{ + CU_ASSERT_FATAL(p1->timestamp_length == p2->timestamp_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(p1->timestamp, p2->timestamp, p1->timestamp_length); + CU_ASSERT_FATAL(p1->record_length == p2->record_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(p1->record, p2->record, p1->record_length); +} + +static void +verify_individuals_equal(tsk_individual_t *i1, tsk_individual_t *i2) +{ + tsk_tbl_size_t j; + + CU_ASSERT_FATAL(i1->id == i2->id); + CU_ASSERT_FATAL(i1->flags == i2->flags); + CU_ASSERT_FATAL(i1->metadata_length == i2->metadata_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(i1->metadata, i2->metadata, i1->metadata_length); + CU_ASSERT_FATAL(i1->location_length == i2->location_length); + for (j = 0; j < i1->location_length; j++) { + CU_ASSERT_EQUAL_FATAL(i1->location[j], i2->location[j]); + } +} + +static void +verify_populations_equal(tsk_population_t *p1, tsk_population_t *p2) +{ + CU_ASSERT_FATAL(p1->id == p2->id); + CU_ASSERT_FATAL(p1->metadata_length == p2->metadata_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(p1->metadata, p2->metadata, p1->metadata_length); +} + +static tsk_tree_t * +get_tree_list(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t t, *trees; + size_t num_trees; + + num_trees = tsk_treeseq_get_num_trees(ts); + ret = tsk_tree_alloc(&t, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + trees = malloc(num_trees * sizeof(tsk_tree_t)); + CU_ASSERT_FATAL(trees != NULL); + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_FATAL(t.index < num_trees); + ret = tsk_tree_alloc(&trees[t.index], ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_copy(&trees[t.index], &t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_equal(&trees[t.index], &t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Make sure the left and right coordinates are also OK */ + CU_ASSERT_DOUBLE_EQUAL(trees[t.index].left, t.left, 1e-6); + CU_ASSERT_DOUBLE_EQUAL(trees[t.index].right, t.right, 1e-6); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + return trees; +} + +static void +verify_tree_next_prev(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t *trees, t; + size_t j; + size_t num_trees = tsk_treeseq_get_num_trees(ts); + + trees = get_tree_list(ts); + ret = tsk_tree_alloc(&t, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Single forward pass */ + j = 0; + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + + /* Single reverse pass */ + j = num_trees; + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + if (ret != 0) { + printf("trees differ\n"); + printf("REVERSE tree::\n"); + tsk_tree_print_state(&t, stdout); + printf("FORWARD tree::\n"); + tsk_tree_print_state(&trees[t.index], stdout); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + + /* Full forward, then reverse */ + j = 0; + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + j--; + while ((ret = tsk_tree_prev(&t)) == 1) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + CU_ASSERT_EQUAL_FATAL(t.index, 0); + /* Calling prev should return 0 and have no effect. */ + for (j = 0; j < 10; j++) { + ret = tsk_tree_prev(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, 0); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Full reverse then forward */ + j = num_trees; + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + j++; + while ((ret = tsk_tree_next(&t)) == 1) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - 1); + /* Calling next should return 0 and have no effect. */ + for (j = 0; j < 10; j++) { + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - 1); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Do a zigzagging traversal */ + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 1; j < TSK_MIN(10, num_trees / 2); j++) { + while (t.index < num_trees - j) { + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - j); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + while (t.index > j) { + ret = tsk_tree_prev(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + CU_ASSERT_EQUAL_FATAL(t.index, j); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Free the trees. */ + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < tsk_treeseq_get_num_trees(ts); j++) { + ret = tsk_tree_free(&trees[j]); + } + free(trees); +} + +static void +verify_hapgen(tsk_treeseq_t *ts) +{ + int ret; + tsk_hapgen_t hapgen; + char *haplotype; + size_t num_samples = tsk_treeseq_get_num_samples(ts); + size_t num_sites = tsk_treeseq_get_num_sites(ts); + tsk_site_t site; + size_t j; + int k; + bool single_char = true; + + for (j = 0; j < num_sites; j++) { + ret = tsk_treeseq_get_site(ts, j, &site); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (site.ancestral_state_length != 1) { + single_char = false; + } + for (k = 0; k < site.mutations_length; k++) { + if (site.mutations[k].derived_state_length != 1) { + single_char = false; + } + } + } + + ret = tsk_hapgen_alloc(&hapgen, ts); + if (single_char) { + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + + for (j = 0; j < num_samples; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(strlen(haplotype), num_sites); + } + for (j = num_samples; j < num_samples + 10; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, j, &haplotype); + CU_ASSERT_EQUAL(ret, TSK_ERR_OUT_OF_BOUNDS); + } + } else { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NON_SINGLE_CHAR_MUTATION); + } + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); +} + +static void +verify_vargen(tsk_treeseq_t *ts) +{ + int ret; + tsk_vargen_t vargen; + size_t num_samples = tsk_treeseq_get_num_samples(ts); + size_t num_sites = tsk_treeseq_get_num_sites(ts); + tsk_variant_t *var; + size_t j, k, f, s; + int flags[] = {0, TSK_16_BIT_GENOTYPES}; + tsk_id_t *samples[] = {NULL, NULL}; + + ret = tsk_treeseq_get_samples(ts, samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (s = 0; s < 2; s++) { + for (f = 0; f < sizeof(flags) / sizeof(*flags); f++) { + ret = tsk_vargen_alloc(&vargen, ts, samples[s], num_samples, flags[f]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_print_state(&vargen, _devnull); + j = 0; + while ((ret = tsk_vargen_next(&vargen, &var)) == 1) { + CU_ASSERT_EQUAL(var->site->id, j); + if (var->site->mutations_length == 0) { + CU_ASSERT_EQUAL(var->num_alleles, 1); + } else { + CU_ASSERT_TRUE(var->num_alleles > 1); + } + CU_ASSERT_EQUAL(var->allele_lengths[0], var->site->ancestral_state_length); + CU_ASSERT_NSTRING_EQUAL_FATAL(var->alleles[0], var->site->ancestral_state, + var->allele_lengths[0]); + for (k = 0; k < var->num_alleles; k++) { + CU_ASSERT_TRUE(var->allele_lengths[k] >= 0); + } + for (k = 0; k < num_samples; k++) { + if (flags[f] == TSK_16_BIT_GENOTYPES) { + CU_ASSERT(var->genotypes.u16[k] <= var->num_alleles); + } else { + CU_ASSERT(var->genotypes.u8[k] <= var->num_alleles); + } + } + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(j, num_sites); + CU_ASSERT_EQUAL_FATAL(tsk_vargen_next(&vargen, &var), 0); + ret = tsk_vargen_free(&vargen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + } +} + +static void +verify_stats(tsk_treeseq_t *ts) +{ + int ret; + uint32_t num_samples = tsk_treeseq_get_num_samples(ts); + tsk_id_t *samples; + uint32_t j; + double pi; + int max_site_mutations = get_max_site_mutations(ts); + + ret = tsk_treeseq_get_pairwise_diversity(ts, NULL, 0, &pi); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_treeseq_get_pairwise_diversity(ts, NULL, 1, &pi); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_treeseq_get_pairwise_diversity(ts, NULL, num_samples + 1, &pi); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 2; j < num_samples; j++) { + ret = tsk_treeseq_get_pairwise_diversity(ts, samples, j, &pi); + if (max_site_mutations <= 1) { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE_FATAL(pi >= 0); + } else { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } + } +} + +/* FIXME: this test is weak and should check the return value somehow. + * We should also have simplest and single tree tests along with separate + * tests for the error conditions. This should be done as part of the general + * stats framework. + */ +static void +verify_genealogical_nearest_neighbours(tsk_treeseq_t *ts) +{ + int ret; + tsk_id_t *samples; + tsk_id_t *sample_sets[2]; + size_t sample_set_size[2]; + size_t num_samples = tsk_treeseq_get_num_samples(ts); + double *A = malloc(2 * num_samples * sizeof(double)); + CU_ASSERT_FATAL(A != NULL); + + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + sample_sets[0] = samples; + sample_set_size[0] = num_samples / 2; + sample_sets[1] = samples + sample_set_size[0]; + sample_set_size[1] = num_samples - sample_set_size[0]; + + ret = tsk_treeseq_genealogical_nearest_neighbours(ts, + samples, num_samples, sample_sets, sample_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + free(A); +} + +/* FIXME: this test is weak and should check the return value somehow. + * We should also have simplest and single tree tests along with separate + * tests for the error conditions. This should be done as part of the general + * stats framework. + */ +static void +verify_mean_descendants(tsk_treeseq_t *ts) +{ + int ret; + tsk_id_t *samples; + tsk_id_t *sample_sets[2]; + size_t sample_set_size[2]; + size_t num_samples = tsk_treeseq_get_num_samples(ts); + double *C = malloc(2 * tsk_treeseq_get_num_nodes(ts) * sizeof(double)); + CU_ASSERT_FATAL(C != NULL); + + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + sample_sets[0] = samples; + sample_set_size[0] = num_samples / 2; + sample_sets[1] = samples + sample_set_size[0]; + sample_set_size[1] = num_samples - sample_set_size[0]; + + ret = tsk_treeseq_mean_descendants(ts, sample_sets, sample_set_size, 2, 0, C); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Check some error conditions */ + ret = tsk_treeseq_mean_descendants(ts, sample_sets, sample_set_size, 0, 0, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + samples[0] = -1; + ret = tsk_treeseq_mean_descendants(ts, sample_sets, sample_set_size, 2, 0, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_OUT_OF_BOUNDS); + samples[0] = tsk_treeseq_get_num_nodes(ts) + 1; + ret = tsk_treeseq_mean_descendants(ts, sample_sets, sample_set_size, 2, 0, C); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_OUT_OF_BOUNDS); + + free(C); +} + + +static void +verify_compute_mutation_parents(tsk_treeseq_t *ts) +{ + int ret; + size_t size = tsk_treeseq_get_num_mutations(ts) * sizeof(tsk_id_t); + tsk_id_t *parent = malloc(size); + tsk_tbl_collection_t tables; + + CU_ASSERT_FATAL(parent != NULL); + ret = tsk_treeseq_copy_tables(ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + memcpy(parent, tables.mutations->parent, size); + /* tsk_tbl_collection_print_state(&tables, stdout); */ + /* Make sure the tables are actually updated */ + memset(tables.mutations->parent, 0xff, size); + + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(memcmp(parent, tables.mutations->parent, size), 0); + /* printf("after\n"); */ + /* tsk_tbl_collection_print_state(&tables, stdout); */ + + free(parent); + tsk_tbl_collection_free(&tables); +} + +static void +verify_individual_nodes(tsk_treeseq_t *ts) +{ + int ret; + tsk_individual_t individual; + tsk_id_t k; + size_t num_nodes = tsk_treeseq_get_num_nodes(ts); + size_t num_individuals = tsk_treeseq_get_num_individuals(ts); + size_t j; + + for (k = 0; k < (tsk_id_t) num_individuals; k++) { + ret = tsk_treeseq_get_individual(ts, k, &individual); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(individual.nodes_length >= 0); + for (j = 0; j < individual.nodes_length; j++) { + CU_ASSERT_FATAL(individual.nodes[j] < num_nodes); + CU_ASSERT_EQUAL_FATAL(k, + ts->tables->nodes->individual[individual.nodes[j]]); + } + } +} + +/* When we keep all sites in simplify, the genotypes for the subset of the + * samples should be the same as the original */ +static void +verify_simplify_genotypes(tsk_treeseq_t *ts, tsk_treeseq_t *subset, + tsk_id_t *samples, uint32_t num_samples, tsk_id_t *node_map) +{ + int ret; + size_t m = tsk_treeseq_get_num_sites(ts); + tsk_vargen_t vargen, subset_vargen; + tsk_variant_t *variant, *subset_variant; + size_t j, k; + tsk_id_t *all_samples; + uint8_t a1, a2; + tsk_id_t *sample_index_map; + + tsk_treeseq_get_sample_index_map(ts, &sample_index_map); + + /* tsk_treeseq_print_state(ts, stdout); */ + /* tsk_treeseq_print_state(subset, stdout); */ + + ret = tsk_vargen_alloc(&vargen, ts, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_vargen_alloc(&subset_vargen, subset, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(m, tsk_treeseq_get_num_sites(subset)); + tsk_treeseq_get_samples(ts, &all_samples); + + for (j = 0; j < m; j++) { + ret = tsk_vargen_next(&vargen, &variant); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_vargen_next(&subset_vargen, &subset_variant); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(variant->site->id, j) + CU_ASSERT_EQUAL(subset_variant->site->id, j) + CU_ASSERT_EQUAL(variant->site->position, subset_variant->site->position); + for (k = 0; k < num_samples; k++) { + CU_ASSERT_FATAL(sample_index_map[samples[k]] < ts->num_samples); + a1 = variant->genotypes.u8[sample_index_map[samples[k]]]; + a2 = subset_variant->genotypes.u8[k]; + /* printf("a1 = %d, a2 = %d\n", a1, a2); */ + /* printf("k = %d original node = %d " */ + /* "original_index = %d a1=%.*s a2=%.*s\n", */ + /* (int) k, samples[k], sample_index_map[samples[k]], */ + /* variant->allele_lengths[a1], variant->alleles[a1], */ + /* subset_variant->allele_lengths[a2], subset_variant->alleles[a2]); */ + CU_ASSERT_FATAL(a1 < variant->num_alleles); + CU_ASSERT_FATAL(a2 < subset_variant->num_alleles); + CU_ASSERT_EQUAL_FATAL(variant->allele_lengths[a1], + subset_variant->allele_lengths[a2]); + CU_ASSERT_NSTRING_EQUAL_FATAL( + variant->alleles[a1], subset_variant->alleles[a2], + variant->allele_lengths[a1]); + } + } + tsk_vargen_free(&vargen); + tsk_vargen_free(&subset_vargen); +} + + +static void +verify_simplify_properties(tsk_treeseq_t *ts, tsk_treeseq_t *subset, + tsk_id_t *samples, uint32_t num_samples, tsk_id_t *node_map) +{ + int ret; + tsk_node_t n1, n2; + tsk_tree_t full_tree, subset_tree; + tsk_site_t *tree_sites; + tsk_tbl_size_t tree_sites_length; + uint32_t j, k; + tsk_id_t u, mrca1, mrca2; + size_t total_sites; + + CU_ASSERT_EQUAL( + tsk_treeseq_get_sequence_length(ts), + tsk_treeseq_get_sequence_length(subset)); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(subset), num_samples); + CU_ASSERT( + tsk_treeseq_get_num_nodes(ts) >= tsk_treeseq_get_num_nodes(subset)); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(subset), num_samples); + + /* Check the sample properties */ + for (j = 0; j < num_samples; j++) { + ret = tsk_treeseq_get_node(ts, samples[j], &n1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node_map[samples[j]], j); + ret = tsk_treeseq_get_node(subset, node_map[samples[j]], &n2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n1.population, n2.population); + CU_ASSERT_EQUAL_FATAL(n1.time, n2.time); + CU_ASSERT_EQUAL_FATAL(n1.flags, n2.flags); + CU_ASSERT_EQUAL_FATAL(n1.metadata_length, n2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(n1.metadata, n2.metadata, n2.metadata_length); + } + /* Check that node mappings are correct */ + for (j = 0; j < tsk_treeseq_get_num_nodes(ts); j++) { + ret = tsk_treeseq_get_node(ts, j, &n1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (node_map[j] != TSK_NULL) { + ret = tsk_treeseq_get_node(subset, node_map[j], &n2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n1.population, n2.population); + CU_ASSERT_EQUAL_FATAL(n1.time, n2.time); + CU_ASSERT_EQUAL_FATAL(n1.flags, n2.flags); + CU_ASSERT_EQUAL_FATAL(n1.metadata_length, n2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(n1.metadata, n2.metadata, n2.metadata_length); + } + } + if (num_samples == 0) { + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(subset), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(subset), 0); + } else if (num_samples == 1) { + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(subset), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(subset), 1); + } + /* Check the pairwise MRCAs */ + ret = tsk_tree_alloc(&full_tree, ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_alloc(&subset_tree, subset, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&full_tree); + CU_ASSERT_EQUAL(ret, 1); + ret = tsk_tree_first(&subset_tree); + CU_ASSERT_EQUAL(ret, 1); + + total_sites = 0; + while (1) { + while (full_tree.right <= subset_tree.right) { + for (j = 0; j < num_samples; j++) { + for (k = j + 1; k < num_samples; k++) { + ret = tsk_tree_get_mrca(&full_tree, samples[j], samples[k], &mrca1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_get_mrca(&subset_tree, + node_map[samples[j]], node_map[samples[k]], &mrca2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (mrca1 == TSK_NULL) { + CU_ASSERT_EQUAL_FATAL(mrca2, TSK_NULL); + } else { + CU_ASSERT_EQUAL(node_map[mrca1], mrca2); + } + } + } + ret = tsk_tree_next(&full_tree); + CU_ASSERT_FATAL(ret >= 0); + if (ret != 1) { + break; + } + } + /* Check the sites in this tree */ + ret = tsk_tree_get_sites(&subset_tree, &tree_sites, &tree_sites_length); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < tree_sites_length; j++) { + CU_ASSERT(subset_tree.left <= tree_sites[j].position); + CU_ASSERT(tree_sites[j].position < subset_tree.right); + for (k = 0; k < tree_sites[j].mutations_length; k++) { + ret = tsk_tree_get_parent(&subset_tree, + tree_sites[j].mutations[k].node, &u); + CU_ASSERT_EQUAL(ret, 0); + } + total_sites++; + } + ret = tsk_tree_next(&subset_tree); + if (ret != 1) { + break; + } + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(subset), total_sites); + + tsk_tree_free(&subset_tree); + tsk_tree_free(&full_tree); + verify_vargen(subset); + verify_hapgen(subset); +} + +static void +verify_simplify(tsk_treeseq_t *ts) +{ + int ret; + uint32_t n = tsk_treeseq_get_num_samples(ts); + uint32_t num_samples[] = {0, 1, 2, 3, n / 2, n - 1, n}; + size_t j; + tsk_id_t *sample; + tsk_id_t *node_map = malloc(tsk_treeseq_get_num_nodes(ts) * sizeof(tsk_id_t)); + tsk_treeseq_t subset; + int flags = TSK_FILTER_SITES; + + CU_ASSERT_FATAL(node_map != NULL); + ret = tsk_treeseq_get_samples(ts, &sample); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (tsk_treeseq_get_num_migrations(ts) > 0) { + ret = tsk_treeseq_simplify(ts, sample, 2, 0, &subset, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED); + /* Exiting early here because simplify isn't supported with migrations. */ + goto out; + } + + for (j = 0; j < sizeof(num_samples) / sizeof(uint32_t); j++) { + if (num_samples[j] <= n) { + ret = tsk_treeseq_simplify(ts, sample, num_samples[j], flags, &subset, + node_map); + /* printf("ret = %s\n", tsk_strerror(ret)); */ + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_simplify_properties(ts, &subset, sample, num_samples[j], node_map); + tsk_treeseq_free(&subset); + + /* Keep all sites */ + ret = tsk_treeseq_simplify(ts, sample, num_samples[j], 0, &subset, + node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_simplify_properties(ts, &subset, sample, num_samples[j], node_map); + verify_simplify_genotypes(ts, &subset, sample, num_samples[j], node_map); + tsk_treeseq_free(&subset); + } + } +out: + free(node_map); +} + +static void +verify_reduce_topology(tsk_treeseq_t *ts) +{ + int ret; + size_t j; + tsk_id_t *sample; + tsk_treeseq_t reduced; + tsk_edge_t edge; + double *X; + size_t num_sites; + size_t n = tsk_treeseq_get_num_samples(ts); + int flags = TSK_REDUCE_TO_SITE_TOPOLOGY; + + ret = tsk_treeseq_get_samples(ts, &sample); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + if (tsk_treeseq_get_num_migrations(ts) > 0) { + ret = tsk_treeseq_simplify(ts, sample, 2, flags, &reduced, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED); + return; + } + + ret = tsk_treeseq_simplify(ts, sample, n, flags, &reduced, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + X = reduced.tables->sites->position; + num_sites = reduced.tables->sites->num_rows; + if (num_sites == 0) { + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_edges(&reduced), 0); + } + for (j = 0; j < tsk_treeseq_get_num_edges(&reduced); j++) { + ret = tsk_treeseq_get_edge(&reduced, j, &edge); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (edge.left != 0) { + CU_ASSERT_EQUAL_FATAL(edge.left, + X[tsk_search_sorted(X, num_sites, edge.left)]); + } + if (edge.right != tsk_treeseq_get_sequence_length(&reduced)) { + CU_ASSERT_EQUAL_FATAL(edge.right, + X[tsk_search_sorted(X, num_sites, edge.right)]); + } + } + tsk_treeseq_free(&reduced); +} + +/* Utility function to return a tree sequence for testing. It is the + * callers responsilibility to free all memory. + */ +static tsk_treeseq_t * +get_example_tree_sequence(uint32_t num_samples, + uint32_t num_historical_samples, uint32_t num_loci, + double sequence_length, double recombination_rate, + double mutation_rate, uint32_t num_bottlenecks, + bottleneck_desc_t *bottlenecks, int alphabet) +{ + return NULL; +} + +tsk_treeseq_t ** +get_example_nonbinary_tree_sequences(void) +{ + return NULL; +} + +tsk_treeseq_t * +make_recurrent_and_back_mutations_copy(tsk_treeseq_t *ts) +{ + return NULL; +} + +tsk_treeseq_t * +make_permuted_nodes_copy(tsk_treeseq_t *ts) +{ + return NULL; +} + +/* Insert some gaps into the specified tree sequence, i.e., positions + * that no edge covers. */ +tsk_treeseq_t * +make_gappy_copy(tsk_treeseq_t *ts) +{ + return NULL; +} + +/* Return a copy of the tree sequence after deleting half of its edges. + */ +tsk_treeseq_t * +make_decapitated_copy(tsk_treeseq_t *ts) +{ + return NULL; +} + +tsk_treeseq_t * +make_multichar_mutations_copy(tsk_treeseq_t *ts) +{ + return NULL; +} + +tsk_treeseq_t ** +get_example_tree_sequences(int include_nonbinary) +{ + size_t max_examples = 1024; + tsk_treeseq_t **ret = malloc(max_examples * sizeof(tsk_treeseq_t *)); + ret[0] = NULL; + return ret; +} + +static void +verify_vcf_converter(tsk_treeseq_t *ts, unsigned int ploidy) +{ + int ret; + char *str = NULL; + tsk_vcf_converter_t vc; + unsigned int num_variants; + + ret = tsk_vcf_converter_alloc(&vc, ts, ploidy, "chr1234"); + CU_ASSERT_FATAL(ret == 0); + tsk_vcf_converter_print_state(&vc, _devnull); + ret = tsk_vcf_converter_get_header(&vc, &str); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL("##", str, 2); + num_variants = 0; + while ((ret = tsk_vcf_converter_next(&vc, &str)) == 1) { + CU_ASSERT_NSTRING_EQUAL("chr1234\t", str, 2); + num_variants++; + } + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_TRUE(num_variants == tsk_treeseq_get_num_mutations(ts)); + tsk_vcf_converter_free(&vc); +} + +static void +test_vcf(void) +{ + int ret; + unsigned int ploidy; + tsk_vcf_converter_t *vc = malloc(sizeof(tsk_vcf_converter_t)); + tsk_treeseq_t *ts = get_example_tree_sequence(10, 0, 100, 100.0, 1.0, 1.0, + 0, NULL, 0); + + CU_ASSERT_FATAL(ts != NULL); + CU_ASSERT_FATAL(vc != NULL); + + ret = tsk_vcf_converter_alloc(vc, ts, 0, "1"); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_vcf_converter_alloc(vc, ts, 3, "1"); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_vcf_converter_alloc(vc, ts, 11, "1"); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + for (ploidy = 1; ploidy < 3; ploidy++) { + verify_vcf_converter(ts, ploidy); + } + + free(vc); + tsk_treeseq_free(ts); + free(ts); +} + +static void +test_vcf_no_mutations(void) +{ + int ret; + char *str = NULL; + tsk_vcf_converter_t *vc = malloc(sizeof(tsk_vcf_converter_t)); + tsk_treeseq_t *ts = get_example_tree_sequence(100, 0, 1, 1.0, 0.0, 0.0, 0, NULL, + 0); + + CU_ASSERT_FATAL(ts != NULL); + CU_ASSERT_FATAL(vc != NULL); + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_mutations(ts), 0); + + ret = tsk_vcf_converter_alloc(vc, ts, 1, "1"); + CU_ASSERT_FATAL(ret == 0); + tsk_vcf_converter_print_state(vc, _devnull); + ret = tsk_vcf_converter_get_header(vc, &str); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL("##", str, 2); + ret = tsk_vcf_converter_next(vc, &str); + CU_ASSERT_EQUAL(ret, 0); + tsk_vcf_converter_free(vc); + + free(vc); + tsk_treeseq_free(ts); + free(ts); +} + +static void +test_node_metadata(void) +{ + const char *nodes = + "1 0 0 -1 n1\n" + "1 0 0 -1 n2\n" + "0 1 0 -1 A_much_longer_name\n" + "0 1 0 -1\n" + "0 1 0 -1 n4"; + const char *edges = + "0 1 2 0,1\n"; + tsk_treeseq_t ts; + int ret; + tsk_node_t node; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 5); + + ret = tsk_treeseq_get_node(&ts, 0, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL(node.metadata, "n1", 2); + + ret = tsk_treeseq_get_node(&ts, 1, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL(node.metadata, "n2", 2); + + ret = tsk_treeseq_get_node(&ts, 2, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL(node.metadata, "A_much_longer_name", 18); + + ret = tsk_treeseq_get_node(&ts, 3, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL(node.metadata, "", 0); + + ret = tsk_treeseq_get_node(&ts, 4, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_NSTRING_EQUAL(node.metadata, "n4", 2); + + tsk_treeseq_free(&ts); +} + +static void +verify_trees_consistent(tsk_treeseq_t *ts) +{ + int ret; + size_t num_trees; + tsk_tree_t tree; + + ret = tsk_tree_alloc(&tree, ts, 0); + CU_ASSERT_EQUAL(ret, 0); + + num_trees = 0; + for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { + tsk_tree_print_state(&tree, _devnull); + CU_ASSERT_EQUAL(tree.index, num_trees); + num_trees++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(ts), num_trees); + + tsk_tree_free(&tree); +} + +static void +verify_ld(tsk_treeseq_t *ts) +{ + int ret; + size_t num_sites = tsk_treeseq_get_num_sites(ts); + tsk_site_t *sites = malloc(num_sites * sizeof(tsk_site_t)); + int *num_site_mutations = malloc(num_sites * sizeof(int)); + tsk_ld_calc_t ld_calc; + double *r2, *r2_prime, x; + size_t j, num_r2_values; + double eps = 1e-6; + + r2 = calloc(num_sites, sizeof(double)); + r2_prime = calloc(num_sites, sizeof(double)); + CU_ASSERT_FATAL(r2 != NULL); + CU_ASSERT_FATAL(r2_prime != NULL); + CU_ASSERT_FATAL(sites != NULL); + CU_ASSERT_FATAL(num_site_mutations != NULL); + + ret = tsk_ld_calc_alloc(&ld_calc, ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_ld_calc_print_state(&ld_calc, _devnull); + + for (j = 0; j < num_sites; j++) { + ret = tsk_treeseq_get_site(ts, j, sites + j); + CU_ASSERT_EQUAL_FATAL(ret, 0); + num_site_mutations[j] = sites[j].mutations_length; + ret = tsk_ld_calc_get_r2(&ld_calc, j, j, &x); + if (num_site_mutations[j] <= 1) { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(x, 1.0, eps); + } else { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } + } + + if (num_sites > 0) { + /* Some checks in the forward direction */ + ret = tsk_ld_calc_get_r2_array(&ld_calc, 0, TSK_DIR_FORWARD, + num_sites, DBL_MAX, r2, &num_r2_values); + if (multi_mutations_exist(ts, 0, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, num_sites - 1); + } + tsk_ld_calc_print_state(&ld_calc, _devnull); + + ret = tsk_ld_calc_get_r2_array(&ld_calc, num_sites - 2, TSK_DIR_FORWARD, + num_sites, DBL_MAX, r2_prime, &num_r2_values); + if (multi_mutations_exist(ts, num_sites - 2, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, 1); + } + tsk_ld_calc_print_state(&ld_calc, _devnull); + + ret = tsk_ld_calc_get_r2_array(&ld_calc, 0, TSK_DIR_FORWARD, + num_sites, DBL_MAX, r2_prime, &num_r2_values); + if (multi_mutations_exist(ts, 0, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, num_sites - 1); + tsk_ld_calc_print_state(&ld_calc, _devnull); + for (j = 0; j < num_r2_values; j++) { + CU_ASSERT_EQUAL_FATAL(r2[j], r2_prime[j]); + ret = tsk_ld_calc_get_r2(&ld_calc, 0, j + 1, &x); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(r2[j], x, eps); + } + + } + + /* Some checks in the reverse direction */ + ret = tsk_ld_calc_get_r2_array(&ld_calc, num_sites - 1, + TSK_DIR_REVERSE, num_sites, DBL_MAX, + r2, &num_r2_values); + if (multi_mutations_exist(ts, 0, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, num_sites - 1); + } + tsk_ld_calc_print_state(&ld_calc, _devnull); + + ret = tsk_ld_calc_get_r2_array(&ld_calc, 1, TSK_DIR_REVERSE, + num_sites, DBL_MAX, r2_prime, &num_r2_values); + if (multi_mutations_exist(ts, 0, 1)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, 1); + } + tsk_ld_calc_print_state(&ld_calc, _devnull); + + ret = tsk_ld_calc_get_r2_array(&ld_calc, num_sites - 1, + TSK_DIR_REVERSE, num_sites, DBL_MAX, + r2_prime, &num_r2_values); + if (multi_mutations_exist(ts, 0, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, num_sites - 1); + tsk_ld_calc_print_state(&ld_calc, _devnull); + + for (j = 0; j < num_r2_values; j++) { + CU_ASSERT_EQUAL_FATAL(r2[j], r2_prime[j]); + ret = tsk_ld_calc_get_r2(&ld_calc, num_sites - 1, + num_sites - j - 2, &x); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(r2[j], x, eps); + } + } + + /* Check some error conditions */ + ret = tsk_ld_calc_get_r2_array(&ld_calc, 0, 0, num_sites, DBL_MAX, + r2, &num_r2_values); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + } + + if (num_sites > 3) { + /* Check for some basic distance calculations */ + j = num_sites / 2; + x = sites[j + 1].position - sites[j].position; + ret = tsk_ld_calc_get_r2_array(&ld_calc, j, TSK_DIR_FORWARD, num_sites, + x, r2, &num_r2_values); + if (multi_mutations_exist(ts, j, num_sites)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, 1); + } + + x = sites[j].position - sites[j - 1].position; + ret = tsk_ld_calc_get_r2_array(&ld_calc, j, TSK_DIR_REVERSE, num_sites, + x, r2, &num_r2_values); + if (multi_mutations_exist(ts, 0, j + 1)) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_ONLY_INFINITE_SITES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_r2_values, 1); + } + } + + /* Check some error conditions */ + for (j = num_sites; j < num_sites + 2; j++) { + ret = tsk_ld_calc_get_r2_array(&ld_calc, j, TSK_DIR_FORWARD, + num_sites, DBL_MAX, r2, &num_r2_values); + CU_ASSERT_EQUAL(ret, TSK_ERR_OUT_OF_BOUNDS); + ret = tsk_ld_calc_get_r2(&ld_calc, j, 0, r2); + CU_ASSERT_EQUAL(ret, TSK_ERR_OUT_OF_BOUNDS); + ret = tsk_ld_calc_get_r2(&ld_calc, 0, j, r2); + CU_ASSERT_EQUAL(ret, TSK_ERR_OUT_OF_BOUNDS); + } + + tsk_ld_calc_free(&ld_calc); + free(r2); + free(r2_prime); + free(sites); + free(num_site_mutations); +} + +static void +verify_empty_tree_sequence(tsk_treeseq_t *ts, double sequence_length) +{ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_migrations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), sequence_length); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(ts), 1); + verify_trees_consistent(ts); + verify_ld(ts); + verify_stats(ts); + verify_hapgen(ts); + verify_vargen(ts); + verify_vcf_converter(ts, 1); +} +static void +test_single_tree_newick(void) +{ + /* int ret; */ + /* tsk_treeseq_t ts; */ + /* tsk_tree_t t; */ + /* size_t buffer_size = 1024; */ + /* char newick[buffer_size]; */ + + /* tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, */ + /* NULL, NULL, NULL, NULL, NULL); */ + /* CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); */ + /* CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); */ + + /* ret = tsk_tree_alloc(&t, &ts, 0); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0) */ + /* ret = tsk_tree_first(&t); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 1) */ + + + /* ret = tsk_tree_get_newick(&t, -1, 1, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); */ + /* ret = tsk_tree_get_newick(&t, 7, 1, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); */ + + /* ret = tsk_tree_get_newick(&t, 0, 0, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + /* /1* Seems odd, but this is what a single node newick tree looks like. */ + /* * Newick parsers seems to accept it in any case *1/ */ + /* CU_ASSERT_STRING_EQUAL(newick, "1;"); */ + + /* ret = tsk_tree_get_newick(&t, 4, 0, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + /* CU_ASSERT_STRING_EQUAL(newick, "(1:1,2:1);"); */ + + /* ret = tsk_tree_get_newick(&t, 6, 0, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + /* CU_ASSERT_STRING_EQUAL(newick, "((1:1,2:1):2,(3:2,4:2):1);"); */ + + /* tsk_tree_free(&t); */ + /* tsk_treeseq_free(&ts); */ +} + + +static void +verify_sample_sets_for_tree(tsk_tree_t *tree) +{ + int ret, stack_top, j; + tsk_id_t u, v, n, num_nodes, num_samples; + size_t tmp; + tsk_id_t *stack, *samples; + tsk_treeseq_t *ts = tree->tree_sequence; + tsk_id_t *sample_index_map = ts->sample_index_map; + const tsk_id_t *list_left = tree->left_sample; + const tsk_id_t *list_right = tree->right_sample; + const tsk_id_t *list_next = tree->next_sample; + tsk_id_t stop, sample_index; + + n = tsk_treeseq_get_num_samples(ts); + num_nodes = tsk_treeseq_get_num_nodes(ts); + stack = malloc(n * sizeof(tsk_id_t)); + samples = malloc(n * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(stack != NULL); + CU_ASSERT_FATAL(samples != NULL); + for (u = 0; u < num_nodes; u++) { + if (tree->left_child[u] == TSK_NULL && !tsk_treeseq_is_sample(ts, u)) { + CU_ASSERT_EQUAL(list_left[u], TSK_NULL); + CU_ASSERT_EQUAL(list_right[u], TSK_NULL); + } else { + stack_top = 0; + num_samples = 0; + stack[stack_top] = u; + while (stack_top >= 0) { + v = stack[stack_top]; + stack_top--; + if (tsk_treeseq_is_sample(ts, v)) { + samples[num_samples] = v; + num_samples++; + } + for (v = tree->right_child[v]; v != TSK_NULL; v = tree->left_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } + ret = tsk_tree_get_num_samples(tree, u, &tmp); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_samples, tmp); + + j = 0; + sample_index = list_left[u]; + if (sample_index != TSK_NULL) { + stop = list_right[u]; + while (true) { + CU_ASSERT_TRUE_FATAL(j < n); + CU_ASSERT_EQUAL_FATAL(sample_index, sample_index_map[samples[j]]); + j++; + if (sample_index == stop) { + break; + } + sample_index = list_next[sample_index]; + } + } + CU_ASSERT_EQUAL_FATAL(j, num_samples); + } + } + free(stack); + free(samples); +} + +static void +verify_sample_sets(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t t; + + ret = tsk_tree_alloc(&t, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL(ret, 0); + + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + verify_sample_sets_for_tree(&t); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + verify_sample_sets_for_tree(&t); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_tree_free(&t); +} + +static void +verify_tree_equals(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t *trees, t; + size_t j, k; + tsk_treeseq_t *other_ts = get_example_tree_sequence( + 10, 0, 100, 100.0, 1.0, 1.0, 0, NULL, 0); + int flags[] = {0, TSK_SAMPLE_LISTS, TSK_SAMPLE_COUNTS, + TSK_SAMPLE_LISTS | TSK_SAMPLE_COUNTS}; + + trees = get_tree_list(ts); + ret = tsk_tree_alloc(&t, other_ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < tsk_treeseq_get_num_trees(ts); j++) { + ret = tsk_tree_equal(&t, &trees[j]); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + for (k = 0; k < tsk_treeseq_get_num_trees(ts); k++) { + ret = tsk_tree_equal(&trees[j], &trees[k]); + if (j == k) { + CU_ASSERT_EQUAL_FATAL(ret, 0); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + } + } + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < sizeof(flags) / sizeof(int); j++) { + ret = tsk_tree_alloc(&t, ts, flags[j]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (ret = tsk_tree_first(&t); ret == 1; + ret = tsk_tree_next(&t)) { + for (k = 0; k < tsk_treeseq_get_num_trees(ts); k++) { + ret = tsk_tree_equal(&t, &trees[k]); + if (t.index == k) { + CU_ASSERT_EQUAL_FATAL(ret, 0); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + } + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL(ret, 0); + } + for (j = 0; j < tsk_treeseq_get_num_trees(ts); j++) { + ret = tsk_tree_free(&trees[j]); + } + free(trees); + tsk_treeseq_free(other_ts); + free(other_ts); +} + +static void +test_individual_nodes_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_individual_nodes(examples[j]); + add_individuals(examples[j]); + verify_individual_nodes(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_diff_iter_from_examples(void) +{ + /* tsk_treeseq_t **examples = get_example_tree_sequences(1); */ + /* uint32_t j; */ + + /* CU_ASSERT_FATAL(examples != NULL); */ + /* for (j = 0; examples[j] != NULL; j++) { */ + /* verify_tree_diffs(examples[j]); */ + /* tsk_treeseq_free(examples[j]); */ + /* free(examples[j]); */ + /* } */ + /* free(examples); */ +} + +static void +test_tree_iter_from_examples(void) +{ + + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_trees_consistent(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_sample_sets_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_sample_sets(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_tree_equals_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_tree_equals(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_next_prev_from_examples(void) +{ + + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_tree_next_prev(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_tsk_hapgen_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_hapgen(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_ld_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_ld(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_tsk_vargen_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_vargen(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_stats_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_stats(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_genealogical_nearest_neighbours_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_genealogical_nearest_neighbours(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_mean_descendants_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_mean_descendants(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + + + +static void +test_compute_mutation_parents_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_compute_mutation_parents(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +verify_simplify_errors(tsk_treeseq_t *ts) +{ + int ret; + tsk_id_t *s; + tsk_id_t u; + tsk_treeseq_t subset; + tsk_id_t sample[2]; + + ret = tsk_treeseq_get_samples(ts, &s); + CU_ASSERT_EQUAL_FATAL(ret, 0); + memcpy(sample, s, 2 * sizeof(tsk_id_t)); + + for (u = 0; u < (tsk_id_t) tsk_treeseq_get_num_nodes(ts); u++) { + if (! tsk_treeseq_is_sample(ts, u)) { + sample[1] = u; + ret = tsk_treeseq_simplify(ts, sample, 2, 0, &subset, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLES); + } + } + sample[0] = -1; + ret = tsk_treeseq_simplify(ts, sample, 2, 0, &subset, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + sample[0] = s[0]; + sample[1] = s[0]; + ret = tsk_treeseq_simplify(ts, sample, 2, 0, &subset, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); +} + +static void +test_simplify_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_simplify(examples[j]); + if (tsk_treeseq_get_num_migrations(examples[j]) == 0) { + /* Migrations are not supported at the moment, so skip these tests + * rather than complicate them */ + verify_simplify_errors(examples[j]); + } + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +test_reduce_topology_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_reduce_topology(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +verify_newick(tsk_treeseq_t *ts) +{ + /* int ret, err; */ + /* tsk_tree_t t; */ + /* tsk_id_t root; */ + /* size_t precision = 4; */ + /* size_t buffer_size = 1024 * 1024; */ + /* char *newick = malloc(buffer_size); */ + /* size_t j, size; */ + + /* CU_ASSERT_FATAL(newick != NULL); */ + + /* ret = tsk_tree_alloc(&t, ts, 0); */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + /* ret = tsk_tree_first(&t); */ + /* CU_ASSERT_FATAL(ret == 1); */ + /* for (root = t.left_root; root != TSK_NULL; root = t.right_sib[root]) { */ + /* err = tsk_tree_get_newick(&t, root, precision, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(err, 0); */ + /* size = strlen(newick); */ + /* CU_ASSERT_TRUE(size > 0); */ + /* CU_ASSERT_TRUE(size < buffer_size); */ + /* for (j = 0; j <= size; j++) { */ + /* err = tsk_tree_get_newick(&t, root, precision, 0, j, newick); */ + /* CU_ASSERT_EQUAL_FATAL(err, TSK_ERR_BUFFER_OVERFLOW); */ + /* } */ + /* err = tsk_tree_get_newick(&t, root, precision, 0, size + 1, newick); */ + /* CU_ASSERT_EQUAL_FATAL(err, 0); */ + /* } */ + + /* for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { */ + /* for (root = t.left_root; root != TSK_NULL; root = t.right_sib[root]) { */ + /* err = tsk_tree_get_newick(&t, root, precision, 0, 0, NULL); */ + /* CU_ASSERT_EQUAL_FATAL(err, TSK_ERR_BAD_PARAM_VALUE); */ + /* err = tsk_tree_get_newick(&t, root, precision, 0, buffer_size, newick); */ + /* CU_ASSERT_EQUAL_FATAL(err, 0); */ + /* size = strlen(newick); */ + /* CU_ASSERT_EQUAL(newick[size - 1], ';'); */ + /* } */ + /* } */ + /* CU_ASSERT_EQUAL_FATAL(ret, 0); */ + + /* tsk_tree_free(&t); */ + /* free(newick); */ +} + +static void +test_newick_from_examples(void) +{ + tsk_treeseq_t **examples = get_example_tree_sequences(1); + uint32_t j; + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + verify_newick(examples[j]); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static void +verify_tree_sequences_equal(tsk_treeseq_t *ts1, tsk_treeseq_t *ts2, + bool check_migrations, bool check_mutations, + bool check_provenance) +{ + int ret, err1, err2; + size_t j; + tsk_edge_t r1, r2; + tsk_node_t n1, n2; + tsk_migration_t m1, m2; + tsk_provenance_t p1, p2; + tsk_individual_t i1, i2; + tsk_population_t pop1, pop2; + size_t num_mutations = tsk_treeseq_get_num_mutations(ts1); + tsk_site_t site_1, site_2; + tsk_mutation_t mutation_1, mutation_2; + tsk_tree_t t1, t2; + + /* tsk_treeseq_print_state(ts1, stdout); */ + /* tsk_treeseq_print_state(ts2, stdout); */ + + CU_ASSERT_EQUAL( + tsk_treeseq_get_num_samples(ts1), + tsk_treeseq_get_num_samples(ts2)); + CU_ASSERT_EQUAL( + tsk_treeseq_get_sequence_length(ts1), + tsk_treeseq_get_sequence_length(ts2)); + CU_ASSERT_EQUAL( + tsk_treeseq_get_num_edges(ts1), + tsk_treeseq_get_num_edges(ts2)); + CU_ASSERT_EQUAL( + tsk_treeseq_get_num_nodes(ts1), + tsk_treeseq_get_num_nodes(ts2)); + CU_ASSERT_EQUAL( + tsk_treeseq_get_num_trees(ts1), + tsk_treeseq_get_num_trees(ts2)); + + for (j = 0; j < tsk_treeseq_get_num_nodes(ts1); j++) { + ret = tsk_treeseq_get_node(ts1, j, &n1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_node(ts2, j, &n2); + CU_ASSERT_EQUAL(ret, 0); + verify_nodes_equal(&n1, &n2); + } + for (j = 0; j < tsk_treeseq_get_num_edges(ts1); j++) { + ret = tsk_treeseq_get_edge(ts1, j, &r1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_edge(ts2, j, &r2); + CU_ASSERT_EQUAL(ret, 0); + verify_edges_equal(&r1, &r2, 1.0); + } + if (check_mutations) { + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_sites(ts1), + tsk_treeseq_get_num_sites(ts2)); + for (j = 0; j < tsk_treeseq_get_num_sites(ts1); j++) { + ret = tsk_treeseq_get_site(ts1, j, &site_1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_site(ts2, j, &site_2); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(site_1.position, site_2.position); + CU_ASSERT_EQUAL(site_1.ancestral_state_length, site_2.ancestral_state_length); + CU_ASSERT_NSTRING_EQUAL(site_1.ancestral_state, site_2.ancestral_state, + site_1.ancestral_state_length); + CU_ASSERT_EQUAL(site_1.metadata_length, site_2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(site_1.metadata, site_2.metadata, + site_1.metadata_length); + } + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_mutations(ts1), + tsk_treeseq_get_num_mutations(ts2)); + for (j = 0; j < num_mutations; j++) { + ret = tsk_treeseq_get_mutation(ts1, j, &mutation_1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_mutation(ts2, j, &mutation_2); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(mutation_1.id, j); + CU_ASSERT_EQUAL(mutation_1.id, mutation_2.id); + CU_ASSERT_EQUAL(mutation_1.site, mutation_2.site); + CU_ASSERT_EQUAL(mutation_1.node, mutation_2.node); + CU_ASSERT_EQUAL_FATAL(mutation_1.parent, mutation_2.parent); + CU_ASSERT_EQUAL_FATAL(mutation_1.derived_state_length, + mutation_2.derived_state_length); + CU_ASSERT_NSTRING_EQUAL(mutation_1.derived_state, + mutation_2.derived_state, mutation_1.derived_state_length); + CU_ASSERT_EQUAL_FATAL(mutation_1.metadata_length, + mutation_2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(mutation_1.metadata, + mutation_2.metadata, mutation_1.metadata_length); + } + } + if (check_migrations) { + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_migrations(ts1), + tsk_treeseq_get_num_migrations(ts2)); + for (j = 0; j < tsk_treeseq_get_num_migrations(ts1); j++) { + ret = tsk_treeseq_get_migration(ts1, j, &m1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_migration(ts2, j, &m2); + CU_ASSERT_EQUAL(ret, 0); + verify_migrations_equal(&m1, &m2, 1.0); + } + } + if (check_provenance) { + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_provenances(ts1), + tsk_treeseq_get_num_provenances(ts2)); + for (j = 0; j < tsk_treeseq_get_num_provenances(ts1); j++) { + ret = tsk_treeseq_get_provenance(ts1, j, &p1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_provenance(ts2, j, &p2); + CU_ASSERT_EQUAL(ret, 0); + verify_provenances_equal(&p1, &p2); + } + } + + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_individuals(ts1), + tsk_treeseq_get_num_individuals(ts2)); + for (j = 0; j < tsk_treeseq_get_num_individuals(ts1); j++) { + ret = tsk_treeseq_get_individual(ts1, j, &i1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_individual(ts2, j, &i2); + CU_ASSERT_EQUAL(ret, 0); + verify_individuals_equal(&i1, &i2); + } + + CU_ASSERT_EQUAL_FATAL( + tsk_treeseq_get_num_populations(ts1), + tsk_treeseq_get_num_populations(ts2)); + for (j = 0; j < tsk_treeseq_get_num_populations(ts1); j++) { + ret = tsk_treeseq_get_population(ts1, j, &pop1); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_treeseq_get_population(ts2, j, &pop2); + CU_ASSERT_EQUAL(ret, 0); + verify_populations_equal(&pop1, &pop2); + } + + ret = tsk_tree_alloc(&t1, ts1, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_alloc(&t2, ts2, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&t1); + CU_ASSERT_EQUAL(ret, 1); + ret = tsk_tree_first(&t2); + CU_ASSERT_EQUAL(ret, 1); + while (1) { + err1 = tsk_tree_next(&t1); + err2 = tsk_tree_next(&t2); + CU_ASSERT_EQUAL_FATAL(err1, err2); + if (err1 != 1) { + break; + } + } + tsk_tree_free(&t1); + tsk_tree_free(&t2); +} + +static void +test_save_empty_kas(void) +{ + int ret; + tsk_treeseq_t ts1, ts2; + double sequence_length = 1234.00; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = sequence_length; + + ret = tsk_treeseq_alloc(&ts1, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_dump(&ts1, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_empty_tree_sequence(&ts1, sequence_length); + ret = tsk_treeseq_load(&ts2, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_empty_tree_sequence(&ts2, sequence_length); + + tsk_treeseq_free(&ts1); + tsk_treeseq_free(&ts2); + tsk_tbl_collection_free(&tables); +} + +static void +test_save_kas(void) +{ + int ret; + size_t j, k; + tsk_treeseq_t **examples = get_example_tree_sequences(1); + tsk_treeseq_t ts2; + tsk_treeseq_t *ts1; + char *file_uuid; + int dump_flags[] = {0}; + + CU_ASSERT_FATAL(examples != NULL); + + for (j = 0; examples[j] != NULL; j++) { + ts1 = examples[j]; + file_uuid = tsk_treeseq_get_file_uuid(ts1); + CU_ASSERT_EQUAL_FATAL(file_uuid, NULL); + for (k = 0; k < sizeof(dump_flags) / sizeof(int); k++) { + ret = tsk_treeseq_dump(ts1, _tmp_file_name, dump_flags[k]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_load(&ts2, _tmp_file_name, TSK_LOAD_EXTENDED_CHECKS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts2, true, true, true); + tsk_treeseq_print_state(&ts2, _devnull); + verify_hapgen(&ts2); + verify_vargen(&ts2); + file_uuid = tsk_treeseq_get_file_uuid(&ts2); + CU_ASSERT_NOT_EQUAL_FATAL(file_uuid, NULL); + CU_ASSERT_EQUAL(strlen(file_uuid), TSK_UUID_SIZE); + tsk_treeseq_free(&ts2); + } + tsk_treeseq_free(ts1); + free(ts1); + } + free(examples); +} + +static void +test_save_kas_tables(void) +{ + int ret; + size_t j, k; + tsk_treeseq_t **examples = get_example_tree_sequences(1); + tsk_treeseq_t *ts1; + tsk_tbl_collection_t t1, t2; + int dump_flags[] = {0}; + + CU_ASSERT_FATAL(examples != NULL); + + for (j = 0; examples[j] != NULL; j++) { + ts1 = examples[j]; + ret = tsk_tbl_collection_alloc(&t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(ts1, &t1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t1.file_uuid, NULL); + for (k = 0; k < sizeof(dump_flags) / sizeof(int); k++) { + ret = tsk_tbl_collection_dump(&t1, _tmp_file_name, dump_flags[k]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_alloc(&t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&t2, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_equals(&t1, &t2)); + CU_ASSERT_EQUAL_FATAL(t1.file_uuid, NULL); + CU_ASSERT_NOT_EQUAL_FATAL(t2.file_uuid, NULL); + CU_ASSERT_EQUAL(strlen(t2.file_uuid), TSK_UUID_SIZE); + tsk_tbl_collection_free(&t2); + } + tsk_tbl_collection_free(&t1); + tsk_treeseq_free(ts1); + free(ts1); + } + free(examples); +} + +static void +test_sort_tables(void) +{ + int ret; + tsk_treeseq_t **examples = get_example_tree_sequences(1); + tsk_treeseq_t ts2; + tsk_treeseq_t *ts1; + size_t j, k, start, starts[3]; + tsk_tbl_collection_t tables; + int load_flags = TSK_BUILD_INDEXES; + tsk_id_t tmp_node; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(examples != NULL); + + for (j = 0; examples[j] != NULL; j++) { + ts1 = examples[j]; + + ret = tsk_treeseq_copy_tables(ts1, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Check the input validation */ + ret = tsk_tbl_collection_sort(NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + /* Check edge sorting */ + if (tables.edges->num_rows == 2) { + starts[0] = 0; + starts[1] = 0; + starts[2] = 0; + } else { + starts[0] = 0; + starts[1] = tables.edges->num_rows / 2; + starts[2] = tables.edges->num_rows - 2; + } + for (k = 0; k < 3; k++) { + start = starts[k]; + unsort_edges(tables.edges, start); + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_NOT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts2); + + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts2, true, true, false); + tsk_treeseq_free(&ts2); + } + + /* A start value of num_tables.edges should have no effect */ + ret = tsk_tbl_collection_sort(&tables, tables.edges->num_rows, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts2, true, true, false); + tsk_treeseq_free(&ts2); + + if (tables.sites->num_rows > 1) { + /* Check site sorting */ + unsort_sites(tables.sites, tables.mutations); + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_NOT_EQUAL(ret, 0); + tsk_treeseq_free(&ts2); + + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts2, true, true, false); + tsk_treeseq_free(&ts2); + + /* Check for site bounds error */ + tables.mutations->site[0] = tables.sites->num_rows; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tables.mutations->site[0] = 0; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Check for edge node bounds error */ + tmp_node = tables.edges->parent[0]; + tables.edges->parent[0] = tables.nodes->num_rows; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.edges->parent[0] = tmp_node; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Check for mutation node bounds error */ + tmp_node = tables.mutations->node[0]; + tables.mutations->node[0] = tables.nodes->num_rows; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.mutations->node[0] = tmp_node; + + /* Check for mutation parent bounds error */ + tables.mutations->parent[0] = tables.mutations->num_rows; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + tables.mutations->parent[0] = TSK_NULL; + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + tsk_treeseq_free(ts1); + free(ts1); + } + free(examples); + tsk_tbl_collection_free(&tables); +} +static void +test_dump_tables(void) +{ + int ret; + tsk_treeseq_t **examples = get_example_tree_sequences(1); + tsk_treeseq_t ts2; + tsk_treeseq_t *ts1; + tsk_tbl_collection_t tables; + size_t j; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(examples != NULL); + + for (j = 0; examples[j] != NULL; j++) { + ts1 = examples[j]; + + ret = tsk_treeseq_copy_tables(ts1, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + ret = tsk_treeseq_copy_tables(ts1, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts2, true, true, true); + tsk_treeseq_print_state(&ts2, _devnull); + tsk_treeseq_free(&ts2); + tsk_treeseq_free(ts1); + free(ts1); + } + + free(examples); + tsk_tbl_collection_free(&tables); +} + +static void +test_dump_tables_kas(void) +{ + int ret; + size_t k; + tsk_treeseq_t *ts1, ts2, ts3, **examples; + tsk_tbl_collection_t tables; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + examples = get_example_tree_sequences(1); + for (k = 0; examples[k] != NULL; k++) { + ts1 = examples[k]; + CU_ASSERT_FATAL(ts1 != NULL); + ret = tsk_treeseq_copy_tables(ts1, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_alloc(&ts2, &tables, load_flags); + ret = tsk_treeseq_dump(&ts2, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_load(&ts3, _tmp_file_name, TSK_LOAD_EXTENDED_CHECKS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_tree_sequences_equal(ts1, &ts3, true, true, true); + tsk_treeseq_print_state(&ts2, _devnull); + + tsk_treeseq_free(&ts2); + tsk_treeseq_free(&ts3); + tsk_treeseq_free(ts1); + free(ts1); + } + free(examples); + tsk_tbl_collection_free(&tables); +} + +void +test_tsk_tbl_collection_simplify_errors(void) +{ + int ret; + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret = tsk_site_tbl_add_row(tables.sites, 0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_tbl_add_row(tables.sites, 0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SITE_POSITION); + + /* Out of order positions */ + tables.sites->position[0] = 0.5; + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + + /* Position out of bounds */ + tables.sites->position[0] = 1.5; + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SITE_POSITION); + + /* TODO More tests for this: see + * https://github.com/tskit-dev/msprime/issues/517 */ + + tsk_tbl_collection_free(&tables); + +} +void +test_tsk_tbl_collection_position_errors(void) +{ + int ret; + int j; + tsk_tbl_collection_t t1, t2; + tsk_tbl_collection_position_t pos1, pos2; + tsk_treeseq_t **examples = get_example_tree_sequences(1); + + CU_ASSERT_FATAL(examples != NULL); + for (j = 0; examples[j] != NULL; j++) { + // set-up + ret = tsk_tbl_collection_alloc(&t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_alloc(&t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(examples[j], &t1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_copy(&t1, &t2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_tbl_collection_record_position(&t1, &pos1); + + // for each table, add a new row to t2, bookmark that location, + // then try to reset t1 to this illegal location + + // individuals + tsk_individual_tbl_add_row(t2.individuals, 0, NULL, 0, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // nodes + tsk_node_tbl_add_row(t2.nodes, 0, 1.2, 0, -1, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // edges + tsk_edge_tbl_add_row(t2.edges, 0.1, 0.4, 0, 3); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // migrations + tsk_migration_tbl_add_row(t2.migrations, 0.1, 0.2, 2, 1, 2, 1.2); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // sites + tsk_site_tbl_add_row(t2.sites, 0.3, "A", 1, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // mutations + tsk_mutation_tbl_add_row(t2.mutations, 0, 1, -1, "X", 1, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // populations + tsk_population_tbl_add_row(t2.populations, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + // provenance + tsk_provenance_tbl_add_row(t2.provenances, "abc", 3, NULL, 0); + tsk_tbl_collection_record_position(&t2, &pos2); + ret = tsk_tbl_collection_reset_position(&t1, &pos2); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_TABLE_POSITION); + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tbl_collection_free(&t1); + tsk_tbl_collection_free(&t2); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +void +test_tsk_tbl_collection_position(void) +{ + int ret; + int j, k; + tsk_treeseq_t **examples; + tsk_tbl_collection_t t1, t2, t3; + tsk_tbl_collection_position_t pos1, pos2; + + examples = get_example_tree_sequences(1); + CU_ASSERT_FATAL(examples != NULL); + + for (j = 0; examples[j] != NULL; j++) { + ret = tsk_tbl_collection_alloc(&t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_alloc(&t2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_alloc(&t3, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_copy_tables(examples[j], &t1); + + // bookmark at pos1 + tsk_tbl_collection_record_position(&t1, &pos1); + // copy to t2 + ret = tsk_tbl_collection_copy(&t1, &t2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + // resetting position should do nothing + ret = tsk_tbl_collection_reset_position(&t2, &pos1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_equals(&t1, &t2)); + // add more rows to t2 + // (they don't have to make sense for this test) + for (k = 0; k < 3; k++) { + tsk_node_tbl_add_row(t2.nodes, 0, 1.2, 0, -1, NULL, 0); + tsk_node_tbl_add_row(t2.nodes, 0, 1.2, k, -1, NULL, 0); + tsk_edge_tbl_add_row(t2.edges, 0.1, 0.5, k, k+1); + tsk_edge_tbl_add_row(t2.edges, 0.3, 0.8, k, k+2); + } + // bookmark at pos2 + tsk_tbl_collection_record_position(&t2, &pos2); + // copy to t3 + ret = tsk_tbl_collection_copy(&t2, &t3); + CU_ASSERT_EQUAL_FATAL(ret, 0); + // add more rows to t3 + for (k = 0; k < 3; k++) { + tsk_node_tbl_add_row(t3.nodes, 0, 1.2, k+5, -1, NULL, 0); + tsk_site_tbl_add_row(t3.sites, 0.2, "A", 1, NULL, 0); + tsk_site_tbl_add_row(t3.sites, 0.2, "C", 1, NULL, 0); + tsk_mutation_tbl_add_row(t3.mutations, 0, k, -1, "T", 1, NULL, 0); + tsk_migration_tbl_add_row(t3.migrations, 0.0, 0.5, 1, 0, 1, 1.2); + tsk_individual_tbl_add_row(t3.individuals, k, NULL, 0, NULL, 0); + tsk_population_tbl_add_row(t3.populations, "X", 1); + tsk_provenance_tbl_add_row(t3.provenances, "abc", 3, NULL, 0); + } + // now resetting t3 to pos2 should equal t2 + ret = tsk_tbl_collection_reset_position(&t3, &pos2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_equals(&t2, &t3)); + // and resetting to pos1 should equal t1 + ret = tsk_tbl_collection_reset_position(&t3, &pos1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_equals(&t1, &t3)); + + ret = tsk_tbl_collection_clear(&t1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(t1.individuals->num_rows, 0); + CU_ASSERT_EQUAL(t1.populations->num_rows, 0); + CU_ASSERT_EQUAL(t1.nodes->num_rows, 0); + CU_ASSERT_EQUAL(t1.edges->num_rows, 0); + CU_ASSERT_EQUAL(t1.migrations->num_rows, 0); + CU_ASSERT_EQUAL(t1.sites->num_rows, 0); + CU_ASSERT_EQUAL(t1.mutations->num_rows, 0); + CU_ASSERT_EQUAL(t1.provenances->num_rows, 0); + + tsk_tbl_collection_free(&t1); + tsk_tbl_collection_free(&t2); + tsk_tbl_collection_free(&t3); + tsk_treeseq_free(examples[j]); + free(examples[j]); + } + free(examples); +} + +static int +msprime_suite_init(void) +{ + int fd; + static char template[] = "/tmp/tsk_c_test_XXXXXX"; + + _tmp_file_name = NULL; + _devnull = NULL; + + _tmp_file_name = malloc(sizeof(template)); + if (_tmp_file_name == NULL) { + return CUE_NOMEMORY; + } + strcpy(_tmp_file_name, template); + fd = mkstemp(_tmp_file_name); + if (fd == -1) { + return CUE_SINIT_FAILED; + } + close(fd); + _devnull = fopen("/dev/null", "w"); + if (_devnull == NULL) { + return CUE_SINIT_FAILED; + } + return CUE_SUCCESS; +} + +static int +msprime_suite_cleanup(void) +{ + if (_tmp_file_name != NULL) { + unlink(_tmp_file_name); + free(_tmp_file_name); + } + if (_devnull != NULL) { + fclose(_devnull); + } + return CUE_SUCCESS; +} + +static void +handle_cunit_error() +{ + fprintf(stderr, "CUnit error occured: %d: %s\n", + CU_get_error(), CU_get_error_msg()); + exit(EXIT_FAILURE); +} + +int +main(int argc, char **argv) +{ + int ret; + CU_pTest test; + CU_pSuite suite; + CU_TestInfo tests[] = { + {"test_vcf", test_vcf}, + {"test_vcf_no_mutations", test_vcf_no_mutations}, + {"test_node_metadata", test_node_metadata}, + + {"test_single_tree_newick", test_single_tree_newick}, + + + {"test_diff_iter_from_examples", test_diff_iter_from_examples}, + {"test_tree_iter_from_examples", test_tree_iter_from_examples}, + {"test_tree_equals_from_examples", test_tree_equals_from_examples}, + {"test_next_prev_from_examples", test_next_prev_from_examples}, + {"test_sample_sets_from_examples", test_sample_sets_from_examples}, + {"test_tsk_hapgen_from_examples", test_tsk_hapgen_from_examples}, + {"test_tsk_vargen_from_examples", test_tsk_vargen_from_examples}, + {"test_newick_from_examples", test_newick_from_examples}, + {"test_stats_from_examples", test_stats_from_examples}, + {"test_compute_mutation_parents_from_examples", + test_compute_mutation_parents_from_examples}, + {"test_individual_nodes_from_examples", + test_individual_nodes_from_examples}, + {"test_ld_from_examples", test_ld_from_examples}, + {"test_simplify_from_examples", test_simplify_from_examples}, + {"test_reduce_topology_from_examples", test_reduce_topology_from_examples}, + {"test_save_empty_kas", test_save_empty_kas}, + {"test_save_kas", test_save_kas}, + {"test_save_kas_tables", test_save_kas_tables}, + {"test_dump_tables", test_dump_tables}, + {"test_sort_tables", test_sort_tables}, + {"test_dump_tables_kas", test_dump_tables_kas}, + + {"test_tsk_tbl_collection_position", test_tsk_tbl_collection_position}, + {"test_tsk_tbl_collection_position_errors", test_tsk_tbl_collection_position_errors}, + + {"test_genealogical_nearest_neighbours_from_examples", + test_genealogical_nearest_neighbours_from_examples}, + {"test_mean_descendants_from_examples", test_mean_descendants_from_examples}, + CU_TEST_INFO_NULL, + }; + + /* We use initialisers here as the struct definitions change between + * versions of CUnit */ + CU_SuiteInfo suites[] = { + { + .pName = "msprime", + .pInitFunc = msprime_suite_init, + .pCleanupFunc = msprime_suite_cleanup, + .pTests = tests + }, + CU_SUITE_INFO_NULL, + }; + if (CUE_SUCCESS != CU_initialize_registry()) { + handle_cunit_error(); + } + if (CUE_SUCCESS != CU_register_suites(suites)) { + handle_cunit_error(); + } + CU_basic_set_mode(CU_BRM_VERBOSE); + + if (argc == 1) { + CU_basic_run_tests(); + } else if (argc == 2) { + suite = CU_get_suite_by_name("msprime", CU_get_registry()); + if (suite == NULL) { + printf("Suite not found\n"); + return EXIT_FAILURE; + } + test = CU_get_test_by_name(argv[1], suite); + if (test == NULL) { + printf("Test '%s' not found\n", argv[1]); + return EXIT_FAILURE; + } + CU_basic_run_test(suite, test); + } else { + printf("usage: ./tests \n"); + return EXIT_FAILURE; + } + + ret = EXIT_SUCCESS; + if (CU_get_number_of_tests_failed() != 0) { + printf("Test failed!\n"); + ret = EXIT_FAILURE; + } + CU_cleanup_registry(); + return ret; +} diff --git a/c/test_core.c b/c/test_core.c new file mode 100644 index 0000000000..83ea7393eb --- /dev/null +++ b/c/test_core.c @@ -0,0 +1,67 @@ +#include "testlib.h" +#include "tsk_core.h" + +#include + +static void +test_strerror(void) +{ + int j; + const char *msg; + int max_error_code = 1024; /* totally arbitrary */ + + for (j = 0; j < max_error_code; j++) { + msg = tsk_strerror(-j); + CU_ASSERT_FATAL(msg != NULL); + CU_ASSERT(strlen(msg) > 0); + } +} + +static void +test_strerror_kastore(void) +{ + int kastore_errors[] = {KAS_ERR_NO_MEMORY, KAS_ERR_IO, KAS_ERR_KEY_NOT_FOUND}; + size_t j; + int err; + + for (j = 0; j < sizeof(kastore_errors) / sizeof(*kastore_errors); j++) { + err = tsk_set_kas_error(kastore_errors[j]); + CU_ASSERT_TRUE(tsk_is_kas_error(err)); + CU_ASSERT_STRING_EQUAL(tsk_strerror(err), kas_strerror(kastore_errors[j])); + } +} + +static void +test_generate_uuid(void) +{ + size_t uuid_size = 36; + char uuid[uuid_size + 1]; + char other_uuid[uuid_size + 1]; + int ret; + + ret = tsk_generate_uuid(uuid, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(strlen(uuid), uuid_size); + CU_ASSERT_EQUAL(uuid[8], '-'); + CU_ASSERT_EQUAL(uuid[13], '-'); + CU_ASSERT_EQUAL(uuid[18], '-'); + CU_ASSERT_EQUAL(uuid[23], '-'); + + ret = tsk_generate_uuid(other_uuid, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(strlen(other_uuid), uuid_size); + CU_ASSERT_STRING_NOT_EQUAL(uuid, other_uuid); +} + +int +main(int argc, char **argv) +{ + CU_TestInfo tests[] = { + {"test_strerror", test_strerror}, + {"test_strerror_kastore", test_strerror_kastore}, + {"test_generate_uuid", test_generate_uuid}, + {NULL}, + }; + + return test_main(tests, argc, argv); +} diff --git a/c/test_genotypes.c b/c/test_genotypes.c new file mode 100644 index 0000000000..3b4ed09b87 --- /dev/null +++ b/c/test_genotypes.c @@ -0,0 +1,578 @@ +#include "testlib.h" +#include "tsk_genotypes.h" + +#include +#include + +static void +test_single_tree_hapgen_char_alphabet(void) +{ + int ret = 0; + const char *sites = + "0.0 A\n" + "0.1 A\n" + "0.2 C\n" + "0.4 A\n"; + const char *mutations = + "0 0 T\n" + "1 1 T\n" + "2 0 G\n" + "2 1 A\n" + "2 2 T\n" // A bunch of different sample mutations + "3 4 T\n" + "3 0 A\n"; // A back mutation from T -> A + uint32_t num_samples = 4; + tsk_treeseq_t ts; + char *haplotype; + size_t j; + tsk_hapgen_t hapgen; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL); + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < num_samples; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, ""); + } + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + sites, mutations, NULL, NULL); + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + + ret = tsk_hapgen_get_haplotype(&hapgen, 0, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "TAGA"); + ret = tsk_hapgen_get_haplotype(&hapgen, 1, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "ATAT"); + ret = tsk_hapgen_get_haplotype(&hapgen, 2, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "AATA"); + ret = tsk_hapgen_get_haplotype(&hapgen, 3, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "AACA"); + + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_vargen_char_alphabet(void) +{ + int ret = 0; + const char *sites = + "0.0 A\n" + "0.1 A\n" + "0.2 C\n" + "0.4 A\n"; + const char *mutations = + "0 0 T -1\n" + "1 1 TTTAAGGG -1\n" + "2 0 G -1\n" + "2 1 AT -1\n" + "2 2 T -1\n" // A bunch of different sample mutations + "3 4 T -1\n" + "3 0 A 5\n"; // A back mutation from T -> A + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_variant_t *var; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + sites, mutations, NULL, NULL); + ret = tsk_vargen_alloc(&vargen, &ts, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->site->position, 0.0); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_EQUAL(var->allele_lengths[0], 1); + CU_ASSERT_EQUAL(var->allele_lengths[1], 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "A", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "T", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->site->position, 0.1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_EQUAL(var->allele_lengths[0], 1); + CU_ASSERT_EQUAL(var->allele_lengths[1], 8); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "A", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "TTTAAGGG", 8); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->site->position, 0.2); + CU_ASSERT_EQUAL(var->num_alleles, 4); + CU_ASSERT_EQUAL(var->allele_lengths[0], 1); + CU_ASSERT_EQUAL(var->allele_lengths[1], 1); + CU_ASSERT_EQUAL(var->allele_lengths[2], 2); + CU_ASSERT_EQUAL(var->allele_lengths[3], 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "C", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "G", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[2], "AT", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[3], "T", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 2); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 3); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->site->position, 0.4); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_EQUAL(var->allele_lengths[0], 1); + CU_ASSERT_EQUAL(var->allele_lengths[1], 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "A", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "T", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_vargen_free(&vargen); + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_hapgen_binary_alphabet(void) +{ + int ret = 0; + uint32_t num_samples = 4; + tsk_treeseq_t ts; + char *haplotype; + size_t j; + tsk_hapgen_t hapgen; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL); + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < num_samples; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, ""); + } + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + + ret = tsk_hapgen_get_haplotype(&hapgen, 0, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "001"); + ret = tsk_hapgen_get_haplotype(&hapgen, 1, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "011"); + ret = tsk_hapgen_get_haplotype(&hapgen, 2, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "101"); + ret = tsk_hapgen_get_haplotype(&hapgen, 3, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "001"); + + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_vargen_binary_alphabet(void) +{ + int ret = 0; + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_variant_t *var; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + ret = tsk_vargen_alloc(&vargen, &ts, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_print_state(&vargen, _devnull); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 0); + CU_ASSERT_EQUAL(var->site->mutations_length, 1); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 0); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 1); + CU_ASSERT_EQUAL(var->site->mutations_length, 2); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[3], 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 2); + CU_ASSERT_EQUAL(var->site->mutations_length, 4); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_vargen_free(&vargen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_vargen_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_id_t samples[] = {0, 3}; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + ret = tsk_vargen_alloc(&vargen, &ts, samples, 2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_free(&vargen); + + samples[0] = -1; + ret = tsk_vargen_alloc(&vargen, &ts, samples, 2, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_OUT_OF_BOUNDS); + tsk_vargen_free(&vargen); + + samples[0] = 7; + ret = tsk_vargen_alloc(&vargen, &ts, samples, 2, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_OUT_OF_BOUNDS); + tsk_vargen_free(&vargen); + + samples[0] = 3; + ret = tsk_vargen_alloc(&vargen, &ts, samples, 2, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + tsk_vargen_free(&vargen); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_vargen_subsample(void) +{ + int ret = 0; + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_variant_t *var; + tsk_id_t samples[] = {0, 3}; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + ret = tsk_vargen_alloc(&vargen, &ts, samples, 2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_print_state(&vargen, _devnull); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 0); + CU_ASSERT_EQUAL(var->site->mutations_length, 1); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 1); + CU_ASSERT_EQUAL(var->site->mutations_length, 2); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 2); + CU_ASSERT_EQUAL(var->site->mutations_length, 4); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_vargen_free(&vargen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Zero samples */ + ret = tsk_vargen_alloc(&vargen, &ts, samples, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_print_state(&vargen, _devnull); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 0); + CU_ASSERT_EQUAL(var->site->mutations_length, 1); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 1); + CU_ASSERT_EQUAL(var->site->mutations_length, 2); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->site->id, 2); + CU_ASSERT_EQUAL(var->site->mutations_length, 4); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_vargen_free(&vargen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_vargen_many_alleles(void) +{ + int ret = 0; + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_variant_t *var; + tsk_tbl_size_t num_alleles = 257; + tsk_id_t j, k, l; + int flags; + char alleles[num_alleles]; + tsk_tbl_collection_t tables; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_FATAL(ret == 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_FATAL(ret == 0); + tsk_treeseq_free(&ts); + memset(alleles, 'X', (size_t) num_alleles); + ret = tsk_site_tbl_add_row(tables.sites, 0, "Y", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + /* Add j mutations over a single node. */ + for (j = 0; j < (tsk_id_t) num_alleles; j++) { + /* When j = 0 we get a parent of -1, which is the NULL_NODE */ + ret = tsk_mutation_tbl_add_row(tables.mutations, 0, 0, j - 1, alleles, + (tsk_tbl_size_t) j, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_treeseq_alloc(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (l = 0; l < 2; l++) { + flags = 0; + if (l == 1) { + flags = TSK_16_BIT_GENOTYPES; + } + ret = tsk_vargen_alloc(&vargen, &ts, NULL, 0, flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_print_state(&vargen, _devnull); + ret = tsk_vargen_next(&vargen, &var); + /* We have j + 2 alleles. So, if j >= 254, we should fail with 8bit + * genotypes */ + if (l == 0 && j >= 254) { + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TOO_MANY_ALLELES); + } else { + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "Y", 1); + for (k = 1; k < (tsk_id_t) var->num_alleles; k++) { + CU_ASSERT_EQUAL(k - 1, var->allele_lengths[k]); + CU_ASSERT_NSTRING_EQUAL(var->alleles[k], alleles, var->allele_lengths[k]); + } + CU_ASSERT_EQUAL(var->num_alleles, j + 2); + } + ret = tsk_vargen_free(&vargen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + tsk_treeseq_free(&ts); + } + tsk_tbl_collection_free(&tables); +} + +static void +test_single_unary_tree_hapgen(void) +{ + int ret = 0; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0\n" + "0 2 0\n" + "0 3 0\n" + "0 4 0\n"; + const char *edges = + "0 1 2 0\n" + "0 1 3 1\n" + "0 1 4 2,3\n" + "0 1 5 4\n" + "0 1 6 5\n"; + const char *sites = + "0 0\n" + "0.1 0\n" + "0.2 0\n" + "0.3 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n" + "2 4 1\n" + "3 5 1\n"; + tsk_treeseq_t ts; + size_t num_samples = 2; + size_t j; + tsk_hapgen_t hapgen; + char *haplotype; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < num_samples; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, ""); + } + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL); + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + + ret = tsk_hapgen_get_haplotype(&hapgen, 0, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "1011"); + ret = tsk_hapgen_get_haplotype(&hapgen, 1, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "0111"); + + ret = tsk_hapgen_free(&hapgen); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_inconsistent_mutations(void) +{ + const char *sites = + "0.0 0\n" + "0.1 0\n" + "0.2 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n" + "2 4 1\n" + "2 0 1\n"; + tsk_treeseq_t ts; + tsk_variant_t *var; + tsk_vargen_t vargen; + tsk_hapgen_t hapgen; + int flags[] = {0, TSK_16_BIT_GENOTYPES}; + tsk_id_t all_samples[] = {0, 1, 2, 3}; + tsk_id_t *samples[] = {NULL, all_samples}; + size_t num_samples = 4; + size_t s, f; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + sites, mutations, NULL, NULL); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INCONSISTENT_MUTATIONS); + ret = tsk_hapgen_free(&hapgen); + + for (s = 0; s < 2; s++) { + for (f = 0; f < sizeof(flags) / sizeof(*flags); f++) { + ret = tsk_vargen_alloc(&vargen, &ts, samples[s], num_samples, flags[f]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INCONSISTENT_MUTATIONS); + tsk_vargen_free(&vargen); + } + } + + tsk_treeseq_free(&ts); +} + +int +main(int argc, char **argv) +{ + CU_TestInfo tests[] = { + {"test_single_tree_hapgen_char_alphabet", test_single_tree_hapgen_char_alphabet}, + {"test_single_tree_hapgen_binary_alphabet", test_single_tree_hapgen_binary_alphabet}, + {"test_single_unary_tree_hapgen", test_single_unary_tree_hapgen}, + {"test_single_tree_vargen_char_alphabet", test_single_tree_vargen_char_alphabet}, + {"test_single_tree_vargen_binary_alphabet", test_single_tree_vargen_binary_alphabet}, + {"test_single_tree_vargen_errors", test_single_tree_vargen_errors}, + {"test_single_tree_vargen_subsample", test_single_tree_vargen_subsample}, + {"test_single_tree_vargen_many_alleles", test_single_tree_vargen_many_alleles}, + {"test_single_tree_inconsistent_mutations", test_single_tree_inconsistent_mutations}, + {NULL}, + }; + + return test_main(tests, argc, argv); +} diff --git a/c/test_tables.c b/c/test_tables.c new file mode 100644 index 0000000000..2eea907056 --- /dev/null +++ b/c/test_tables.c @@ -0,0 +1,1857 @@ +#include "testlib.h" +#include "tsk_tables.h" + +#include +#include + +typedef struct { + const char *name; + void *array; + tsk_tbl_size_t len; + int type; +} write_table_col_t; + +static void +write_table_cols(kastore_t *store, write_table_col_t *write_cols, size_t num_cols) +{ + size_t j; + int ret; + + for (j = 0; j < num_cols; j++) { + ret = kastore_puts(store, write_cols[j].name, write_cols[j].array, + write_cols[j].len, write_cols[j].type, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } +} + +static void +test_format_data_load_errors(void) +{ + size_t uuid_size = 36; + char uuid[uuid_size]; + char format_name[TSK_FILE_FORMAT_NAME_LENGTH]; + double L[2]; + uint32_t version[2] = { + TSK_FILE_FORMAT_VERSION_MAJOR, TSK_FILE_FORMAT_VERSION_MINOR}; + write_table_col_t write_cols[] = { + {"format/name", (void *) format_name, sizeof(format_name), KAS_INT8}, + {"format/version", (void *) version, 2, KAS_UINT32}, + {"sequence_length", (void *) L, 1, KAS_FLOAT64}, + {"uuid", (void *) uuid, (tsk_tbl_size_t) uuid_size, KAS_INT8}, + }; + tsk_tbl_collection_t tables; + kastore_t store; + size_t j; + int ret; + + L[0] = 1; + L[1] = 0; + memcpy(format_name, TSK_FILE_FORMAT_NAME, sizeof(format_name)); + /* Note: this will fail if we ever start parsing the form of the UUID */ + memset(uuid, 0, uuid_size); + + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + /* We've only defined the format headers, so we should fail immediately + * after with key not found */ + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_KEY_NOT_FOUND); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Version too old */ + version[0] = TSK_FILE_FORMAT_VERSION_MAJOR - 1; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_VERSION_TOO_OLD); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Version too new */ + version[0] = TSK_FILE_FORMAT_VERSION_MAJOR + 1; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_VERSION_TOO_NEW); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + version[0] = TSK_FILE_FORMAT_VERSION_MAJOR; + + /* Bad version length */ + write_cols[1].len = 0; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[1].len = 2; + + /* Bad format name length */ + write_cols[0].len = 0; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[0].len = TSK_FILE_FORMAT_NAME_LENGTH; + + /* Bad format name */ + format_name[0] = 'X'; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + format_name[0] = 't'; + + /* Bad type for sequence length. */ + write_cols[2].type = KAS_FLOAT32; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_TYPE_MISMATCH); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[2].type = KAS_FLOAT64; + + /* Bad length for sequence length. */ + write_cols[2].len = 2; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[2].len = 1; + + /* Bad value for sequence length. */ + L[0] = -1; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SEQUENCE_LENGTH); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + L[0] = 1; + + /* Wrong length for uuid */ + write_cols[3].len = 1; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[3].len = (tsk_tbl_size_t) uuid_size; + + /* Missing keys */ + for (j = 0; j < sizeof(write_cols) / sizeof(*write_cols) - 1; j++) { + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, j); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_KEY_NOT_FOUND); + CU_ASSERT_STRING_EQUAL(tsk_strerror(ret), kas_strerror(KAS_ERR_KEY_NOT_FOUND)); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } +} + +static void +test_dump_unindexed(void) +{ + tsk_tbl_collection_t tables, loaded; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + parse_nodes(single_tree_ex_nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 7); + parse_edges(single_tree_ex_edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 6); + CU_ASSERT_FALSE(tsk_tbl_collection_is_indexed(&tables)); + ret = tsk_tbl_collection_dump(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_is_indexed(&tables)); + + ret = tsk_tbl_collection_load(&loaded, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_tbl_collection_is_indexed(&loaded)); + CU_ASSERT_TRUE(tsk_node_tbl_equals(tables.nodes, loaded.nodes)); + CU_ASSERT_TRUE(tsk_edge_tbl_equals(tables.edges, loaded.edges)); + + tsk_tbl_collection_free(&loaded); + tsk_tbl_collection_free(&tables); +} + +static void +test_tbl_collection_load_errors(void) +{ + tsk_tbl_collection_t tables; + int ret; + const char *str; + + ret = tsk_tbl_collection_load(&tables, "/", 0); + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_IO); + str = tsk_strerror(ret); + CU_ASSERT_TRUE(strlen(str) > 0); + + tsk_tbl_collection_free(&tables); +} + +static void +test_tbl_collection_dump_errors(void) +{ + tsk_tbl_collection_t tables; + int ret; + const char *str; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_dump(&tables, "/", 0); + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_IO); + str = tsk_strerror(ret); + CU_ASSERT_TRUE(strlen(str) > 0); + + tsk_tbl_collection_free(&tables); +} +static void +test_tbl_collection_simplify_errors(void) +{ + int ret; + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret = tsk_site_tbl_add_row(tables.sites, 0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_tbl_add_row(tables.sites, 0, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SITE_POSITION); + + /* Out of order positions */ + tables.sites->position[0] = 0.5; + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + + /* Position out of bounds */ + tables.sites->position[0] = 1.5; + ret = tsk_tbl_collection_simplify(&tables, samples, 0, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SITE_POSITION); + + /* TODO More tests for this: see + * https://github.com/tskit-dev/msprime/issues/517 */ + + tsk_tbl_collection_free(&tables); +} + +static void +test_load_tsk_node_tbl_errors(void) +{ + char format_name[TSK_FILE_FORMAT_NAME_LENGTH]; + tsk_tbl_size_t uuid_size = 36; + char uuid[uuid_size]; + double L = 1; + double time = 0; + double flags = 0; + int32_t population = 0; + int32_t individual = 0; + int8_t metadata = 0; + uint32_t metadata_offset[] = {0, 1}; + uint32_t version[2] = { + TSK_FILE_FORMAT_VERSION_MAJOR, TSK_FILE_FORMAT_VERSION_MINOR}; + write_table_col_t write_cols[] = { + {"nodes/time", (void *) &time, 1, KAS_FLOAT64}, + {"nodes/flags", (void *) &flags, 1, KAS_UINT32}, + {"nodes/population", (void *) &population, 1, KAS_INT32}, + {"nodes/individual", (void *) &individual, 1, KAS_INT32}, + {"nodes/metadata", (void *) &metadata, 1, KAS_UINT8}, + {"nodes/metadata_offset", (void *) metadata_offset, 2, KAS_UINT32}, + {"format/name", (void *) format_name, sizeof(format_name), KAS_INT8}, + {"format/version", (void *) version, 2, KAS_UINT32}, + {"uuid", (void *) uuid, uuid_size, KAS_INT8}, + {"sequence_length", (void *) &L, 1, KAS_FLOAT64}, + }; + tsk_tbl_collection_t tables; + kastore_t store; + int ret; + + memcpy(format_name, TSK_FILE_FORMAT_NAME, sizeof(format_name)); + /* Note: this will fail if we ever start parsing the form of the UUID */ + memset(uuid, 0, uuid_size); + + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + /* We've only defined the format headers and nodes, so we should fail immediately + * after with key not found */ + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_KEY_NOT_FOUND); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Wrong type for time */ + write_cols[0].type = KAS_INT64; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[0].type = KAS_FLOAT64; + + /* Wrong length for flags */ + write_cols[1].len = 0; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[1].len = 1; + + /* Missing key */ + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols) - 1); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_is_kas_error(ret)); + CU_ASSERT_EQUAL_FATAL(ret ^ (1 << TSK_KAS_ERR_BIT), KAS_ERR_KEY_NOT_FOUND); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Wrong length for metadata offset */ + write_cols[5].len = 1; + ret = kastore_open(&store, _tmp_file_name, "w", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_table_cols(&store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_FILE_FORMAT); + ret = tsk_tbl_collection_free(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + write_cols[5].len = 2; + +} + +static void +test_node_table(void) +{ + int ret; + tsk_node_tbl_t table; + tsk_node_t node; + uint32_t num_rows = 100; + uint32_t j; + uint32_t *flags; + tsk_id_t *population; + double *time; + tsk_id_t *individual; + char *metadata; + uint32_t *metadata_offset; + const char *test_metadata = "test"; + size_t test_metadata_length = 4; + char metadata_copy[test_metadata_length + 1]; + + metadata_copy[test_metadata_length] = '\0'; + ret = tsk_node_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_node_tbl_set_max_rows_increment(&table, 1); + tsk_node_tbl_set_max_metadata_length_increment(&table, 1); + tsk_node_tbl_print_state(&table, _devnull); + tsk_node_tbl_dump_text(&table, _devnull); + + for (j = 0; j < num_rows; j++) { + ret = tsk_node_tbl_add_row(&table, j, j, (tsk_id_t) j, (tsk_id_t) j, + test_metadata, test_metadata_length); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.flags[j], j); + CU_ASSERT_EQUAL(table.time[j], j); + CU_ASSERT_EQUAL(table.population[j], j); + CU_ASSERT_EQUAL(table.individual[j], j); + CU_ASSERT_EQUAL(table.num_rows, j + 1); + CU_ASSERT_EQUAL(table.metadata_length, (j + 1) * test_metadata_length); + CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); + /* check the metadata */ + memcpy(metadata_copy, table.metadata + table.metadata_offset[j], test_metadata_length); + CU_ASSERT_NSTRING_EQUAL(metadata_copy, test_metadata, test_metadata_length); + ret = tsk_node_tbl_get_row(&table, j, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.id, j); + CU_ASSERT_EQUAL(node.flags, j); + CU_ASSERT_EQUAL(node.time, j); + CU_ASSERT_EQUAL(node.population, j); + CU_ASSERT_EQUAL(node.individual, j); + CU_ASSERT_EQUAL(node.metadata_length, test_metadata_length); + CU_ASSERT_NSTRING_EQUAL(node.metadata, test_metadata, test_metadata_length); + } + CU_ASSERT_EQUAL(tsk_node_tbl_get_row(&table, num_rows, &node), + TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_node_tbl_print_state(&table, _devnull); + tsk_node_tbl_dump_text(&table, _devnull); + + tsk_node_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + num_rows *= 2; + flags = malloc(num_rows * sizeof(uint32_t)); + CU_ASSERT_FATAL(flags != NULL); + memset(flags, 1, num_rows * sizeof(uint32_t)); + population = malloc(num_rows * sizeof(uint32_t)); + CU_ASSERT_FATAL(population != NULL); + memset(population, 2, num_rows * sizeof(uint32_t)); + time = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(time != NULL); + memset(time, 0, num_rows * sizeof(double)); + individual = malloc(num_rows * sizeof(uint32_t)); + CU_ASSERT_FATAL(individual != NULL); + memset(individual, 3, num_rows * sizeof(uint32_t)); + metadata = malloc(num_rows * sizeof(char)); + memset(metadata, 'a', num_rows * sizeof(char)); + CU_ASSERT_FATAL(metadata != NULL); + metadata_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(metadata_offset != NULL); + for (j = 0; j < num_rows + 1; j++) { + metadata_offset[j] = j; + } + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, time, population, + individual, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.population, population, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.individual, individual, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + tsk_node_tbl_print_state(&table, _devnull); + tsk_node_tbl_dump_text(&table, _devnull); + + /* Append another num_rows onto the end */ + ret = tsk_node_tbl_append_columns(&table, num_rows, flags, time, population, + individual, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.flags + num_rows, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.population, population, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.population + num_rows, population, + num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time + num_rows, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.individual, individual, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.individual + num_rows, individual, + num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata + num_rows, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 2 * num_rows); + tsk_node_tbl_print_state(&table, _devnull); + tsk_node_tbl_dump_text(&table, _devnull); + + /* Truncate back to the original number of rows. */ + ret = tsk_node_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.population, population, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.individual, individual, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + ret = tsk_node_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* If population is NULL it should be set to -1. If metadata is NULL all metadatas + * should be set to the empty string. If individual is NULL it should be set to -1. */ + num_rows = 10; + memset(population, 0xff, num_rows * sizeof(uint32_t)); + memset(individual, 0xff, num_rows * sizeof(uint32_t)); + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, time, NULL, NULL, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.population, population, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.individual, individual, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + /* flags and time cannot be NULL */ + ret = tsk_node_tbl_set_columns(&table, num_rows, NULL, time, population, individual, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, NULL, population, individual, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, time, population, individual, + NULL, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, time, population, individual, + metadata, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* if metadata and metadata_offset are both null, all metadatas are zero length */ + num_rows = 10; + memset(metadata_offset, 0, (num_rows + 1) * sizeof(tsk_tbl_size_t)); + ret = tsk_node_tbl_set_columns(&table, num_rows, flags, time, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + ret = tsk_node_tbl_append_columns(&table, num_rows, flags, time, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.flags + num_rows, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time + num_rows, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset + num_rows, metadata_offset, + num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + tsk_node_tbl_print_state(&table, _devnull); + tsk_node_tbl_dump_text(&table, _devnull); + + tsk_node_tbl_free(&table); + free(flags); + free(population); + free(time); + free(metadata); + free(metadata_offset); + free(individual); +} + +static void +test_edge_table(void) +{ + int ret; + tsk_edge_tbl_t table; + tsk_tbl_size_t num_rows = 100; + tsk_id_t j; + tsk_edge_t edge; + tsk_id_t *parent, *child; + double *left, *right; + + ret = tsk_edge_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_edge_tbl_set_max_rows_increment(&table, 1); + tsk_edge_tbl_print_state(&table, _devnull); + tsk_edge_tbl_dump_text(&table, _devnull); + + for (j = 0; j < (tsk_id_t) num_rows; j++) { + ret = tsk_edge_tbl_add_row(&table, (double) j, (double) j, j, j); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.left[j], j); + CU_ASSERT_EQUAL(table.right[j], j); + CU_ASSERT_EQUAL(table.parent[j], j); + CU_ASSERT_EQUAL(table.child[j], j); + CU_ASSERT_EQUAL(table.num_rows, j + 1); + ret = tsk_edge_tbl_get_row(&table, (tsk_tbl_size_t) j, &edge); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(edge.id, j); + CU_ASSERT_EQUAL(edge.left, j); + CU_ASSERT_EQUAL(edge.right, j); + CU_ASSERT_EQUAL(edge.parent, j); + CU_ASSERT_EQUAL(edge.child, j); + } + ret = tsk_edge_tbl_get_row(&table, num_rows, &edge); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGE_OUT_OF_BOUNDS); + tsk_edge_tbl_print_state(&table, _devnull); + tsk_edge_tbl_dump_text(&table, _devnull); + + num_rows *= 2; + left = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(left != NULL); + memset(left, 0, num_rows * sizeof(double)); + right = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(right != NULL); + memset(right, 0, num_rows * sizeof(double)); + parent = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(parent != NULL); + memset(parent, 1, num_rows * sizeof(tsk_id_t)); + child = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(child != NULL); + memset(child, 1, num_rows * sizeof(tsk_id_t)); + + ret = tsk_edge_tbl_set_columns(&table, num_rows, left, right, parent, child); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.child, child, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + /* Append another num_rows to the end. */ + ret = tsk_edge_tbl_append_columns(&table, num_rows, left, right, parent, child); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.left + num_rows, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right + num_rows, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent + num_rows, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.child, child, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.child + num_rows, child, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + + /* Truncate back to num_rows */ + ret = tsk_edge_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.child, child, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + + ret = tsk_edge_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* Inputs cannot be NULL */ + ret = tsk_edge_tbl_set_columns(&table, num_rows, NULL, right, parent, child); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_edge_tbl_set_columns(&table, num_rows, left, NULL, parent, child); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_edge_tbl_set_columns(&table, num_rows, left, right, NULL, child); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_edge_tbl_set_columns(&table, num_rows, left, right, parent, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + tsk_edge_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + + tsk_edge_tbl_free(&table); + free(left); + free(right); + free(parent); + free(child); +} + +static void +test_site_table(void) +{ + int ret; + tsk_site_tbl_t table; + tsk_tbl_size_t num_rows, j; + char *ancestral_state; + char *metadata; + double *position; + tsk_site_t site; + tsk_tbl_size_t *ancestral_state_offset; + tsk_tbl_size_t *metadata_offset; + + ret = tsk_site_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_site_tbl_set_max_rows_increment(&table, 1); + tsk_site_tbl_set_max_metadata_length_increment(&table, 1); + tsk_site_tbl_set_max_ancestral_state_length_increment(&table, 1); + tsk_site_tbl_print_state(&table, _devnull); + tsk_site_tbl_dump_text(&table, _devnull); + + ret = tsk_site_tbl_add_row(&table, 0, "A", 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(table.position[0], 0); + CU_ASSERT_EQUAL(table.ancestral_state_offset[0], 0); + CU_ASSERT_EQUAL(table.ancestral_state_offset[1], 1); + CU_ASSERT_EQUAL(table.ancestral_state_length, 1); + CU_ASSERT_EQUAL(table.metadata_offset[0], 0); + CU_ASSERT_EQUAL(table.metadata_offset[1], 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + CU_ASSERT_EQUAL(table.num_rows, 1); + + ret = tsk_site_tbl_get_row(&table, 0, &site); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(site.position, 0); + CU_ASSERT_EQUAL(site.ancestral_state_length, 1); + CU_ASSERT_NSTRING_EQUAL(site.ancestral_state, "A", 1); + CU_ASSERT_EQUAL(site.metadata_length, 0); + + ret = tsk_site_tbl_add_row(&table, 1, "AA", 2, "{}", 2); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(table.position[1], 1); + CU_ASSERT_EQUAL(table.ancestral_state_offset[2], 3); + CU_ASSERT_EQUAL(table.metadata_offset[1], 0); + CU_ASSERT_EQUAL(table.metadata_offset[2], 2); + CU_ASSERT_EQUAL(table.metadata_length, 2); + CU_ASSERT_EQUAL(table.num_rows, 2); + + ret = tsk_site_tbl_get_row(&table, 1, &site); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(site.position, 1); + CU_ASSERT_EQUAL(site.ancestral_state_length, 2); + CU_ASSERT_NSTRING_EQUAL(site.ancestral_state, "AA", 2); + CU_ASSERT_EQUAL(site.metadata_length, 2); + CU_ASSERT_NSTRING_EQUAL(site.metadata, "{}", 2); + + ret = tsk_site_tbl_add_row(&table, 2, "A", 1, "metadata", 8); + CU_ASSERT_EQUAL_FATAL(ret, 2); + CU_ASSERT_EQUAL(table.position[1], 1); + CU_ASSERT_EQUAL(table.ancestral_state_offset[3], 4); + CU_ASSERT_EQUAL(table.ancestral_state_length, 4); + CU_ASSERT_EQUAL(table.metadata_offset[3], 10); + CU_ASSERT_EQUAL(table.metadata_length, 10); + CU_ASSERT_EQUAL(table.num_rows, 3); + + ret = tsk_site_tbl_get_row(&table, 3, &site); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + + tsk_site_tbl_print_state(&table, _devnull); + tsk_site_tbl_dump_text(&table, _devnull); + tsk_site_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.ancestral_state_length, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + CU_ASSERT_EQUAL(table.ancestral_state_offset[0], 0); + CU_ASSERT_EQUAL(table.metadata_offset[0], 0); + + num_rows = 100; + position = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(position != NULL); + ancestral_state = malloc(num_rows * sizeof(char)); + CU_ASSERT_FATAL(ancestral_state != NULL); + ancestral_state_offset = malloc((num_rows + 1) * sizeof(uint32_t)); + CU_ASSERT_FATAL(ancestral_state_offset != NULL); + metadata = malloc(num_rows * sizeof(char)); + CU_ASSERT_FATAL(metadata != NULL); + metadata_offset = malloc((num_rows + 1) * sizeof(uint32_t)); + CU_ASSERT_FATAL(metadata_offset != NULL); + + for (j = 0; j < num_rows; j++) { + position[j] = (double) j; + ancestral_state[j] = (char) j; + ancestral_state_offset[j] = (tsk_tbl_size_t) j; + metadata[j] = (char) ('A' + j); + metadata_offset[j] = (tsk_tbl_size_t) j; + } + ancestral_state_offset[num_rows] = num_rows; + metadata_offset[num_rows] = num_rows; + + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.position, position, + num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.ancestral_state, ancestral_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.ancestral_state_length, num_rows); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + + /* Append another num rows */ + ret = tsk_site_tbl_append_columns(&table, num_rows, position, ancestral_state, + ancestral_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.position, position, + num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.position + num_rows, position, + num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.ancestral_state, ancestral_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.ancestral_state + num_rows, ancestral_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata + num_rows, metadata, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.ancestral_state_length, 2 * num_rows); + + /* truncate back to num_rows */ + ret = tsk_site_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.position, position, + num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.ancestral_state, ancestral_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.ancestral_state_length, num_rows); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + + ret = tsk_site_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* Inputs cannot be NULL */ + ret = tsk_site_tbl_set_columns(&table, num_rows, NULL, ancestral_state, + ancestral_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_site_tbl_set_columns(&table, num_rows, position, NULL, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_site_tbl_set_columns(&table, num_rows, position, ancestral_state, NULL, + metadata, metadata_offset); + /* Metadata and metadata_offset must both be null */ + ret = tsk_site_tbl_set_columns(&table, num_rows, position, ancestral_state, + ancestral_state_offset, NULL, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_site_tbl_set_columns(&table, num_rows, position, ancestral_state, + ancestral_state_offset, metadata, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* Set metadata to NULL */ + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, NULL, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + memset(metadata_offset, 0, (num_rows + 1) * sizeof(uint32_t)); + CU_ASSERT_EQUAL(memcmp(table.position, position, + num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.ancestral_state, ancestral_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.ancestral_state_length, num_rows); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + + /* Test for bad offsets */ + ancestral_state_offset[0] = 1; + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + ancestral_state_offset[0] = 0; + ancestral_state_offset[num_rows] = 0; + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + ancestral_state_offset[0] = 0; + + metadata_offset[0] = 0; + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + metadata_offset[0] = 0; + metadata_offset[num_rows] = 0; + ret = tsk_site_tbl_set_columns(&table, num_rows, position, + ancestral_state, ancestral_state_offset, + metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + + ret = tsk_site_tbl_clear(&table); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.ancestral_state_length, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + tsk_site_tbl_free(&table); + free(position); + free(ancestral_state); + free(ancestral_state_offset); + free(metadata); + free(metadata_offset); +} + +static void +test_mutation_table(void) +{ + int ret; + tsk_mutation_tbl_t table; + tsk_tbl_size_t num_rows = 100; + tsk_tbl_size_t max_len = 20; + tsk_tbl_size_t k, len; + tsk_id_t j; + tsk_id_t *node; + tsk_id_t *parent; + tsk_id_t *site; + char *derived_state, *metadata; + char c[max_len + 1]; + tsk_tbl_size_t *derived_state_offset, *metadata_offset; + tsk_mutation_t mutation; + + for (j = 0; j < (tsk_id_t) max_len; j++) { + c[j] = (char) ('A' + j); + } + + ret = tsk_mutation_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_mutation_tbl_set_max_rows_increment(&table, 1); + tsk_mutation_tbl_set_max_metadata_length_increment(&table, 1); + tsk_mutation_tbl_set_max_derived_state_length_increment(&table, 1); + tsk_mutation_tbl_print_state(&table, _devnull); + tsk_mutation_tbl_dump_text(&table, _devnull); + + len = 0; + for (j = 0; j < (tsk_id_t) num_rows; j++) { + k = TSK_MIN((tsk_tbl_size_t) j + 1, max_len); + ret = tsk_mutation_tbl_add_row(&table, j, j, j, c, k, c, k); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.site[j], j); + CU_ASSERT_EQUAL(table.node[j], j); + CU_ASSERT_EQUAL(table.parent[j], j); + CU_ASSERT_EQUAL(table.derived_state_offset[j], len); + CU_ASSERT_EQUAL(table.metadata_offset[j], len); + CU_ASSERT_EQUAL(table.num_rows, j + 1); + len += k; + CU_ASSERT_EQUAL(table.derived_state_offset[j + 1], len); + CU_ASSERT_EQUAL(table.derived_state_length, len); + CU_ASSERT_EQUAL(table.metadata_offset[j + 1], len); + CU_ASSERT_EQUAL(table.metadata_length, len); + + ret = tsk_mutation_tbl_get_row(&table, (tsk_tbl_size_t) j, &mutation); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(mutation.id, j); + CU_ASSERT_EQUAL(mutation.site, j); + CU_ASSERT_EQUAL(mutation.node, j); + CU_ASSERT_EQUAL(mutation.parent, j); + CU_ASSERT_EQUAL(mutation.metadata_length, k); + CU_ASSERT_NSTRING_EQUAL(mutation.metadata, c, k); + CU_ASSERT_EQUAL(mutation.derived_state_length, k); + CU_ASSERT_NSTRING_EQUAL(mutation.derived_state, c, k); + } + ret = tsk_mutation_tbl_get_row(&table, num_rows, &mutation); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + tsk_mutation_tbl_print_state(&table, _devnull); + tsk_mutation_tbl_dump_text(&table, _devnull); + + num_rows *= 2; + site = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(site != NULL); + node = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(node != NULL); + parent = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(parent != NULL); + derived_state = malloc(num_rows * sizeof(char)); + CU_ASSERT_FATAL(derived_state != NULL); + derived_state_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(derived_state_offset != NULL); + metadata = malloc(num_rows * sizeof(char)); + CU_ASSERT_FATAL(metadata != NULL); + metadata_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(metadata_offset != NULL); + + for (j = 0; j < (tsk_id_t) num_rows; j++) { + node[j] = j; + site[j] = j + 1; + parent[j] = j + 2; + derived_state[j] = 'Y'; + derived_state_offset[j] = (tsk_tbl_size_t) j; + metadata[j] = 'M'; + metadata_offset[j] = (tsk_tbl_size_t) j; + } + + derived_state_offset[num_rows] = num_rows; + metadata_offset[num_rows] = num_rows; + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.site, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.derived_state_length, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + /* Append another num_rows */ + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, parent, derived_state, + derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.site, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.site + num_rows, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node + num_rows, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent + num_rows, parent, + num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.derived_state_length, 2 * num_rows); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.metadata_length, 2 * num_rows); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + + /* Truncate back to num_rows */ + ret = tsk_mutation_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.site, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.derived_state_length, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + ret = tsk_mutation_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* Check all this again, except with parent == NULL and metadata == NULL. */ + memset(parent, 0xff, num_rows * sizeof(tsk_id_t)); + memset(metadata_offset, 0, (num_rows + 1) * sizeof(tsk_tbl_size_t)); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, NULL, + derived_state, derived_state_offset, NULL, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.site, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state_offset, derived_state_offset, + num_rows * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.derived_state_length, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + /* Append another num_rows */ + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, NULL, derived_state, + derived_state_offset, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.site, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.site + num_rows, site, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node + num_rows, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent, parent, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.parent + num_rows, parent, + num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.derived_state + num_rows, derived_state, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.derived_state_length, 2 * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + + /* Inputs except parent, metadata and metadata_offset cannot be NULL*/ + ret = tsk_mutation_tbl_set_columns(&table, num_rows, NULL, node, parent, + derived_state, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, NULL, parent, + derived_state, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + NULL, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, NULL, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, NULL, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, metadata, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* Inputs except parent, metadata and metadata_offset cannot be NULL*/ + ret = tsk_mutation_tbl_append_columns(&table, num_rows, NULL, node, parent, + derived_state, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, NULL, parent, + derived_state, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, parent, + NULL, derived_state_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, parent, + derived_state, NULL, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, NULL, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_mutation_tbl_append_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, metadata, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* Test for bad offsets */ + derived_state_offset[0] = 1; + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, NULL, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + derived_state_offset[0] = 0; + derived_state_offset[num_rows] = 0; + ret = tsk_mutation_tbl_set_columns(&table, num_rows, site, node, parent, + derived_state, derived_state_offset, NULL, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + + tsk_mutation_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.derived_state_length, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + tsk_mutation_tbl_free(&table); + free(site); + free(node); + free(parent); + free(derived_state); + free(derived_state_offset); + free(metadata); + free(metadata_offset); +} + +static void +test_migration_table(void) +{ + int ret; + tsk_migration_tbl_t table; + tsk_tbl_size_t num_rows = 100; + tsk_id_t j; + tsk_id_t *node; + tsk_id_t *source, *dest; + double *left, *right, *time; + tsk_migration_t migration; + + ret = tsk_migration_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_migration_tbl_set_max_rows_increment(&table, 1); + tsk_migration_tbl_print_state(&table, _devnull); + tsk_migration_tbl_dump_text(&table, _devnull); + + for (j = 0; j < (tsk_id_t) num_rows; j++) { + ret = tsk_migration_tbl_add_row(&table, j, j, j, j, j, j); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.left[j], j); + CU_ASSERT_EQUAL(table.right[j], j); + CU_ASSERT_EQUAL(table.node[j], j); + CU_ASSERT_EQUAL(table.source[j], j); + CU_ASSERT_EQUAL(table.dest[j], j); + CU_ASSERT_EQUAL(table.time[j], j); + CU_ASSERT_EQUAL(table.num_rows, j + 1); + + ret = tsk_migration_tbl_get_row(&table, (tsk_tbl_size_t) j, &migration); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(migration.id, j); + CU_ASSERT_EQUAL(migration.left, j); + CU_ASSERT_EQUAL(migration.right, j); + CU_ASSERT_EQUAL(migration.node, j); + CU_ASSERT_EQUAL(migration.source, j); + CU_ASSERT_EQUAL(migration.dest, j); + CU_ASSERT_EQUAL(migration.time, j); + } + ret = tsk_migration_tbl_get_row(&table, num_rows, &migration); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATION_OUT_OF_BOUNDS); + tsk_migration_tbl_print_state(&table, _devnull); + tsk_migration_tbl_dump_text(&table, _devnull); + + num_rows *= 2; + left = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(left != NULL); + memset(left, 1, num_rows * sizeof(double)); + right = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(right != NULL); + memset(right, 2, num_rows * sizeof(double)); + time = malloc(num_rows * sizeof(double)); + CU_ASSERT_FATAL(time != NULL); + memset(time, 3, num_rows * sizeof(double)); + node = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(node != NULL); + memset(node, 4, num_rows * sizeof(tsk_id_t)); + source = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(source != NULL); + memset(source, 5, num_rows * sizeof(tsk_id_t)); + dest = malloc(num_rows * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(dest != NULL); + memset(dest, 6, num_rows * sizeof(tsk_id_t)); + + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, right, node, source, + dest, time); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.source, source, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.dest, dest, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + /* Append another num_rows */ + ret = tsk_migration_tbl_append_columns(&table, num_rows, left, right, node, source, + dest, time); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.left + num_rows, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right + num_rows, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time + num_rows, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.node + num_rows, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.source, source, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.source + num_rows, source, + num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.dest, dest, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.dest + num_rows, dest, + num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + + /* Truncate back to num_rows */ + ret = tsk_migration_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.left, left, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.right, right, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.time, time, num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.node, node, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.source, source, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.dest, dest, num_rows * sizeof(tsk_id_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + + ret = tsk_migration_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* inputs cannot be NULL */ + ret = tsk_migration_tbl_set_columns(&table, num_rows, NULL, right, node, source, + dest, time); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, NULL, node, source, + dest, time); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, right, NULL, source, + dest, time); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, right, node, NULL, + dest, time); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, right, node, source, + NULL, time); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_migration_tbl_set_columns(&table, num_rows, left, right, node, source, + dest, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + tsk_migration_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + + tsk_migration_tbl_free(&table); + free(left); + free(right); + free(time); + free(node); + free(source); + free(dest); +} + +static void +test_individual_table(void) +{ + int ret = 0; + tsk_individual_tbl_t table; + /* tsk_tbl_collection_t tables, tables2; */ + tsk_tbl_size_t num_rows = 100; + tsk_id_t j; + tsk_tbl_size_t k; + uint32_t *flags; + double *location; + char *metadata; + tsk_tbl_size_t *metadata_offset; + tsk_tbl_size_t *location_offset; + tsk_individual_t individual; + const char *test_metadata = "test"; + tsk_tbl_size_t test_metadata_length = 4; + char metadata_copy[test_metadata_length + 1]; + tsk_tbl_size_t spatial_dimension = 2; + double test_location[spatial_dimension]; + + for (k = 0; k < spatial_dimension; k++) { + test_location[k] = (double) k; + } + metadata_copy[test_metadata_length] = '\0'; + ret = tsk_individual_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_individual_tbl_set_max_rows_increment(&table, 1); + tsk_individual_tbl_set_max_metadata_length_increment(&table, 1); + tsk_individual_tbl_set_max_location_length_increment(&table, 1); + + tsk_individual_tbl_print_state(&table, _devnull); + + for (j = 0; j < (tsk_id_t) num_rows; j++) { + ret = tsk_individual_tbl_add_row(&table, (uint32_t) j, test_location, + (size_t) spatial_dimension, test_metadata, test_metadata_length); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.flags[j], j); + for (k = 0; k < spatial_dimension; k++) { + test_location[k] = (double) k; + CU_ASSERT_EQUAL(table.location[spatial_dimension * (size_t) j + k], + test_location[k]); + } + CU_ASSERT_EQUAL(table.metadata_length, (tsk_tbl_size_t) (j + 1) * test_metadata_length); + CU_ASSERT_EQUAL(table.metadata_offset[j + 1], table.metadata_length); + /* check the metadata */ + memcpy(metadata_copy, table.metadata + table.metadata_offset[j], + test_metadata_length); + CU_ASSERT_NSTRING_EQUAL(metadata_copy, test_metadata, test_metadata_length); + + ret = tsk_individual_tbl_get_row(&table, (tsk_tbl_size_t) j, &individual); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(individual.id, j); + CU_ASSERT_EQUAL(individual.flags, j); + CU_ASSERT_EQUAL(individual.location_length, spatial_dimension); + CU_ASSERT_NSTRING_EQUAL(individual.location, test_location, + spatial_dimension * sizeof(double)); + CU_ASSERT_EQUAL(individual.metadata_length, test_metadata_length); + CU_ASSERT_NSTRING_EQUAL(individual.metadata, test_metadata, test_metadata_length); + } + ret = tsk_individual_tbl_get_row(&table, num_rows, &individual); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + tsk_individual_tbl_print_state(&table, _devnull); + tsk_individual_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + num_rows *= 2; + flags = malloc(num_rows * sizeof(uint32_t)); + CU_ASSERT_FATAL(flags != NULL); + memset(flags, 1, num_rows * sizeof(uint32_t)); + location = malloc(spatial_dimension * num_rows * sizeof(double)); + CU_ASSERT_FATAL(location != NULL); + memset(location, 0, spatial_dimension * num_rows * sizeof(double)); + location_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(location_offset != NULL); + for (j = 0; j < (tsk_id_t) num_rows + 1; j++) { + location_offset[j] = (tsk_tbl_size_t) j * spatial_dimension; + } + metadata = malloc(num_rows * sizeof(char)); + memset(metadata, 'a', num_rows * sizeof(char)); + CU_ASSERT_FATAL(metadata != NULL); + metadata_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(metadata_offset != NULL); + for (j = 0; j < (tsk_id_t) num_rows + 1; j++) { + metadata_offset[j] = (tsk_tbl_size_t) j; + } + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + location, location_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.location, location, + spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.location_offset, location_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.location_length, spatial_dimension * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + tsk_individual_tbl_print_state(&table, _devnull); + + /* Append another num_rows onto the end */ + ret = tsk_individual_tbl_append_columns(&table, num_rows, flags, location, + location_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.flags + num_rows, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata + num_rows, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.location, location, + spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.location + spatial_dimension * num_rows, + location, spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 2 * num_rows); + tsk_individual_tbl_print_state(&table, _devnull); + tsk_individual_tbl_dump_text(&table, _devnull); + + /* Truncate back to num_rows */ + ret = tsk_individual_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.location, location, + spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.location_offset, location_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.location_length, spatial_dimension * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + tsk_individual_tbl_print_state(&table, _devnull); + + ret = tsk_individual_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* flags can't be NULL */ + ret = tsk_individual_tbl_set_columns(&table, num_rows, NULL, + location, location_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + /* location and location offset must be simultaneously NULL or not */ + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + location, NULL, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + NULL, location_offset, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + /* metadata and metadata offset must be simultaneously NULL or not */ + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + location, location_offset, NULL, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + location, location_offset, metadata, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* if location and location_offset are both null, all locations are zero length */ + num_rows = 10; + memset(location_offset, 0, (num_rows + 1) * sizeof(tsk_tbl_size_t)); + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.location_offset, location_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.location_length, 0); + ret = tsk_individual_tbl_append_columns(&table, num_rows, flags, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.location_offset, location_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.location_offset + num_rows, location_offset, + num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.location_length, 0); + tsk_individual_tbl_print_state(&table, _devnull); + tsk_individual_tbl_dump_text(&table, _devnull); + + /* if metadata and metadata_offset are both null, all metadatas are zero length */ + num_rows = 10; + memset(metadata_offset, 0, (num_rows + 1) * sizeof(tsk_tbl_size_t)); + ret = tsk_individual_tbl_set_columns(&table, num_rows, flags, + location, location_offset, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.flags, flags, num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.location, location, + spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + ret = tsk_individual_tbl_append_columns(&table, num_rows, flags, location, + location_offset, NULL, NULL); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.location, location, + spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.location + spatial_dimension * num_rows, + location, spatial_dimension * num_rows * sizeof(double)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset, metadata_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata_offset + num_rows, metadata_offset, + num_rows * sizeof(uint32_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.metadata_length, 0); + tsk_individual_tbl_print_state(&table, _devnull); + tsk_individual_tbl_dump_text(&table, _devnull); + + ret = tsk_individual_tbl_free(&table); + CU_ASSERT_EQUAL(ret, 0); + free(flags); + free(location); + free(location_offset); + free(metadata); + free(metadata_offset); +} + +static void +test_population_table(void) +{ + int ret; + tsk_population_tbl_t table; + tsk_tbl_size_t num_rows = 100; + tsk_tbl_size_t max_len = 20; + tsk_tbl_size_t k, len; + tsk_id_t j; + char *metadata; + char c[max_len + 1]; + tsk_tbl_size_t *metadata_offset; + tsk_population_t population; + + for (j = 0; j < (tsk_id_t) max_len; j++) { + c[j] = (char) ('A' + j); + } + + ret = tsk_population_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_population_tbl_set_max_rows_increment(&table, 1); + tsk_population_tbl_set_max_metadata_length_increment(&table, 1); + tsk_population_tbl_print_state(&table, _devnull); + tsk_population_tbl_dump_text(&table, _devnull); + /* Adding zero length metadata with NULL should be fine */ + + ret = tsk_population_tbl_add_row(&table, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(table.metadata_length, 0); + CU_ASSERT_EQUAL(table.num_rows, 1); + CU_ASSERT_EQUAL(table.metadata_offset[0], 0); + CU_ASSERT_EQUAL(table.metadata_offset[1], 0); + tsk_population_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + + len = 0; + for (j = 0; j < (tsk_id_t) num_rows; j++) { + k = TSK_MIN((tsk_tbl_size_t) j + 1, max_len); + ret = tsk_population_tbl_add_row(&table, c, k); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.metadata_offset[j], len); + CU_ASSERT_EQUAL(table.num_rows, j + 1); + len += k; + CU_ASSERT_EQUAL(table.metadata_offset[j + 1], len); + CU_ASSERT_EQUAL(table.metadata_length, len); + + ret = tsk_population_tbl_get_row(&table, (tsk_tbl_size_t) j, &population); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(population.id, j); + CU_ASSERT_EQUAL(population.metadata_length, k); + CU_ASSERT_NSTRING_EQUAL(population.metadata, c, k); + } + ret = tsk_population_tbl_get_row(&table, num_rows, &population); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_population_tbl_print_state(&table, _devnull); + tsk_population_tbl_dump_text(&table, _devnull); + + num_rows *= 2; + metadata = malloc(num_rows * sizeof(char)); + CU_ASSERT_FATAL(metadata != NULL); + metadata_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(metadata_offset != NULL); + + for (j = 0; j < (tsk_id_t) num_rows; j++) { + metadata[j] = 'M'; + metadata_offset[j] = (tsk_tbl_size_t) j; + } + + metadata_offset[num_rows] = num_rows; + ret = tsk_population_tbl_set_columns(&table, num_rows, metadata, metadata_offset); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + /* Append another num_rows */ + ret = tsk_population_tbl_append_columns(&table, num_rows, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.metadata + num_rows, metadata, + num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.metadata_length, 2 * num_rows); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + + /* Truncate back to num_rows */ + ret = tsk_population_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.metadata, metadata, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.metadata_length, num_rows); + + ret = tsk_population_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* Metadata = NULL gives an error */ + ret = tsk_population_tbl_set_columns(&table, num_rows, NULL, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_population_tbl_set_columns(&table, num_rows, metadata, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_population_tbl_set_columns(&table, num_rows, NULL, metadata_offset); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* Test for bad offsets */ + metadata_offset[0] = 1; + ret = tsk_population_tbl_set_columns(&table, num_rows, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + metadata_offset[0] = 0; + metadata_offset[num_rows] = 0; + ret = tsk_population_tbl_set_columns(&table, num_rows, metadata, metadata_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + + tsk_population_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.metadata_length, 0); + + tsk_population_tbl_free(&table); + free(metadata); + free(metadata_offset); +} + +static void +test_provenance_table(void) +{ + int ret; + tsk_provenance_tbl_t table; + tsk_tbl_size_t num_rows = 100; + tsk_tbl_size_t j; + char *timestamp; + uint32_t *timestamp_offset; + const char *test_timestamp = "2017-12-06T20:40:25+00:00"; + size_t test_timestamp_length = strlen(test_timestamp); + char timestamp_copy[test_timestamp_length + 1]; + char *record; + uint32_t *record_offset; + const char *test_record = "{\"json\"=1234}"; + size_t test_record_length = strlen(test_record); + char record_copy[test_record_length + 1]; + tsk_provenance_t provenance; + + timestamp_copy[test_timestamp_length] = '\0'; + record_copy[test_record_length] = '\0'; + ret = tsk_provenance_tbl_alloc(&table, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_provenance_tbl_set_max_rows_increment(&table, 1); + tsk_provenance_tbl_set_max_timestamp_length_increment(&table, 1); + tsk_provenance_tbl_set_max_record_length_increment(&table, 1); + tsk_provenance_tbl_print_state(&table, _devnull); + tsk_provenance_tbl_dump_text(&table, _devnull); + + for (j = 0; j < num_rows; j++) { + ret = tsk_provenance_tbl_add_row(&table, test_timestamp, test_timestamp_length, + test_record, test_record_length); + CU_ASSERT_EQUAL_FATAL(ret, j); + CU_ASSERT_EQUAL(table.timestamp_length, (j + 1) * test_timestamp_length); + CU_ASSERT_EQUAL(table.timestamp_offset[j + 1], table.timestamp_length); + CU_ASSERT_EQUAL(table.record_length, (j + 1) * test_record_length); + CU_ASSERT_EQUAL(table.record_offset[j + 1], table.record_length); + /* check the timestamp */ + memcpy(timestamp_copy, table.timestamp + table.timestamp_offset[j], + test_timestamp_length); + CU_ASSERT_NSTRING_EQUAL(timestamp_copy, test_timestamp, test_timestamp_length); + /* check the record */ + memcpy(record_copy, table.record + table.record_offset[j], + test_record_length); + CU_ASSERT_NSTRING_EQUAL(record_copy, test_record, test_record_length); + + ret = tsk_provenance_tbl_get_row(&table, j, &provenance); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(provenance.id, j); + CU_ASSERT_EQUAL(provenance.timestamp_length, test_timestamp_length); + CU_ASSERT_NSTRING_EQUAL(provenance.timestamp, test_timestamp, + test_timestamp_length); + CU_ASSERT_EQUAL(provenance.record_length, test_record_length); + CU_ASSERT_NSTRING_EQUAL(provenance.record, test_record, + test_record_length); + } + ret = tsk_provenance_tbl_get_row(&table, num_rows, &provenance); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_PROVENANCE_OUT_OF_BOUNDS); + tsk_provenance_tbl_print_state(&table, _devnull); + tsk_provenance_tbl_dump_text(&table, _devnull); + tsk_provenance_tbl_clear(&table); + CU_ASSERT_EQUAL(table.num_rows, 0); + CU_ASSERT_EQUAL(table.timestamp_length, 0); + CU_ASSERT_EQUAL(table.record_length, 0); + + num_rows *= 2; + timestamp = malloc(num_rows * sizeof(char)); + memset(timestamp, 'a', num_rows * sizeof(char)); + CU_ASSERT_FATAL(timestamp != NULL); + timestamp_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(timestamp_offset != NULL); + record = malloc(num_rows * sizeof(char)); + memset(record, 'a', num_rows * sizeof(char)); + CU_ASSERT_FATAL(record != NULL); + record_offset = malloc((num_rows + 1) * sizeof(tsk_tbl_size_t)); + CU_ASSERT_FATAL(record_offset != NULL); + for (j = 0; j < num_rows + 1; j++) { + timestamp_offset[j] = j; + record_offset[j] = j; + } + ret = tsk_provenance_tbl_set_columns(&table, num_rows, + timestamp, timestamp_offset, record, record_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp, timestamp, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp_offset, timestamp_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.record, record, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.record_offset, record_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.timestamp_length, num_rows); + CU_ASSERT_EQUAL(table.record_length, num_rows); + tsk_provenance_tbl_print_state(&table, _devnull); + + /* Append another num_rows onto the end */ + ret = tsk_provenance_tbl_append_columns(&table, num_rows, + timestamp, timestamp_offset, record, record_offset); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp, timestamp, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp + num_rows, timestamp, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.record, record, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.record + num_rows, record, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(table.num_rows, 2 * num_rows); + CU_ASSERT_EQUAL(table.timestamp_length, 2 * num_rows); + CU_ASSERT_EQUAL(table.record_length, 2 * num_rows); + tsk_provenance_tbl_print_state(&table, _devnull); + + /* Truncate back to num_rows */ + ret = tsk_provenance_tbl_truncate(&table, num_rows); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp, timestamp, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.timestamp_offset, timestamp_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(memcmp(table.record, record, num_rows * sizeof(char)), 0); + CU_ASSERT_EQUAL(memcmp(table.record_offset, record_offset, + (num_rows + 1) * sizeof(tsk_tbl_size_t)), 0); + CU_ASSERT_EQUAL(table.num_rows, num_rows); + CU_ASSERT_EQUAL(table.timestamp_length, num_rows); + CU_ASSERT_EQUAL(table.record_length, num_rows); + tsk_provenance_tbl_print_state(&table, _devnull); + + ret = tsk_provenance_tbl_truncate(&table, num_rows + 1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_TABLE_POSITION); + + /* No arguments can be null */ + ret = tsk_provenance_tbl_set_columns(&table, num_rows, NULL, timestamp_offset, + record, record_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_provenance_tbl_set_columns(&table, num_rows, timestamp, NULL, + record, record_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_provenance_tbl_set_columns(&table, num_rows, timestamp, timestamp_offset, + NULL, record_offset); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_provenance_tbl_set_columns(&table, num_rows, timestamp, timestamp_offset, + record, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + tsk_provenance_tbl_free(&table); + free(timestamp); + free(timestamp_offset); + free(record); + free(record_offset); +} + +static void +test_simplify_tables_drops_indexes(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_TRUE(tsk_tbl_collection_is_indexed(&tables)) + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_tbl_collection_is_indexed(&tables)) + + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_sort_tables_drops_indexes(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_TRUE(tsk_tbl_collection_is_indexed(&tables)) + ret = tsk_tbl_collection_sort(&tables, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_tbl_collection_is_indexed(&tables)) + + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +int +main(int argc, char **argv) +{ + CU_TestInfo tests[] = { + {"test_node_table", test_node_table}, + {"test_edge_table", test_edge_table}, + {"test_site_table", test_site_table}, + {"test_mutation_table", test_mutation_table}, + {"test_migration_table", test_migration_table}, + {"test_individual_table", test_individual_table}, + {"test_population_table", test_population_table}, + {"test_provenance_table", test_provenance_table}, + {"test_format_data_load_errors", test_format_data_load_errors}, + {"test_dump_unindexed", test_dump_unindexed}, + {"test_tbl_collection_load_errors", test_tbl_collection_load_errors}, + {"test_tbl_collection_dump_errors", test_tbl_collection_dump_errors}, + {"test_tbl_collection_simplify_errors", test_tbl_collection_simplify_errors}, + {"test_load_tsk_node_tbl_errors", test_load_tsk_node_tbl_errors}, + {"test_simplify_tables_drops_indexes", test_simplify_tables_drops_indexes}, + {"test_sort_tables_drops_indexes", test_sort_tables_drops_indexes}, + {NULL}, + }; + + return test_main(tests, argc, argv); +} diff --git a/c/test_trees.c b/c/test_trees.c new file mode 100644 index 0000000000..e74bfe2df0 --- /dev/null +++ b/c/test_trees.c @@ -0,0 +1,4194 @@ +#include "testlib.h" +#include "tsk_trees.h" +#include "tsk_genotypes.h" + +#include +#include + + +/*======================================================= + * Verification utilities. + *======================================================*/ + +static void +verify_compute_mutation_parents(tsk_treeseq_t *ts) +{ + int ret; + size_t size = tsk_treeseq_get_num_mutations(ts) * sizeof(tsk_id_t); + tsk_id_t *parent = malloc(size); + tsk_tbl_collection_t tables; + + CU_ASSERT_FATAL(parent != NULL); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + memcpy(parent, tables.mutations->parent, size); + /* tsk_tbl_collection_print_state(&tables, stdout); */ + /* Make sure the tables are actually updated */ + memset(tables.mutations->parent, 0xff, size); + + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(memcmp(parent, tables.mutations->parent, size), 0); + /* printf("after\n"); */ + /* tsk_tbl_collection_print_state(&tables, stdout); */ + + free(parent); + tsk_tbl_collection_free(&tables); +} + +static void +verify_individual_nodes(tsk_treeseq_t *ts) +{ + int ret; + tsk_individual_t individual; + tsk_id_t k; + size_t num_nodes = tsk_treeseq_get_num_nodes(ts); + size_t num_individuals = tsk_treeseq_get_num_individuals(ts); + size_t j; + + for (k = 0; k < (tsk_id_t) num_individuals; k++) { + ret = tsk_treeseq_get_individual(ts, (size_t) k, &individual); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < individual.nodes_length; j++) { + CU_ASSERT_FATAL(individual.nodes[j] < (tsk_id_t) num_nodes); + CU_ASSERT_EQUAL_FATAL(k, + ts->tables->nodes->individual[individual.nodes[j]]); + } + } +} + +static void +verify_trees(tsk_treeseq_t *ts, uint32_t num_trees, tsk_id_t* parents) +{ + int ret; + tsk_id_t u, j, v; + uint32_t mutation_index, site_index; + tsk_tbl_size_t k, l, tree_sites_length; + tsk_site_t *sites = NULL; + tsk_tree_t tree; + size_t num_nodes = tsk_treeseq_get_num_nodes(ts); + size_t num_sites = tsk_treeseq_get_num_sites(ts); + size_t num_mutations = tsk_treeseq_get_num_mutations(ts); + + ret = tsk_tree_alloc(&tree, ts, 0); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(ts), num_trees); + + site_index = 0; + mutation_index = 0; + j = 0; + for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { + CU_ASSERT_EQUAL(j, tree.index); + tsk_tree_print_state(&tree, _devnull); + /* tsk_tree_print_state(&tree, stdout); */ + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + ret = tsk_tree_get_parent(&tree, u, &v); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(v, parents[j * (tsk_id_t) num_nodes + u]); + } + ret = tsk_tree_get_sites(&tree, &sites, &tree_sites_length); + CU_ASSERT_EQUAL(ret, 0); + for (k = 0; k < tree_sites_length; k++) { + CU_ASSERT_EQUAL(sites[k].id, site_index); + for (l = 0; l < sites[k].mutations_length; l++) { + CU_ASSERT_EQUAL(sites[k].mutations[l].id, mutation_index); + CU_ASSERT_EQUAL(sites[k].mutations[l].site, site_index); + mutation_index++; + } + site_index++; + } + j++; + } + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(site_index, num_sites); + CU_ASSERT_EQUAL(mutation_index, num_mutations); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tree_free(&tree); +} + +static tsk_tree_t * +get_tree_list(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t t, *trees; + size_t num_trees; + + num_trees = tsk_treeseq_get_num_trees(ts); + ret = tsk_tree_alloc(&t, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + trees = malloc(num_trees * sizeof(tsk_tree_t)); + CU_ASSERT_FATAL(trees != NULL); + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_FATAL(t.index < num_trees); + ret = tsk_tree_alloc(&trees[t.index], ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_copy(&trees[t.index], &t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_equal(&trees[t.index], &t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* Make sure the left and right coordinates are also OK */ + CU_ASSERT_DOUBLE_EQUAL(trees[t.index].left, t.left, 1e-6); + CU_ASSERT_DOUBLE_EQUAL(trees[t.index].right, t.right, 1e-6); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + return trees; +} + + +static void +verify_tree_next_prev(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t *trees, t; + size_t j; + size_t num_trees = tsk_treeseq_get_num_trees(ts); + + trees = get_tree_list(ts); + ret = tsk_tree_alloc(&t, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Single forward pass */ + j = 0; + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + + /* Single reverse pass */ + j = num_trees; + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + if (ret != 0) { + printf("trees differ\n"); + printf("REVERSE tree::\n"); + tsk_tree_print_state(&t, stdout); + printf("FORWARD tree::\n"); + tsk_tree_print_state(&trees[t.index], stdout); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + + /* Full forward, then reverse */ + j = 0; + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + j--; + while ((ret = tsk_tree_prev(&t)) == 1) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + CU_ASSERT_EQUAL_FATAL(t.index, 0); + /* Calling prev should return 0 and have no effect. */ + for (j = 0; j < 10; j++) { + ret = tsk_tree_prev(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, 0); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Full reverse then forward */ + j = num_trees; + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + CU_ASSERT_EQUAL_FATAL(j - 1, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j--; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, 0); + j++; + while ((ret = tsk_tree_next(&t)) == 1) { + CU_ASSERT_EQUAL_FATAL(j, t.index); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + j++; + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(j, num_trees); + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - 1); + /* Calling next should return 0 and have no effect. */ + for (j = 0; j < 10; j++) { + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - 1); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Do a zigzagging traversal */ + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 1; j < TSK_MIN(10, num_trees / 2); j++) { + while (t.index < num_trees - j) { + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + CU_ASSERT_EQUAL_FATAL(t.index, num_trees - j); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + while (t.index > j) { + ret = tsk_tree_prev(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + CU_ASSERT_EQUAL_FATAL(t.index, j); + ret = tsk_tree_equal(&t, &trees[t.index]); + CU_ASSERT_EQUAL_FATAL(ret, 0); + } + + /* Free the trees. */ + ret = tsk_tree_free(&t); + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < tsk_treeseq_get_num_trees(ts); j++) { + ret = tsk_tree_free(&trees[j]); + } + free(trees); +} + +static void +verify_tree_diffs(tsk_treeseq_t *ts) +{ + int ret; + tsk_diff_iter_t iter; + tsk_tree_t tree; + tsk_edge_list_t *record, *records_out, *records_in; + size_t num_nodes = tsk_treeseq_get_num_nodes(ts); + size_t j, num_trees; + double left, right; + tsk_id_t *parent = malloc(num_nodes * sizeof(tsk_id_t)); + tsk_id_t *child = malloc(num_nodes * sizeof(tsk_id_t)); + tsk_id_t *sib = malloc(num_nodes * sizeof(tsk_id_t)); + tsk_id_t *samples; + + CU_ASSERT_FATAL(parent != NULL); + CU_ASSERT_FATAL(child != NULL); + CU_ASSERT_FATAL(sib != NULL); + for (j = 0; j < num_nodes; j++) { + parent[j] = TSK_NULL; + child[j] = TSK_NULL; + sib[j] = TSK_NULL; + } + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_diff_iter_alloc(&iter, ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_alloc(&tree, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + tsk_diff_iter_print_state(&iter, _devnull); + + num_trees = 0; + while ((ret = tsk_diff_iter_next( + &iter, &left, &right, &records_out, &records_in)) == 1) { + tsk_diff_iter_print_state(&iter, _devnull); + num_trees++; + for (record = records_out; record != NULL; record = record->next) { + parent[record->edge.child] = TSK_NULL; + } + for (record = records_in; record != NULL; record = record->next) { + parent[record->edge.child] = record->edge.parent; + } + /* Now check against the sparse tree iterator. */ + for (j = 0; j < num_nodes; j++) { + CU_ASSERT_EQUAL(parent[j], tree.parent[j]); + } + CU_ASSERT_EQUAL(tree.left, left); + CU_ASSERT_EQUAL(tree.right, right); + ret = tsk_tree_next(&tree); + if (num_trees < tsk_treeseq_get_num_trees(ts)) { + CU_ASSERT_EQUAL(ret, 1); + } else { + CU_ASSERT_EQUAL(ret, 0); + } + } + CU_ASSERT_EQUAL(num_trees, tsk_treeseq_get_num_trees(ts)); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_diff_iter_free(&iter); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_free(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + free(parent); + free(child); + free(sib); +} + +/* When we keep all sites in simplify, the genotypes for the subset of the + * samples should be the same as the original */ +static void +verify_simplify_genotypes(tsk_treeseq_t *ts, tsk_treeseq_t *subset, + tsk_id_t *samples, size_t num_samples) +{ + int ret; + size_t m = tsk_treeseq_get_num_sites(ts); + tsk_vargen_t vargen, subset_vargen; + tsk_variant_t *variant, *subset_variant; + size_t j, k; + tsk_id_t *all_samples; + uint8_t a1, a2; + tsk_id_t *sample_index_map; + + tsk_treeseq_get_sample_index_map(ts, &sample_index_map); + + /* tsk_treeseq_print_state(ts, stdout); */ + /* tsk_treeseq_print_state(subset, stdout); */ + + ret = tsk_vargen_alloc(&vargen, ts, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_vargen_alloc(&subset_vargen, subset, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(m, tsk_treeseq_get_num_sites(subset)); + tsk_treeseq_get_samples(ts, &all_samples); + + for (j = 0; j < m; j++) { + ret = tsk_vargen_next(&vargen, &variant); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_vargen_next(&subset_vargen, &subset_variant); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(variant->site->id, j) + CU_ASSERT_EQUAL(subset_variant->site->id, j) + CU_ASSERT_EQUAL(variant->site->position, subset_variant->site->position); + for (k = 0; k < num_samples; k++) { + CU_ASSERT_FATAL(sample_index_map[samples[k]] < (tsk_id_t) ts->num_samples); + a1 = variant->genotypes.u8[sample_index_map[samples[k]]]; + a2 = subset_variant->genotypes.u8[k]; + /* printf("a1 = %d, a2 = %d\n", a1, a2); */ + /* printf("k = %d original node = %d " */ + /* "original_index = %d a1=%.*s a2=%.*s\n", */ + /* (int) k, samples[k], sample_index_map[samples[k]], */ + /* variant->allele_lengths[a1], variant->alleles[a1], */ + /* subset_variant->allele_lengths[a2], subset_variant->alleles[a2]); */ + CU_ASSERT_FATAL(a1 < variant->num_alleles); + CU_ASSERT_FATAL(a2 < subset_variant->num_alleles); + CU_ASSERT_EQUAL_FATAL(variant->allele_lengths[a1], + subset_variant->allele_lengths[a2]); + CU_ASSERT_NSTRING_EQUAL_FATAL( + variant->alleles[a1], subset_variant->alleles[a2], + variant->allele_lengths[a1]); + } + } + tsk_vargen_free(&vargen); + tsk_vargen_free(&subset_vargen); +} + + +static void +verify_simplify_properties(tsk_treeseq_t *ts, tsk_treeseq_t *subset, + tsk_id_t *samples, size_t num_samples, tsk_id_t *node_map) +{ + int ret; + tsk_node_t n1, n2; + tsk_tree_t full_tree, subset_tree; + tsk_site_t *tree_sites; + tsk_tbl_size_t tree_sites_length; + uint32_t j, k; + tsk_id_t u, mrca1, mrca2; + size_t total_sites; + + CU_ASSERT_EQUAL( + tsk_treeseq_get_sequence_length(ts), + tsk_treeseq_get_sequence_length(subset)); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(subset), num_samples); + CU_ASSERT( + tsk_treeseq_get_num_nodes(ts) >= tsk_treeseq_get_num_nodes(subset)); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(subset), num_samples); + + /* Check the sample properties */ + for (j = 0; j < num_samples; j++) { + ret = tsk_treeseq_get_node(ts, (size_t) samples[j], &n1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node_map[samples[j]], j); + ret = tsk_treeseq_get_node(subset, (size_t) node_map[samples[j]], &n2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n1.population, n2.population); + CU_ASSERT_EQUAL_FATAL(n1.time, n2.time); + CU_ASSERT_EQUAL_FATAL(n1.flags, n2.flags); + CU_ASSERT_EQUAL_FATAL(n1.metadata_length, n2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(n1.metadata, n2.metadata, n2.metadata_length); + } + /* Check that node mappings are correct */ + for (j = 0; j < tsk_treeseq_get_num_nodes(ts); j++) { + ret = tsk_treeseq_get_node(ts, j, &n1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (node_map[j] != TSK_NULL) { + ret = tsk_treeseq_get_node(subset, (size_t) node_map[j], &n2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n1.population, n2.population); + CU_ASSERT_EQUAL_FATAL(n1.time, n2.time); + CU_ASSERT_EQUAL_FATAL(n1.flags, n2.flags); + CU_ASSERT_EQUAL_FATAL(n1.metadata_length, n2.metadata_length); + CU_ASSERT_NSTRING_EQUAL(n1.metadata, n2.metadata, n2.metadata_length); + } + } + if (num_samples == 0) { + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(subset), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(subset), 0); + } else if (num_samples == 1) { + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(subset), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(subset), 1); + } + /* Check the pairwise MRCAs */ + ret = tsk_tree_alloc(&full_tree, ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_alloc(&subset_tree, subset, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&full_tree); + CU_ASSERT_EQUAL(ret, 1); + ret = tsk_tree_first(&subset_tree); + CU_ASSERT_EQUAL(ret, 1); + + total_sites = 0; + while (1) { + while (full_tree.right <= subset_tree.right) { + for (j = 0; j < num_samples; j++) { + for (k = j + 1; k < num_samples; k++) { + ret = tsk_tree_get_mrca(&full_tree, samples[j], samples[k], &mrca1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_get_mrca(&subset_tree, + node_map[samples[j]], node_map[samples[k]], &mrca2); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (mrca1 == TSK_NULL) { + CU_ASSERT_EQUAL_FATAL(mrca2, TSK_NULL); + } else { + CU_ASSERT_EQUAL(node_map[mrca1], mrca2); + } + } + } + ret = tsk_tree_next(&full_tree); + CU_ASSERT_FATAL(ret >= 0); + if (ret != 1) { + break; + } + } + /* Check the sites in this tree */ + ret = tsk_tree_get_sites(&subset_tree, &tree_sites, &tree_sites_length); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < tree_sites_length; j++) { + CU_ASSERT(subset_tree.left <= tree_sites[j].position); + CU_ASSERT(tree_sites[j].position < subset_tree.right); + for (k = 0; k < tree_sites[j].mutations_length; k++) { + ret = tsk_tree_get_parent(&subset_tree, + tree_sites[j].mutations[k].node, &u); + CU_ASSERT_EQUAL(ret, 0); + } + total_sites++; + } + ret = tsk_tree_next(&subset_tree); + if (ret != 1) { + break; + } + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(subset), total_sites); + + tsk_tree_free(&subset_tree); + tsk_tree_free(&full_tree); +} + +static void +verify_simplify(tsk_treeseq_t *ts) +{ + int ret; + size_t n = tsk_treeseq_get_num_samples(ts); + size_t num_samples[] = {0, 1, 2, 3, n / 2, n - 1, n}; + size_t j; + tsk_id_t *sample; + tsk_id_t *node_map = malloc(tsk_treeseq_get_num_nodes(ts) * sizeof(tsk_id_t)); + tsk_treeseq_t subset; + int flags = TSK_FILTER_SITES; + + CU_ASSERT_FATAL(node_map != NULL); + ret = tsk_treeseq_get_samples(ts, &sample); + CU_ASSERT_EQUAL_FATAL(ret, 0); + if (tsk_treeseq_get_num_migrations(ts) > 0) { + ret = tsk_treeseq_simplify(ts, sample, 2, 0, &subset, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED); + /* Exiting early here because simplify isn't supported with migrations. */ + goto out; + } + + for (j = 0; j < sizeof(num_samples) / sizeof(*num_samples); j++) { + if (num_samples[j] <= n) { + ret = tsk_treeseq_simplify(ts, sample, num_samples[j], flags, &subset, + node_map); + /* printf("ret = %s\n", tsk_strerror(ret)); */ + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_simplify_properties(ts, &subset, sample, num_samples[j], node_map); + tsk_treeseq_free(&subset); + + /* Keep all sites */ + ret = tsk_treeseq_simplify(ts, sample, num_samples[j], 0, &subset, + node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_simplify_properties(ts, &subset, sample, num_samples[j], node_map); + verify_simplify_genotypes(ts, &subset, sample, num_samples[j]); + tsk_treeseq_free(&subset); + } + } +out: + free(node_map); +} + +typedef struct { + uint32_t tree_index; + tsk_id_t node; + uint32_t count; +} sample_count_test_t; + +static void +verify_sample_counts(tsk_treeseq_t *ts, size_t num_tests, sample_count_test_t *tests) +{ + int ret; + size_t j, num_samples, n, k; + tsk_id_t stop, sample_index; + tsk_tree_t tree; + tsk_id_t *samples; + + n = tsk_treeseq_get_num_samples(ts); + ret = tsk_treeseq_get_samples(ts, &samples); + CU_ASSERT_EQUAL(ret, 0); + + /* First run without the TSK_SAMPLE_COUNTS feature */ + ret = tsk_tree_alloc(&tree, ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 0; j < num_tests; j++) { + while (tree.index < tests[j].tree_index) { + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tests[j].count, num_samples); + /* all operations depending on tracked samples should fail. */ + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &num_samples); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSUPPORTED_OPERATION); + } + tsk_tree_free(&tree); + + /* Now run with TSK_SAMPLE_COUNTS but with no samples tracked. */ + ret = tsk_tree_alloc(&tree, ts, TSK_SAMPLE_COUNTS); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 0; j < num_tests; j++) { + while (tree.index < tests[j].tree_index) { + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tests[j].count, num_samples); + /* all operations depending on tracked samples should fail. */ + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 0); + } + tsk_tree_free(&tree); + + /* Run with TSK_SAMPLE_LISTS, but without TSK_SAMPLE_COUNTS */ + ret = tsk_tree_alloc(&tree, ts, TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 0; j < num_tests; j++) { + while (tree.index < tests[j].tree_index) { + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tests[j].count, num_samples); + /* all operations depending on tracked samples should fail. */ + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &num_samples); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSUPPORTED_OPERATION); + + sample_index = tree.left_sample[tests[j].node]; + k = 0; + if (sample_index != TSK_NULL) { + stop = tree.right_sample[tests[j].node]; + while (true) { + k++; + CU_ASSERT_FATAL(k <= tests[j].count); + if (sample_index == stop) { + break; + } + sample_index = tree.next_sample[sample_index]; + } + } + CU_ASSERT_EQUAL(tests[j].count, k); + } + tsk_tree_free(&tree); + + /* Now use TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS */ + ret = tsk_tree_alloc(&tree, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_set_tracked_samples(&tree, n, samples); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + for (j = 0; j < num_tests; j++) { + while (tree.index < tests[j].tree_index) { + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + } + ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tests[j].count, num_samples); + + /* We're tracking all samples, so the count should be the same */ + ret = tsk_tree_get_num_tracked_samples(&tree, tests[j].node, &num_samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tests[j].count, num_samples); + + sample_index = tree.left_sample[tests[j].node]; + k = 0; + if (sample_index != TSK_NULL) { + stop = tree.right_sample[tests[j].node]; + while (true) { + k++; + if (sample_index == stop) { + break; + } + sample_index = tree.next_sample[sample_index]; + } + } + CU_ASSERT_EQUAL(tests[j].count, k); + } + tsk_tree_free(&tree); +} + +static void +verify_sample_sets_for_tree(tsk_tree_t *tree) +{ + int ret, stack_top, j; + tsk_id_t u, v; + size_t tmp, n, num_nodes, num_samples; + tsk_id_t *stack, *samples; + tsk_treeseq_t *ts = tree->tree_sequence; + tsk_id_t *sample_index_map = ts->sample_index_map; + const tsk_id_t *list_left = tree->left_sample; + const tsk_id_t *list_right = tree->right_sample; + const tsk_id_t *list_next = tree->next_sample; + tsk_id_t stop, sample_index; + + n = tsk_treeseq_get_num_samples(ts); + num_nodes = tsk_treeseq_get_num_nodes(ts); + stack = malloc(n * sizeof(tsk_id_t)); + samples = malloc(n * sizeof(tsk_id_t)); + CU_ASSERT_FATAL(stack != NULL); + CU_ASSERT_FATAL(samples != NULL); + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + if (tree->left_child[u] == TSK_NULL && !tsk_treeseq_is_sample(ts, u)) { + CU_ASSERT_EQUAL(list_left[u], TSK_NULL); + CU_ASSERT_EQUAL(list_right[u], TSK_NULL); + } else { + stack_top = 0; + num_samples = 0; + stack[stack_top] = u; + while (stack_top >= 0) { + v = stack[stack_top]; + stack_top--; + if (tsk_treeseq_is_sample(ts, v)) { + samples[num_samples] = v; + num_samples++; + } + for (v = tree->right_child[v]; v != TSK_NULL; v = tree->left_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } + ret = tsk_tree_get_num_samples(tree, u, &tmp); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(num_samples, tmp); + + j = 0; + sample_index = list_left[u]; + if (sample_index != TSK_NULL) { + stop = list_right[u]; + while (true) { + CU_ASSERT_TRUE_FATAL(j < (tsk_id_t) n); + CU_ASSERT_EQUAL_FATAL(sample_index, sample_index_map[samples[j]]); + j++; + if (sample_index == stop) { + break; + } + sample_index = list_next[sample_index]; + } + } + CU_ASSERT_EQUAL_FATAL(j, num_samples); + } + } + free(stack); + free(samples); +} + +static void +verify_sample_sets(tsk_treeseq_t *ts) +{ + int ret; + tsk_tree_t t; + + ret = tsk_tree_alloc(&t, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL(ret, 0); + + for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { + verify_sample_sets_for_tree(&t); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + for (ret = tsk_tree_last(&t); ret == 1; ret = tsk_tree_prev(&t)) { + verify_sample_sets_for_tree(&t); + } + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_tree_free(&t); +} + + +static void +verify_empty_tree_sequence(tsk_treeseq_t *ts, double sequence_length) +{ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_edges(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_migrations(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), sequence_length); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(ts), 1); +} + +/*======================================================= + * Simplest test cases. + *======================================================*/ + +static void +test_simplest_records(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges = + "0 1 2 0,1\n"; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_nonbinary_records(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges = + "0 1 4 0,1,2,3\n"; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 5); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_unary_records(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0\n" + "0 2 0"; + const char *edges = + "0 1 2 0\n" + "0 1 3 1\n" + "0 1 4 2,3\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = {0, 1}; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 5); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_populations(&ts), 1); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&simplified), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&simplified), 1); + + tsk_treeseq_free(&ts); + tsk_treeseq_free(&simplified); +} + +static void +test_simplest_non_sample_leaf_records(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 0 0\n" + "0 0 0"; + const char *edges = + "0 1 2 0,1,3,4\n"; + const char *sites = + "0.1 0\n" + "0.2 0\n" + "0.3 0\n" + "0.4 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n" + "2 3 1\n" + "3 4 1"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = {0, 1}; + tsk_hapgen_t hapgen; + tsk_vargen_t vargen; + char *haplotype; + tsk_variant_t *var; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 5); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_hapgen_get_haplotype(&hapgen, 0, &haplotype); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "1000"); + ret = tsk_hapgen_get_haplotype(&hapgen, 1, &haplotype); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "0100"); + tsk_hapgen_free(&hapgen); + + ret = tsk_vargen_alloc(&vargen, &ts, NULL, 0, 0); + tsk_vargen_print_state(&vargen, _devnull); + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_vargen_free(&vargen); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&simplified), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&simplified), 1); + + tsk_treeseq_free(&ts); + tsk_treeseq_free(&simplified); +} + +static void +test_simplest_degenerate_multiple_root_records(void) +{ + + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0\n"; + const char *edges = + "0 1 2 0\n" + "0 1 3 1\n"; + tsk_treeseq_t ts, simplified; + tsk_tree_t t; + tsk_id_t sample_ids[] = {0, 1}; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&t), 2); + CU_ASSERT_EQUAL(t.left_root, 2); + CU_ASSERT_EQUAL(t.right_sib[2], 3); + CU_ASSERT_EQUAL(t.right_sib[3], TSK_NULL); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 2); + + tsk_treeseq_free(&simplified); + tsk_treeseq_free(&ts); + tsk_tree_free(&t); +} + +static void +test_simplest_multiple_root_records(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 1 0\n"; + const char *edges = + "0 1 4 0,1\n" + "0 1 5 2,3\n"; + tsk_treeseq_t ts, simplified; + tsk_id_t sample_ids[] = {0, 1, 2, 3}; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 4, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&simplified), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&simplified), 1); + tsk_treeseq_free(&simplified); + + /* Make one tree degenerate */ + ret = tsk_treeseq_simplify(&ts, sample_ids, 3, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&simplified), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&simplified), 1); + tsk_treeseq_free(&simplified); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_zero_root_tree(void) +{ + int ret; + const char *nodes = + "0 0 0\n" + "0 0 0\n" + "0 0 0\n" + "0 0 0\n" + "0 1 0\n" + "0 1 0\n"; + const char *edges = + "0 1 4 0,1\n" + "0 1 5 2,3\n"; + tsk_treeseq_t ts; + tsk_tree_t t; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&t), 0); + CU_ASSERT_EQUAL(t.left_root, TSK_NULL); + CU_ASSERT_EQUAL(t.right_sib[2], 3); + CU_ASSERT_EQUAL(t.right_sib[3], TSK_NULL); + + tsk_tree_free(&t); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_root_mutations(void) +{ + int ret; + uint32_t j; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = + "0 1 2 0,1\n"; + const char *sites = + "0.1 0"; + const char *mutations = + "0 2 1"; + tsk_hapgen_t hapgen; + char *haplotype; + int flags = 0; + tsk_id_t sample_ids[] = {0, 1}; + tsk_treeseq_t ts, simplified; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, "1"); + } + tsk_hapgen_free(&hapgen); + + ret = tsk_treeseq_simplify(&ts, sample_ids, 2, flags, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&simplified), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&simplified), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&simplified), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&simplified), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&simplified), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&simplified), 1); + tsk_treeseq_free(&simplified); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_back_mutations(void) +{ + int ret; + uint32_t j; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 2 0\n"; + const char *edges = + "0 1 3 0,1\n" + "0 1 4 2,3\n"; + const char *sites = + "0.5 0"; + const char *mutations = + "0 3 1 -1\n" + "0 0 0 0"; + tsk_hapgen_t hapgen; + const char *haplotypes[] = {"0", "1", "0"}; + char *haplotype; + tsk_treeseq_t ts; + tsk_vargen_t vargen; + tsk_variant_t *var; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 5); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + + ret = tsk_vargen_alloc(&vargen, &ts, NULL, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_vargen_next(&vargen, &var); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(var->num_alleles, 2); + CU_ASSERT_NSTRING_EQUAL(var->alleles[0], "0", 1); + CU_ASSERT_NSTRING_EQUAL(var->alleles[1], "1", 1); + CU_ASSERT_EQUAL(var->genotypes.u8[0], 0); + CU_ASSERT_EQUAL(var->genotypes.u8[1], 1); + CU_ASSERT_EQUAL(var->genotypes.u8[2], 0); + CU_ASSERT_EQUAL(var->site->id, 0); + CU_ASSERT_EQUAL(var->site->mutations_length, 2); + tsk_vargen_free(&vargen); + + tsk_treeseq_free(&ts); +} + +static void +test_simplest_general_samples(void) +{ + const char *nodes = + "1 0 0\n" + "0 1 0\n" + "1 0 0"; + const char *edges = + "0 1 1 0,2\n"; + const char *sites = + "0.5 0\n" + "0.75 0\n"; + const char *mutations = + "0 2 1\n" + "1 0 1"; + const char *haplotypes[] = {"01", "10"}; + char *haplotype; + unsigned int j; + tsk_id_t samples[2] = {0, 2}; + tsk_id_t *s; + int ret; + + tsk_treeseq_t ts, simplified; + tsk_hapgen_t hapgen; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_treeseq_get_samples(&ts, &s); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_FATAL(s != NULL); + CU_ASSERT_EQUAL(s[0], 0); + CU_ASSERT_EQUAL(s[1], 2); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + + ret = tsk_treeseq_simplify(&ts, samples, 2, 0, &simplified, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_get_samples(&simplified, &s); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_FATAL(s != NULL); + CU_ASSERT_EQUAL(s[0], 0); + CU_ASSERT_EQUAL(s[1], 1); + + ret = tsk_hapgen_alloc(&hapgen, &simplified); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + + tsk_treeseq_free(&simplified); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_holey_tree_sequence(void) +{ + const char *nodes_txt = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges_txt = + "0 1 2 0\n" + "2 3 2 0\n" + "0 1 2 1\n" + "2 3 2 1\n"; + const char *sites_txt = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations_txt = + "0 0 1\n" + "1 1 1\n" + "2 2 1\n"; + const char *haplotypes[] = {"101", "011"}; + char *haplotype; + unsigned int j; + int ret; + tsk_treeseq_t ts; + tsk_hapgen_t hapgen; + + tsk_treeseq_from_text(&ts, 3, nodes_txt, edges_txt, NULL, sites_txt, + mutations_txt, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 3.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 3); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + + tsk_hapgen_free(&hapgen); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_holey_tsk_treeseq_mutation_parents(void) +{ + const char *nodes_txt = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges_txt = + "0 1 2 0\n" + "2 3 2 0\n" + "0 1 2 1\n" + "2 3 2 1\n"; + const char *sites_txt = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations_txt = + "0 0 1\n" + "0 0 1\n" + "1 1 1\n" + "1 1 1\n" + "2 2 1\n" + "2 2 1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int ret; + + tsk_treeseq_from_text(&ts, 3, nodes_txt, edges_txt, NULL, sites_txt, + mutations_txt, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 3); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.mutations->parent[0], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[1], 0); + CU_ASSERT_EQUAL(tables.mutations->parent[2], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[3], 2); + CU_ASSERT_EQUAL(tables.mutations->parent[4], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[5], 4); + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_initial_gap_tree_sequence(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges = + "2 3 2 0,1\n"; + const char *sites = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n" + "2 2 1"; + const char *haplotypes[] = {"101", "011"}; + char *haplotype; + unsigned int j; + int ret; + tsk_treeseq_t ts; + tsk_hapgen_t hapgen; + const tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + z, z, z, + 2, 2, z, + }; + uint32_t num_trees = 2; + + tsk_treeseq_from_text(&ts, 3, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 3.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 2); + + verify_trees(&ts, num_trees, parents); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_initial_gap_zero_roots(void) +{ + const char *nodes = + "0 0 0\n" + "0 0 0\n" + "0 1 0"; + const char *edges = + "2 3 2 0,1\n"; + int ret; + tsk_treeseq_t ts; + const tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + z, z, z, + 2, 2, z, + }; + uint32_t num_trees = 2; + tsk_tree_t tree; + + tsk_treeseq_from_text(&ts, 3, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 3.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 2); + + verify_trees(&ts, num_trees, parents); + + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 0); + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 0); + CU_ASSERT_EQUAL(tree.parent[0], 2); + CU_ASSERT_EQUAL(tree.parent[1], 2); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_holey_tsk_treeseq_zero_roots(void) +{ + const char *nodes_txt = + "0 0 0\n" + "0 0 0\n" + "0 1 0"; + const char *edges_txt = + "0 1 2 0\n" + "2 3 2 0\n" + "0 1 2 1\n" + "2 3 2 1\n"; + int ret; + tsk_treeseq_t ts; + const tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + 2, 2, z, + z, z, z, + 2, 2, z, + }; + uint32_t num_trees = 3; + tsk_tree_t tree; + + tsk_treeseq_from_text(&ts, 3, nodes_txt, edges_txt, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 3.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 3); + + verify_trees(&ts, num_trees, parents); + + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); + CU_ASSERT_EQUAL(tree.parent[0], 2); + CU_ASSERT_EQUAL(tree.parent[1], 2); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 0); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 0); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 0); + CU_ASSERT_EQUAL(tree.parent[0], 2); + CU_ASSERT_EQUAL(tree.parent[1], 2); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_initial_gap_tsk_treeseq_mutation_parents(void) +{ + const char *nodes_txt = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges_txt = + "2 3 2 0,1\n"; + const char *sites_txt = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations_txt = + "0 0 1\n" + "0 0 1\n" + "1 1 1\n" + "1 1 1\n" + "2 2 1\n" + "2 2 1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int ret; + + tsk_treeseq_from_text(&ts, 3, nodes_txt, edges_txt, NULL, sites_txt, + mutations_txt, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 2); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.mutations->parent[0], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[1], 0); + CU_ASSERT_EQUAL(tables.mutations->parent[2], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[3], 2); + CU_ASSERT_EQUAL(tables.mutations->parent[4], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[5], 4); + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_final_gap_tree_sequence(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges = + "0 2 2 0,1\n"; + const char *sites = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n" + "2 0 1"; + const char *haplotypes[] = {"101", "010"}; + char *haplotype; + unsigned int j; + int ret; + tsk_treeseq_t ts; + tsk_hapgen_t hapgen; + const tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + 2, 2, z, + z, z, z, + }; + uint32_t num_trees = 2; + + tsk_treeseq_from_text(&ts, 3, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 3.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 2); + + verify_trees(&ts, num_trees, parents); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_final_gap_tsk_treeseq_mutation_parents(void) +{ + const char *nodes_txt = + "1 0 0\n" + "1 0 0\n" + "0 1 0"; + const char *edges_txt = + "0 2 2 0,1\n"; + const char *sites_txt = + "0.5 0\n" + "1.5 0\n" + "2.5 0\n"; + const char *mutations_txt = + "0 0 1\n" + "0 0 1\n" + "1 1 1\n" + "1 1 1\n" + "2 0 1\n" + "2 0 1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int ret; + + tsk_treeseq_from_text(&ts, 3, nodes_txt, edges_txt, NULL, sites_txt, + mutations_txt, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 2); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.mutations->parent[0], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[1], 0); + CU_ASSERT_EQUAL(tables.mutations->parent[2], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[3], 2); + CU_ASSERT_EQUAL(tables.mutations->parent[4], -1); + CU_ASSERT_EQUAL(tables.mutations->parent[5], 4); + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_individuals(void) +{ + const char *individuals = + "1 0.2\n" + "2 0.5,0.6\n"; + const char *nodes = + "1 0 -1 -1\n" + "1 0 -1 1\n" + "0 0 -1 -1\n" + "1 0 -1 0\n" + "0 0 -1 1\n"; + tsk_tbl_collection_t tables; + tsk_treeseq_t ts; + tsk_node_t node; + tsk_individual_t individual; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1.0; + parse_individuals(individuals, tables.individuals); + CU_ASSERT_EQUAL_FATAL(tables.individuals->num_rows, 2); + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 5); + + ret = tsk_treeseq_alloc(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_get_node(&ts, 0, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(node.individual, TSK_NULL); + + ret = tsk_treeseq_get_node(&ts, 1, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(node.individual, 1); + + ret = tsk_treeseq_get_individual(&ts, 0, &individual); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(individual.id, 0); + CU_ASSERT_EQUAL_FATAL(individual.flags, 1); + CU_ASSERT_EQUAL_FATAL(individual.location_length, 1); + CU_ASSERT_EQUAL_FATAL(individual.location[0], 0.2); + CU_ASSERT_EQUAL_FATAL(individual.nodes_length, 1); + CU_ASSERT_EQUAL_FATAL(individual.nodes[0], 3); + + ret = tsk_treeseq_get_individual(&ts, 1, &individual); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(individual.id, 1); + CU_ASSERT_EQUAL_FATAL(individual.flags, 2); + CU_ASSERT_EQUAL_FATAL(individual.location_length, 2); + CU_ASSERT_EQUAL_FATAL(individual.location[0], 0.5); + CU_ASSERT_EQUAL_FATAL(individual.location[1], 0.6); + CU_ASSERT_EQUAL_FATAL(individual.nodes_length, 2); + CU_ASSERT_EQUAL_FATAL(individual.nodes[0], 1); + CU_ASSERT_EQUAL_FATAL(individual.nodes[1], 4); + + ret = tsk_treeseq_get_individual(&ts, 3, &individual); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + + tsk_tbl_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_simplest_bad_individuals(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = + "0 1 2 0\n" + "0 1 2 1\n" + "0 1 4 3\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int load_flags = TSK_BUILD_INDEXES; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1.0; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 5); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 3); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Make sure we have a good set of records */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* Bad individual ID */ + tables.nodes->individual[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->individual[0] = TSK_NULL; + + /* Bad individual ID */ + tables.nodes->individual[0] = 0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->individual[0] = TSK_NULL; + + /* add two individuals */ + ret = tsk_individual_tbl_add_row(tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_individual_tbl_add_row(tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_EQUAL(ret, 1); + + /* Bad individual ID */ + tables.nodes->individual[0] = 2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->individual[0] = TSK_NULL; + + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_bad_edges(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = + "0 1 2 0\n" + "0 1 2 1\n" + "0 1 4 3\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int ret; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1.0; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 5); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 3); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Make sure we have a good set of records */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* NULL for tables should be an error */ + ret = tsk_treeseq_alloc(&ts, NULL, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + tsk_treeseq_free(&ts); + + /* Bad population ID */ + tables.nodes->population[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->population[0] = 0; + + /* Bad population ID */ + tables.nodes->population[0] = 1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->population[0] = 0; + + /* Bad interval */ + tables.edges->right[0] = 0.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_EDGE_INTERVAL); + tsk_treeseq_free(&ts); + tables.edges->right[0]= 1.0; + + /* Left coordinate < 0. */ + tables.edges->left[0] = -1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_LEFT_LESS_ZERO); + tsk_treeseq_free(&ts); + tables.edges->left[0]= 0.0; + + /* Right coordinate > sequence length. */ + tables.edges->right[0] = 2.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_RIGHT_GREATER_SEQ_LENGTH); + tsk_treeseq_free(&ts); + tables.edges->right[0]= 1.0; + + /* Duplicate records */ + tables.edges->child[0] = 1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_DUPLICATE_EDGES); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + + /* Duplicate records */ + tables.edges->child[0] = 1; + tables.edges->left[0] = 0.5; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_EDGES_NOT_SORTED_LEFT); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + tables.edges->left[0] = 0.0; + + /* child node == parent */ + tables.edges->child[1] = 2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + tsk_treeseq_free(&ts); + tables.edges->child[1] = 1; + + /* Unsorted child nodes */ + tables.edges->child[0] = 1; + tables.edges->child[1] = 0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_EDGES_NOT_SORTED_CHILD); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + tables.edges->child[1] = 1; + + /* discontinuous parent nodes */ + /* Swap rows 1 and 2 */ + tables.edges->parent[1] = 4; + tables.edges->child[1] = 3; + tables.edges->parent[2] = 2; + tables.edges->child[2] = 1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS); + tsk_treeseq_free(&ts); + tables.edges->parent[2] = 4; + tables.edges->child[2] = 3; + tables.edges->parent[1] = 2; + tables.edges->child[1] = 1; + + /* Null parent */ + tables.edges->parent[0] = TSK_NULL; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NULL_PARENT); + tsk_treeseq_free(&ts); + tables.edges->parent[0] = 2; + + /* parent not in nodes list */ + tables.nodes->num_rows = 2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.nodes->num_rows = 5; + + /* parent negative */ + tables.edges->parent[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.edges->parent[0] = 2; + + /* Null child */ + tables.edges->child[0] = TSK_NULL; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NULL_CHILD); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + + /* child node reference out of bounds */ + tables.edges->child[0] = 100; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + + /* child node reference negative */ + tables.edges->child[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.edges->child[0] = 0; + + /* Make sure we've preserved a good tree sequence */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + tsk_treeseq_free(&ts); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_bad_indexes(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = + "0 1 2 0\n" + "0 1 2 1\n" + "0 1 4 3\n"; + tsk_tbl_collection_t tables; + tsk_id_t bad_indexes[] = {-1, 3, 4, 1000}; + size_t j; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1.0; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 5); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 3); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Make sure we have a good set of records */ + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_tbl_collection_check_integrity(&tables, TSK_CHECK_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TABLES_NOT_INDEXED); + ret = tsk_tbl_collection_build_indexes(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_check_integrity(&tables, TSK_CHECK_ALL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < sizeof(bad_indexes) / sizeof(*bad_indexes); j++) { + tables.indexes.edge_insertion_order[0] = bad_indexes[j]; + ret = tsk_tbl_collection_check_integrity(&tables, TSK_CHECK_ALL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGE_OUT_OF_BOUNDS); + tables.indexes.edge_insertion_order[0] = 0; + + tables.indexes.edge_removal_order[0] = bad_indexes[j]; + ret = tsk_tbl_collection_check_integrity(&tables, TSK_CHECK_ALL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGE_OUT_OF_BOUNDS); + tables.indexes.edge_removal_order[0] = 0; + } + + ret = tsk_tbl_collection_drop_indexes(&tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_check_integrity(&tables, TSK_CHECK_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TABLES_NOT_INDEXED); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_bad_migrations(void) +{ + tsk_tbl_collection_t tables; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + /* insert two populations and one node to refer to. */ + ret = tsk_node_tbl_add_row(tables.nodes, 0, 0.0, TSK_NULL, + TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 1); + /* One migration, node 0 goes from population 0 to 1. */ + ret = tsk_migration_tbl_add_row(tables.migrations, 0, 1, 0, 0, 1, 1.0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* We only need basic intregity checks for migrations */ + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Bad node reference */ + tables.migrations->node[0] = -1; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.migrations->node[0] = 0; + + /* Bad node reference */ + tables.migrations->node[0] = 1; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.migrations->node[0] = 0; + + /* Bad population reference */ + tables.migrations->source[0] = -1; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tables.migrations->source[0] = 0; + + /* Bad population reference */ + tables.migrations->source[0] = 2; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tables.migrations->source[0] = 0; + + /* Bad population reference */ + tables.migrations->dest[0] = -1; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tables.migrations->dest[0] = 1; + + /* Bad population reference */ + tables.migrations->dest[0] = 2; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tables.migrations->dest[0] = 1; + + /* Bad left coordinate */ + tables.migrations->left[0] = -1; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_LEFT_LESS_ZERO); + tables.migrations->left[0] = 0; + + /* Bad right coordinate */ + tables.migrations->right[0] = 2; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_RIGHT_GREATER_SEQ_LENGTH); + tables.migrations->right[0] = 1; + + /* Bad interval coordinate */ + tables.migrations->right[0] = 0; + ret = tsk_tbl_collection_check_integrity(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_EDGE_INTERVAL); + tables.migrations->right[0] = 1; + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_migration_simplify(void) +{ + tsk_tbl_collection_t tables; + int ret; + tsk_id_t samples[] = {0, 1}; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + /* insert two populations and one node to refer to. */ + ret = tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, + TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, + TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 1); + /* One migration, node 0 goes from population 0 to 1. */ + ret = tsk_migration_tbl_add_row(tables.migrations, 0, 1, 0, 0, 1, 1.0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_overlapping_parents(void) +{ + const char *nodes = + "1 0 -1\n" + "1 0 -1\n" + "0 1 -1\n"; + const char *edges = + "0 1 2 0\n" + "0 1 2 1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + tsk_tree_t tree; + int ret; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 3); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 2); + + tables.edges->left[0] = 0; + tables.edges->parent[0] = 2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tree.parent[0], 2); + CU_ASSERT_EQUAL(tree.parent[1], 2); + CU_ASSERT_EQUAL(tree.left_sib[2], TSK_NULL); + CU_ASSERT_EQUAL(tree.right_sib[2], TSK_NULL); + CU_ASSERT_EQUAL(tree.left_child[2], 0); + CU_ASSERT_EQUAL(tree.right_child[2], 1); + CU_ASSERT_EQUAL(tree.left_sib[0], TSK_NULL); + CU_ASSERT_EQUAL(tree.right_sib[0], 1); + CU_ASSERT_EQUAL(tree.left_sib[1], 0); + CU_ASSERT_EQUAL(tree.right_sib[1], TSK_NULL); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_contradictory_children(void) +{ + const char *nodes = + "1 0 -1\n" + "1 1 -1\n" + "0 1 -1\n"; + const char *edges = + "0 1 1 0\n" + "0 1 2 0\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + tsk_tree_t tree; + int ret; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 3); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 2); + tables.sequence_length = 1.0; + + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_overlapping_edges_simplify(void) +{ + const char *nodes = + "1 0 -1\n" + "1 0 -1\n" + "1 0 -1\n" + "0 1 -1"; + const char *edges = + "0 2 3 0\n" + "1 3 3 1\n" + "0 3 3 2\n"; + tsk_id_t samples[] = {0, 1, 2}; + tsk_tbl_collection_t tables; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 3; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 4); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 3); + + ret = tsk_tbl_collection_simplify(&tables, samples, 3, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tables.nodes->num_rows, 4); + CU_ASSERT_EQUAL(tables.edges->num_rows, 3); + + /* Identical to the input. + 0 2 3 0 + 1 3 3 1 + 0 3 3 2 + */ + CU_ASSERT_EQUAL(tables.edges->left[0], 0); + CU_ASSERT_EQUAL(tables.edges->left[1], 1); + CU_ASSERT_EQUAL(tables.edges->left[2], 0); + CU_ASSERT_EQUAL(tables.edges->right[0], 2); + CU_ASSERT_EQUAL(tables.edges->right[1], 3); + CU_ASSERT_EQUAL(tables.edges->right[2], 3); + CU_ASSERT_EQUAL(tables.edges->parent[0], 3); + CU_ASSERT_EQUAL(tables.edges->parent[1], 3); + CU_ASSERT_EQUAL(tables.edges->parent[2], 3); + CU_ASSERT_EQUAL(tables.edges->child[0], 0); + CU_ASSERT_EQUAL(tables.edges->child[1], 1); + CU_ASSERT_EQUAL(tables.edges->child[2], 2); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_overlapping_unary_edges_simplify(void) +{ + const char *nodes = + "1 0 -1\n" + "1 0 -1\n" + "0 1 -1"; + const char *edges = + "0 2 2 0\n" + "1 3 2 1\n"; + tsk_id_t samples[] = {0, 1}; + tsk_tbl_collection_t tables; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 3; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 3); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 2); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tables.nodes->num_rows, 3); + CU_ASSERT_EQUAL(tables.edges->num_rows, 2); + + /* Because we only sample 0 and 1, the flanking unary edges are removed + 1 2 2 0 + 1 2 2 1 + */ + CU_ASSERT_EQUAL(tables.edges->left[0], 1); + CU_ASSERT_EQUAL(tables.edges->right[0], 2); + CU_ASSERT_EQUAL(tables.edges->parent[0], 2); + CU_ASSERT_EQUAL(tables.edges->child[0], 0); + CU_ASSERT_EQUAL(tables.edges->left[1], 1); + CU_ASSERT_EQUAL(tables.edges->right[1], 2); + CU_ASSERT_EQUAL(tables.edges->parent[1], 2); + CU_ASSERT_EQUAL(tables.edges->child[1], 1); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_overlapping_unary_edges_internal_samples_simplify(void) +{ + const char *nodes = + "1 0 -1\n" + "1 0 -1\n" + "1 1 -1"; + const char *edges = + "0 2 2 0\n" + "1 3 2 1\n"; + tsk_id_t samples[] = {0, 1, 2}; + tsk_tbl_collection_t tables; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 3; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 3); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 2); + + ret = tsk_tbl_collection_simplify(&tables, samples, 3, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tables.nodes->num_rows, 3); + CU_ASSERT_EQUAL(tables.edges->num_rows, 2); + /* Identical to the input. + 0 2 2 0 + 1 3 2 1 + */ + CU_ASSERT_EQUAL(tables.edges->left[0], 0); + CU_ASSERT_EQUAL(tables.edges->left[1], 1); + CU_ASSERT_EQUAL(tables.edges->right[0], 2); + CU_ASSERT_EQUAL(tables.edges->right[1], 3); + CU_ASSERT_EQUAL(tables.edges->parent[0], 2); + CU_ASSERT_EQUAL(tables.edges->parent[1], 2); + CU_ASSERT_EQUAL(tables.edges->child[0], 0); + CU_ASSERT_EQUAL(tables.edges->child[1], 1); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_reduce_site_topology(void) +{ + /* Two trees side by side, with a site on the second one. The first + * tree should disappear. */ + const char *nodes = + "1 0 -1\n" + "1 0 -1\n" + "0 1 -1\n" + "0 2 -1\n"; + const char *edges = + "0 1 2 0\n" + "0 1 2 1\n" + "1 2 3 0\n" + "1 2 3 1\n"; + const char *sites = + "1.0 0\n"; + tsk_id_t samples[] = {0, 1}; + tsk_tbl_collection_t tables; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 2; + parse_nodes(nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 4); + parse_edges(edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 4); + parse_sites(sites, tables.sites); + CU_ASSERT_EQUAL_FATAL(tables.sites->num_rows, 1); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, + TSK_REDUCE_TO_SITE_TOPOLOGY, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tables.nodes->num_rows, 3); + CU_ASSERT_EQUAL(tables.edges->num_rows, 2); + CU_ASSERT_EQUAL(tables.edges->left[0], 0); + CU_ASSERT_EQUAL(tables.edges->left[1], 0); + CU_ASSERT_EQUAL(tables.edges->right[0], 2); + CU_ASSERT_EQUAL(tables.edges->right[1], 2); + CU_ASSERT_EQUAL(tables.edges->parent[0], 2); + CU_ASSERT_EQUAL(tables.edges->parent[1], 2); + CU_ASSERT_EQUAL(tables.edges->child[0], 0); + CU_ASSERT_EQUAL(tables.edges->child[1], 1); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_population_filter(void) +{ + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + tsk_population_tbl_add_row(tables.populations, "0", 1); + tsk_population_tbl_add_row(tables.populations, "1", 1); + tsk_population_tbl_add_row(tables.populations, "2", 1); + /* Two nodes referring to population 1 */ + tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 1, TSK_NULL, + NULL, 0); + tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 1, TSK_NULL, + NULL, 0); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.nodes->num_rows, 2); + CU_ASSERT_EQUAL(tables.populations->num_rows, 3); + CU_ASSERT_EQUAL(tables.populations->metadata[0], '0'); + CU_ASSERT_EQUAL(tables.populations->metadata[1], '1'); + CU_ASSERT_EQUAL(tables.populations->metadata[2], '2'); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, TSK_FILTER_POPULATIONS, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.nodes->num_rows, 2); + CU_ASSERT_EQUAL(tables.nodes->population[0], 0); + CU_ASSERT_EQUAL(tables.nodes->population[1], 0); + CU_ASSERT_EQUAL(tables.populations->num_rows, 1); + CU_ASSERT_EQUAL(tables.populations->metadata[0], '1'); + + tsk_tbl_collection_free(&tables); +} + +static void +test_simplest_individual_filter(void) +{ + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + tsk_individual_tbl_add_row(tables.individuals, 0, NULL, 0, "0", 1); + tsk_individual_tbl_add_row(tables.individuals, 0, NULL, 0, "1", 1); + tsk_individual_tbl_add_row(tables.individuals, 0, NULL, 0, "2", 1); + /* Two nodes referring to individual 1 */ + tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, + NULL, 0); + tsk_node_tbl_add_row(tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, + NULL, 0); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.nodes->num_rows, 2); + CU_ASSERT_EQUAL(tables.individuals->num_rows, 3); + CU_ASSERT_EQUAL(tables.individuals->metadata[0], '0'); + CU_ASSERT_EQUAL(tables.individuals->metadata[1], '1'); + CU_ASSERT_EQUAL(tables.individuals->metadata[2], '2'); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, TSK_FILTER_INDIVIDUALS, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.nodes->num_rows, 2); + CU_ASSERT_EQUAL(tables.nodes->individual[0], 0); + CU_ASSERT_EQUAL(tables.nodes->individual[1], 0); + CU_ASSERT_EQUAL(tables.individuals->num_rows, 1); + CU_ASSERT_EQUAL(tables.individuals->metadata[0], '1'); + + tsk_tbl_collection_free(&tables); +} + +/*======================================================= + * Single tree tests. + *======================================================*/ + +static void +test_single_tree_good_records(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 7); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_treeseq_free(&ts); +} + + +static void +test_single_nonbinary_tree_good_records(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 2 0\n" + "0 3 0\n"; + const char *edges = + "0 1 7 0,1,2,3\n" + "0 1 8 4,5\n" + "0 1 9 6,7,8"; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 7); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 10); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_bad_records(void) +{ + int ret = 0; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + parse_nodes(single_tree_ex_nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 7); + parse_edges(single_tree_ex_edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 6); + + /* Not sorted in time order */ + tables.nodes->time[5] = 0.5; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME); + tsk_treeseq_free(&ts); + tables.nodes->time[5] = 2.0; + + /* Left value greater than sequence right */ + tables.edges->left[2] = 2.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_EDGE_INTERVAL); + tsk_treeseq_free(&ts); + tables.edges->left[2] = 0.0; + + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + + +static void +test_single_tree_good_mutations(void) +{ + tsk_treeseq_t ts; + size_t j; + size_t num_sites = 3; + size_t num_mutations = 7; + tsk_site_t other_sites[num_sites]; + tsk_mutation_t other_mutations[num_mutations]; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 7); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), num_sites); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), num_mutations); + + for (j = 0; j < num_sites; j++) { + ret = tsk_treeseq_get_site(&ts, j, other_sites + j); + CU_ASSERT_EQUAL(ret, 0); + } + for (j = 0; j < num_mutations; j++) { + ret = tsk_treeseq_get_mutation(&ts, j, other_mutations + j); + CU_ASSERT_EQUAL(ret, 0); + } + CU_ASSERT_EQUAL(other_sites[0].position, 0.1); + CU_ASSERT_NSTRING_EQUAL(other_sites[0].ancestral_state, "0", 1); + CU_ASSERT_EQUAL(other_sites[1].position, 0.2); + CU_ASSERT_NSTRING_EQUAL(other_sites[1].ancestral_state, "0", 1); + CU_ASSERT_EQUAL(other_sites[2].position, 0.3); + CU_ASSERT_NSTRING_EQUAL(other_sites[2].ancestral_state, "0", 1); + + CU_ASSERT_EQUAL(other_mutations[0].id, 0); + CU_ASSERT_EQUAL(other_mutations[0].node, 2); + CU_ASSERT_NSTRING_EQUAL(other_mutations[0].derived_state, "1", 1); + CU_ASSERT_EQUAL(other_mutations[1].id, 1); + CU_ASSERT_EQUAL(other_mutations[1].node, 4); + CU_ASSERT_NSTRING_EQUAL(other_mutations[1].derived_state, "1", 1); + CU_ASSERT_EQUAL(other_mutations[2].id, 2); + CU_ASSERT_EQUAL(other_mutations[2].node, 0); + CU_ASSERT_NSTRING_EQUAL(other_mutations[2].derived_state, "0", 1); + CU_ASSERT_EQUAL(other_mutations[3].id, 3); + CU_ASSERT_EQUAL(other_mutations[3].node, 0); + CU_ASSERT_NSTRING_EQUAL(other_mutations[3].derived_state, "1", 1); + CU_ASSERT_EQUAL(other_mutations[4].id, 4); + CU_ASSERT_EQUAL(other_mutations[4].node, 1); + CU_ASSERT_NSTRING_EQUAL(other_mutations[4].derived_state, "1", 1); + CU_ASSERT_EQUAL(other_mutations[5].id, 5); + CU_ASSERT_EQUAL(other_mutations[5].node, 2); + CU_ASSERT_NSTRING_EQUAL(other_mutations[5].derived_state, "1", 1); + CU_ASSERT_EQUAL(other_mutations[6].id, 6); + CU_ASSERT_EQUAL(other_mutations[6].node, 3); + CU_ASSERT_NSTRING_EQUAL(other_mutations[6].derived_state, "1", 1); + + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_bad_mutations(void) +{ + int ret = 0; + const char *sites = + "0 0\n" + "0.1 0\n" + "0.2 0\n"; + const char *mutations = + "0 0 1 -1\n" + "1 1 1 -1\n" + "2 4 1 -1\n" + "2 1 0 2\n" + "2 1 1 3\n" + "2 2 1 -1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + parse_nodes(single_tree_ex_nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 7); + parse_edges(single_tree_ex_edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 6); + parse_sites(sites, tables.sites); + parse_mutations(mutations, tables.mutations); + CU_ASSERT_EQUAL_FATAL(tables.sites->num_rows, 3); + CU_ASSERT_EQUAL_FATAL(tables.mutations->num_rows, 6); + tables.sequence_length = 1.0; + + /* Check to make sure we have legal mutations */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + tsk_treeseq_free(&ts); + + /* negative coordinate */ + tables.sites->position[0] = -1.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_SITE_POSITION); + tsk_treeseq_free(&ts); + tables.sites->position[0] = 0.0; + + /* coordinate == sequence length */ + tables.sites->position[2] = 1.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_SITE_POSITION); + tsk_treeseq_free(&ts); + tables.sites->position[2] = 0.2; + + /* coordinate > sequence length */ + tables.sites->position[2] = 1.1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_SITE_POSITION); + tsk_treeseq_free(&ts); + tables.sites->position[2] = 0.2; + + /* Duplicate positions */ + tables.sites->position[0] = 0.1; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_DUPLICATE_SITE_POSITION); + tsk_treeseq_free(&ts); + tables.sites->position[0] = 0.0; + + /* Unsorted positions */ + tables.sites->position[0] = 0.3; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSORTED_SITES); + tsk_treeseq_free(&ts); + tables.sites->position[0] = 0.0; + + /* site < 0 */ + tables.mutations->site[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->site[0] = 0; + + /* site == num_sites */ + tables.mutations->site[0] = 3; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->site[0] = 0; + + /* node = NULL */ + tables.mutations->node[0] = TSK_NULL; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->node[0] = 0; + + /* node >= num_nodes */ + tables.mutations->node[0] = 7; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->node[0] = 0; + + /* parent < -1 */ + tables.mutations->parent[0] = -2; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->parent[0] = TSK_NULL; + + /* parent >= num_mutations */ + tables.mutations->parent[0] = 7; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); + tables.mutations->parent[0] = TSK_NULL; + + /* parent on a different site */ + tables.mutations->parent[1] = 0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE); + tsk_treeseq_free(&ts); + tables.mutations->parent[1] = TSK_NULL; + + /* parent is the same mutation */ + tables.mutations->parent[0] = 0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_PARENT_EQUAL); + tsk_treeseq_free(&ts); + tables.mutations->parent[0] = TSK_NULL; + + /* parent_id > mutation id */ + tables.mutations->parent[2] = 3; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_PARENT_AFTER_CHILD); + tsk_treeseq_free(&ts); + tables.mutations->parent[2] = TSK_NULL; + + /* Check to make sure we've maintained legal mutations */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + tsk_treeseq_free(&ts); + + tsk_tbl_collection_free(&tables); +} + +static void +test_single_tree_iter(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 2 0\n" + "0 3 0\n"; + const char *edges = + "0 6 4 0,1\n" + "0 6 5 2,3\n" + "0 6 6 4,5\n"; + tsk_id_t parents[] = {4, 4, 5, 5, 6, 6, TSK_NULL}; + tsk_treeseq_t ts; + tsk_tree_t tree; + tsk_id_t u, v, w; + size_t num_samples; + uint32_t num_nodes = 7; + + tsk_treeseq_from_text(&ts, 6, nodes, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), num_nodes); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_tree_print_state(&tree, _devnull); + + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + ret = tsk_tree_get_parent(&tree, u, &v); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(v, parents[u]); + } + ret = tsk_tree_get_num_samples(&tree, 0, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 1); + ret = tsk_tree_get_num_samples(&tree, 4, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 2); + ret = tsk_tree_get_num_samples(&tree, 6, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 4); + ret = tsk_tree_get_mrca(&tree, 0, 1, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 4); + ret = tsk_tree_get_mrca(&tree, 0, 2, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 6); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + +static void +test_single_nonbinary_tree_iter(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n" + "0 2 0\n" + "0 3 0\n"; + const char *edges = + "0 1 7 0,1,2,3\n" + "0 1 8 4,5\n" + "0 1 9 6,7,8\n"; + tsk_id_t parents[] = {7, 7, 7, 7, 8, 8, 9, 9, 9, TSK_NULL}; + tsk_treeseq_t ts; + tsk_tree_t tree; + tsk_id_t u, v, w; + size_t num_samples; + size_t num_nodes = 10; + size_t total_samples = 7; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), num_nodes); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_tree_print_state(&tree, _devnull); + + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + ret = tsk_tree_get_parent(&tree, u, &v); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(v, parents[u]); + } + for (u = 0; u < (tsk_id_t) total_samples; u++) { + ret = tsk_tree_get_num_samples(&tree, u, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 1); + CU_ASSERT_EQUAL(tree.left_child[u], TSK_NULL); + } + + u = 7; + ret = tsk_tree_get_num_samples(&tree, u, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 4); + CU_ASSERT_EQUAL(tree.right_child[u], 3); + CU_ASSERT_EQUAL(tree.left_sib[3], 2); + CU_ASSERT_EQUAL(tree.left_sib[2], 1); + CU_ASSERT_EQUAL(tree.left_sib[1], 0); + CU_ASSERT_EQUAL(tree.left_sib[0], TSK_NULL); + + u = 8; + ret = tsk_tree_get_num_samples(&tree, u, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 2); + CU_ASSERT_EQUAL(tree.right_child[u], 5); + CU_ASSERT_EQUAL(tree.left_sib[5], 4); + CU_ASSERT_EQUAL(tree.left_sib[4], TSK_NULL); + + u = 9; + ret = tsk_tree_get_num_samples(&tree, u, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 7); + CU_ASSERT_EQUAL(tree.right_child[u], 8); + CU_ASSERT_EQUAL(tree.left_sib[8], 7); + CU_ASSERT_EQUAL(tree.left_sib[7], 6); + CU_ASSERT_EQUAL(tree.left_sib[6], TSK_NULL); + + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&tree), 1); + CU_ASSERT_EQUAL(tree.left_root, 9); + + ret = tsk_tree_get_mrca(&tree, 0, 1, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 7); + ret = tsk_tree_get_mrca(&tree, 0, 4, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 9); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_general_samples_iter(void) +{ + int ret; + const char *nodes = + "0 3 0\n" + "0 2 0\n" + "0 1 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n"; + const char *edges = + "0 6 2 3,4\n" + "0 6 1 5,6\n" + "0 6 0 1,2\n"; + tsk_id_t parents[] = {TSK_NULL, 0, 0, 2, 2, 1, 1}; + tsk_id_t *samples; + tsk_treeseq_t ts; + tsk_tree_t tree; + tsk_id_t u, v, w; + size_t num_samples; + uint32_t num_nodes = 7; + + tsk_treeseq_from_text(&ts, 6, nodes, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_treeseq_get_samples(&ts, &samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(samples[0], 3); + CU_ASSERT_EQUAL(samples[1], 4); + CU_ASSERT_EQUAL(samples[2], 5); + CU_ASSERT_EQUAL(samples[3], 6); + + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), num_nodes); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_tree_print_state(&tree, _devnull); + + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + ret = tsk_tree_get_parent(&tree, u, &v); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(v, parents[u]); + } + ret = tsk_tree_get_num_samples(&tree, 3, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 1); + ret = tsk_tree_get_num_samples(&tree, 2, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 2); + ret = tsk_tree_get_num_samples(&tree, 0, &num_samples); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(num_samples, 4); + ret = tsk_tree_get_mrca(&tree, 3, 4, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 2); + ret = tsk_tree_get_mrca(&tree, 3, 6, &w); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(w, 0); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + +static void +test_single_tree_iter_times(void) +{ + int ret = 0; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 2 0\n" + "1 3 0\n" + "0 1 0\n" + "0 4 0\n" + "0 5 0\n"; + const char *edges = + "0 6 4 0,1\n" + "0 6 5 2,3\n" + "0 6 6 4,5\n"; + tsk_id_t parents[] = {4, 4, 5, 5, 6, 6, TSK_NULL}; + double times[] = {0.0, 0.0, 2.0, 3.0, 1.0, 4.0, 5.0}; + double t; + tsk_treeseq_t ts; + tsk_tree_t tree; + tsk_id_t u, v; + uint32_t num_nodes = 7; + + tsk_treeseq_from_text(&ts, 6, nodes, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_tree_alloc(&tree, &ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), num_nodes); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_tree_print_state(&tree, _devnull); + + for (u = 0; u < (tsk_id_t) num_nodes; u++) { + ret = tsk_tree_get_parent(&tree, u, &v); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(v, parents[u]); + ret = tsk_tree_get_time(&tree, u, &t); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(t, times[u]); + } + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + + +static void +test_single_tree_simplify(void) +{ + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + tsk_id_t samples[] = {0, 1}; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL); + verify_simplify(&ts); + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tables.nodes->num_rows, 3); + CU_ASSERT_EQUAL(tables.edges->num_rows, 2); + + /* Make sure we detect unsorted edges */ + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + unsort_edges(tables.edges, 0); + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGES_NOT_SORTED_CHILD); + + /* detect bad parents */ + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.edges->parent[0] = -1; + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NULL_PARENT); + + /* detect bad children */ + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.edges->child[0] = -1; + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NULL_CHILD); + + /* detect loops */ + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.edges->child[0] = tables.edges->parent[0]; + ret = tsk_tbl_collection_simplify(&tables, samples, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_TIME_ORDERING); + + /* Test the interface for NULL inputs */ + ret = tsk_treeseq_copy_tables(&ts, &tables); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_simplify(&tables, NULL, 2, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + +static void +test_single_tree_compute_mutation_parents(void) +{ + int ret = 0; + const char *sites = + "0 0\n" + "0.1 0\n" + "0.2 0\n"; + const char *mutations = + "0 0 1 -1\n" + "1 1 1 -1\n" + "2 4 1 -1\n" + "2 1 0 2\n" + "2 1 1 3\n" + "2 2 1 -1\n"; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 1; + parse_nodes(single_tree_ex_nodes, tables.nodes); + CU_ASSERT_EQUAL_FATAL(tables.nodes->num_rows, 7); + parse_edges(single_tree_ex_edges, tables.edges); + CU_ASSERT_EQUAL_FATAL(tables.edges->num_rows, 6); + parse_sites(sites, tables.sites); + parse_mutations(mutations, tables.mutations); + CU_ASSERT_EQUAL_FATAL(tables.sites->num_rows, 3); + CU_ASSERT_EQUAL_FATAL(tables.mutations->num_rows, 6); + tables.sequence_length = 1.0; + + ret = tsk_tbl_collection_build_indexes(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Check to make sure we have legal mutations */ + ret = tsk_treeseq_alloc(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + + /* Compute the mutation parents */ + verify_compute_mutation_parents(&ts); + + /* Verify consistency of individuals */ + verify_individual_nodes(&ts); + tsk_treeseq_free(&ts); + + /* Bad site reference */ + tables.mutations->site[0] = -1; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tables.mutations->site[0] = 0; + + /* Bad site reference */ + tables.mutations->site[0] = -1; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tables.mutations->site[0] = 0; + + /* mutation sites out of order */ + tables.mutations->site[0] = 2; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSORTED_MUTATIONS); + tables.mutations->site[0] = 0; + + /* sites out of order */ + tables.sites->position[0] = 0.11; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSORTED_SITES); + tables.sites->position[0] = 0; + + /* Bad node reference */ + tables.mutations->node[0] = -1; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.mutations->node[0] = 0; + + /* Bad node reference */ + tables.mutations->node[0] = (tsk_id_t) tables.nodes->num_rows; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tables.mutations->node[0] = 0; + + /* Mutations not ordered by tree */ + tables.mutations->node[2] = 1; + tables.mutations->node[3] = 4; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_MUTATION_PARENT_AFTER_CHILD); + tables.mutations->node[2] = 4; + tables.mutations->node[3] = 1; + + /* Need to reset the parent field here */ + memset(tables.mutations->parent, 0xff, + tables.mutations->num_rows * sizeof(tsk_id_t)); + /* Mutations not ordered by site */ + tables.mutations->site[3] = 1; + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_MUTATIONS); + tables.mutations->site[3] = 2; + + /* Check to make sure we still have legal mutations */ + ret = tsk_tbl_collection_compute_mutation_parents(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_alloc(&ts, &tables, 0); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 6); + tsk_treeseq_free(&ts); + + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + + +/*======================================================= + * Multi tree tests. + *======================================================*/ + +static void +test_simple_multi_tree(void) +{ + tsk_id_t parents[] = { + 6, 5, 8, 5, TSK_NULL, 6, 8, TSK_NULL, TSK_NULL, + 6, 5, 4, 4, 5, 6, TSK_NULL, TSK_NULL, TSK_NULL, + 7, 5, 4, 4, 5, 7, TSK_NULL, TSK_NULL, TSK_NULL, + }; + uint32_t num_trees = 3; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, + paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); +} + +static void +test_unary_multi_tree(void) +{ + tsk_id_t parents[] = { + 6, 5, 7, 5, TSK_NULL, 6, 8, 8, TSK_NULL, + 6, 5, 4, 4, 5, 6, 8, TSK_NULL, TSK_NULL, + 7, 5, 4, 4, 5, 7, TSK_NULL, TSK_NULL, TSK_NULL, + }; + tsk_treeseq_t ts; + uint32_t num_trees = 3; + + tsk_treeseq_from_text(&ts, 10, unary_ex_nodes, unary_ex_edges, NULL, + unary_ex_sites, unary_ex_mutations, NULL, NULL); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); +} + +static void +test_internal_sample_multi_tree(void) +{ + tsk_id_t parents[] = { + 7, 5, 4, 4, 5, 7, TSK_NULL, TSK_NULL, TSK_NULL, + 4, 5, 4, 8, 5, 8, TSK_NULL, TSK_NULL, TSK_NULL, + 6, 5, 4, 4, 5, 6, TSK_NULL, TSK_NULL, TSK_NULL, + }; + tsk_treeseq_t ts; + uint32_t num_trees = 3; + + tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, NULL, + internal_sample_ex_sites, internal_sample_ex_mutations, NULL, NULL); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); +} + +static void +test_internal_sample_simplified_multi_tree(void) +{ + int ret; + tsk_treeseq_t ts, simplified; + tsk_id_t samples[] = {2, 3, 5}; + tsk_id_t node_map[9]; + tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + /* 0 1 2 3 4 */ + 3, 3, z, 2, z, + 2, 4, 4, z, z, + 3, 3, z, 2, z, + }; + uint32_t num_trees = 3; + + tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, NULL, + internal_sample_ex_sites, internal_sample_ex_mutations, NULL, NULL); + ret = tsk_treeseq_simplify(&ts, samples, 3, 0, &simplified, node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node_map[2], 0); + CU_ASSERT_EQUAL(node_map[3], 1); + CU_ASSERT_EQUAL(node_map[5], 2); + + verify_trees(&simplified, num_trees, parents); + tsk_treeseq_free(&simplified); + tsk_treeseq_free(&ts); +} + +static void +test_nonbinary_multi_tree(void) +{ + /* We make one mutation for each tree */ + tsk_id_t parents[] = { + 8, 8, 8, 8, 10, 10, 9, 10, 9, 12, 12, TSK_NULL, TSK_NULL, + 8, 8, 8, 8, 10, 11, 9, 10, 9, 11, 12, 12, TSK_NULL, + }; + + tsk_treeseq_t ts; + uint32_t num_trees = 2; + + tsk_treeseq_from_text(&ts, 100, nonbinary_ex_nodes, nonbinary_ex_edges, NULL, + nonbinary_ex_sites, nonbinary_ex_mutations, NULL, NULL); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); +} + +static void +test_left_to_right_multi_tree(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 0.090 0\n" + "0 0.170 0\n" + "0 0.253 0\n" + "0 0.071 0\n" + "0 0.202 0\n"; + const char *edges = + "2 10 7 2,3\n" + "0 2 4 1\n" + "2 10 4 1\n" + "0 2 4 3\n" + "2 10 4 7\n" + "0 7 5 0,4\n" + "7 10 8 0,4\n" + "0 2 6 2,5\n"; + const char *sites = + "1 0\n" + "4.5 0\n" + "8.5 0\n"; + const char *mutations = + "0 2 1\n" + "1 0 1\n" + "2 4 1\n"; + + tsk_id_t parents[] = { + 5, 4, 6, 4, 5, 6, TSK_NULL, TSK_NULL, TSK_NULL, + 5, 4, 7, 7, 5, TSK_NULL, TSK_NULL, 4, TSK_NULL, + 8, 4, 7, 7, 8, TSK_NULL, TSK_NULL, 4, TSK_NULL, + }; + tsk_treeseq_t ts; + uint32_t num_trees = 3; + + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL); + verify_trees(&ts, num_trees, parents); + verify_tree_next_prev(&ts); + tsk_treeseq_free(&ts); +} + +static void +test_gappy_multi_tree(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 0.090 0\n" + "0 0.170 0\n" + "0 0.253 0\n" + "0 0.071 0\n" + "0 0.202 0\n"; + const char *edges = + "2 7 7 2\n" + "8 10 7 2\n" + "2 7 7 3\n" + "8 10 7 3\n" + "1 2 4 1\n" + "2 7 4 1\n" + "8 10 4 1\n" + "1 2 4 3\n" + "2 7 4 7\n" + "8 10 4 7\n" + "1 7 5 0,4\n" + "8 10 8 0,4\n" + "1 2 6 2,5\n"; + tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + z, z, z, z, z, z, z, z, z, + 5, 4, 6, 4, 5, 6, z, z, z, + 5, 4, 7, 7, 5, z, z, 4, z, + z, z, z, z, z, z, z, z, z, + 8, 4, 7, 7, 8, z, z, 4, z, + z, z, z, z, z, z, z, z, z, + }; + tsk_treeseq_t ts; + uint32_t num_trees = 6; + + tsk_treeseq_from_text(&ts, 12, nodes, edges, NULL, NULL, NULL, NULL, NULL); + verify_trees(&ts, num_trees, parents); + verify_tree_next_prev(&ts); + tsk_treeseq_free(&ts); +} + +static void +test_tsk_treeseq_bad_records(void) +{ + int ret = 0; + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + uint32_t num_trees = 3; + tsk_id_t parents[] = { + 6, 5, 8, 5, TSK_NULL, 6, 8, TSK_NULL, TSK_NULL, + 6, 5, 4, 4, 5, 6, TSK_NULL, TSK_NULL, TSK_NULL, + 7, 5, 4, 4, 5, 7, TSK_NULL, TSK_NULL, TSK_NULL, + }; + int load_flags = TSK_BUILD_INDEXES; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 10; + parse_nodes(paper_ex_nodes, tables.nodes); + parse_edges(paper_ex_edges, tables.edges); + parse_individuals(paper_ex_individuals, tables.individuals); + + /* Make sure we have a good set of records */ + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(ts.num_trees, 3); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); + + /* Left value greater than right */ + tables.edges->left[0] = 10.0; + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_EDGE_INTERVAL); + tsk_treeseq_free(&ts); + tables.edges->left[0] = 2.0; + + ret = tsk_treeseq_alloc(&ts, &tables, load_flags); + CU_ASSERT_EQUAL(ret, 0); + verify_trees(&ts, num_trees, parents); + tsk_treeseq_free(&ts); + + tsk_tbl_collection_free(&tables); +} + +/*======================================================= + * Diff iter tests. + *======================================================*/ + +static void +test_simple_diff_iter(void) +{ + int ret; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, NULL, NULL, + paper_ex_individuals, NULL); + verify_tree_diffs(&ts); + + ret = tsk_treeseq_free(&ts); + CU_ASSERT_EQUAL(ret, 0); +} + +static void +test_nonbinary_diff_iter(void) +{ + int ret; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 100, nonbinary_ex_nodes, nonbinary_ex_edges, NULL, + NULL, NULL, NULL, NULL); + verify_tree_diffs(&ts); + + ret = tsk_treeseq_free(&ts); + CU_ASSERT_EQUAL(ret, 0); +} + +static void +test_unary_diff_iter(void) +{ + int ret; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, unary_ex_nodes, unary_ex_edges, NULL, + NULL, NULL, NULL, NULL); + verify_tree_diffs(&ts); + + ret = tsk_treeseq_free(&ts); + CU_ASSERT_EQUAL(ret, 0); +} + +static void +test_internal_sample_diff_iter(void) +{ + int ret; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, NULL, + NULL, NULL, NULL, NULL); + verify_tree_diffs(&ts); + + ret = tsk_treeseq_free(&ts); + CU_ASSERT_EQUAL(ret, 0); +} + +/*======================================================= + * Sample sets + *======================================================*/ + +static void +test_simple_sample_sets(void) +{ + sample_count_test_t tests[] = { + {0, 0, 1}, {0, 5, 2}, {0, 6, 3}, + {1, 4, 2}, {1, 5, 3}, {1, 6, 4}}; + uint32_t num_tests = 6; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, + NULL, NULL, NULL, paper_ex_individuals, NULL); + verify_sample_counts(&ts, num_tests, tests); + verify_sample_sets(&ts); + + tsk_treeseq_free(&ts); +} + +static void +test_nonbinary_sample_sets(void) +{ + sample_count_test_t tests[] = { + {0, 0, 1}, {0, 8, 4}, {0, 9, 5}, {0, 10, 3}, {0, 12, 8}, + {1, 5, 1}, {1, 8, 4}, {1, 9, 5}, {0, 10, 2}, {0, 11, 1}}; + uint32_t num_tests = 8; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 100, nonbinary_ex_nodes, nonbinary_ex_edges, NULL, + NULL, NULL, NULL, NULL); + verify_sample_counts(&ts, num_tests, tests); + verify_sample_sets(&ts); + + tsk_treeseq_free(&ts); +} + +static void +test_internal_sample_sample_sets(void) +{ + sample_count_test_t tests[] = { + {0, 0, 1}, {0, 5, 4}, {0, 4, 2}, {0, 7, 5}, + {1, 4, 2}, {1, 5, 4}, {1, 8, 5}, + {2, 5, 4}, {2, 6, 5}}; + uint32_t num_tests = 9; + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, + NULL, NULL, NULL, NULL, NULL); + verify_sample_counts(&ts, num_tests, tests); + verify_sample_sets(&ts); + + tsk_treeseq_free(&ts); +} + + +/*======================================================= + * Miscellaneous tests. + *======================================================*/ + +static void +test_genealogical_nearest_neighbours_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_id_t *reference_sets[2]; + tsk_id_t reference_set_0[4], reference_set_1[4]; + tsk_id_t focal[] = {0, 1, 2, 3}; + size_t reference_set_size[2]; + size_t num_focal = 4; + double *A = malloc(2 * num_focal * sizeof(double)); + CU_ASSERT_FATAL(A != NULL); + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 0, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, INT16_MAX, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* Overlapping sample sets */ + reference_sets[0] = focal; + reference_set_size[0] = 1; + reference_sets[1] = focal; + reference_set_size[1] = num_focal; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + + /* bad values in the sample sets */ + reference_set_0[0] = 0; + reference_set_0[1] = 1; + reference_set_1[0] = 2; + reference_set_1[1] = 3; + reference_set_size[0] = 2; + reference_set_size[1] = 2; + reference_sets[0] = reference_set_0; + reference_sets[1] = reference_set_1; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + reference_set_0[0] = -1; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + reference_set_0[0] = (tsk_id_t) tsk_treeseq_get_num_nodes(&ts); + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + reference_set_0[0] = (tsk_id_t) tsk_treeseq_get_num_nodes(&ts) + 1; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + /* Duplicate values in the focal sets */ + reference_set_0[0] = 1; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + reference_set_0[0] = 3; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + + /* Bad sample ID */ + reference_sets[0] = focal; + reference_set_size[0] = 1; + reference_sets[1] = focal + 1; + reference_set_size[1] = num_focal - 1; + focal[0] = -1; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + focal[0] = (tsk_id_t) tsk_treeseq_get_num_nodes(&ts); + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + focal[0] = (tsk_id_t) tsk_treeseq_get_num_nodes(&ts) + 100; + ret = tsk_treeseq_genealogical_nearest_neighbours(&ts, + focal, num_focal, reference_sets, reference_set_size, 2, 0, A); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); + free(A); +} + +static void +test_tree_errors(void) +{ + int ret; + size_t j; + tsk_id_t num_nodes = 9; + tsk_id_t u; + tsk_node_t node; + tsk_treeseq_t ts, other_ts; + tsk_tree_t t, other_t; + tsk_id_t bad_nodes[] = {num_nodes, num_nodes + 1, -1}; + tsk_id_t tracked_samples[] = {0, 0, 0}; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, NULL, NULL, + paper_ex_individuals, NULL); + + ret = tsk_tree_alloc(&t, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_tree_alloc(&t, &ts, TSK_SAMPLE_COUNTS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + + /* Out-of-bounds queries */ + for (j = 0; j < sizeof(bad_nodes) / sizeof(tsk_id_t); j++) { + u = bad_nodes[j]; + ret = tsk_tree_get_parent(&t, u, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_tree_get_time(&t, u, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_tree_get_mrca(&t, u, 0, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_tree_get_mrca(&t, 0, u, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_tree_get_num_samples(&t, u, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_tree_get_num_tracked_samples(&t, u, NULL); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + /* Also check tree sequence methods */ + ret = tsk_treeseq_get_node(&ts, (tsk_tbl_size_t) u, &node); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + CU_ASSERT(!tsk_treeseq_is_sample(&ts, u)); + CU_ASSERT(!tsk_tree_is_sample(&t, u)); + } + + tracked_samples[0] = 0; + tracked_samples[1] = (tsk_id_t) tsk_treeseq_get_num_samples(&ts); + ret = tsk_tree_set_tracked_samples(&t, 2, tracked_samples); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_SAMPLES); + tracked_samples[1] = (tsk_id_t) tsk_treeseq_get_num_nodes(&ts); + ret = tsk_tree_set_tracked_samples(&t, 2, tracked_samples); + CU_ASSERT_EQUAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tracked_samples[1] = 0; + ret = tsk_tree_set_tracked_samples(&t, 2, tracked_samples); + CU_ASSERT_EQUAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + + tsk_treeseq_from_text(&other_ts, 10, paper_ex_nodes, paper_ex_edges, NULL, NULL, NULL, + paper_ex_individuals, NULL); + ret = tsk_tree_alloc(&other_t, &other_ts, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_copy(&t, &t); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_tree_copy(&t, &other_t); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_PARAM_VALUE); + + /* TODO run checks for the various unsupported operations with flags */ + + tsk_tree_free(&t); + tsk_tree_free(&other_t); + tsk_treeseq_free(&other_ts); + tsk_treeseq_free(&ts); +} + +static void +test_deduplicate_sites(void) +{ + int ret; + // Modified from paper_ex + const char *tidy_sites = + "1 0\n" + "4.5 0\n" + "8.5 0\n"; + const char *tidy_mutations = + "0 2 1\n" + "0 1 2\n" + "0 6 3\n" + "0 3 4\n" + "1 0 1\n" + "1 2 2\n" + "1 4 3\n" + "1 5 4\n" + "2 5 1\n" + "2 7 2\n" + "2 1 3\n" + "2 0 4\n"; + const char *messy_sites = + "1 0\n" + "1 0\n" + "1 0\n" + "1 0\n" + "4.5 0\n" + "4.5 0\n" + "4.5 0\n" + "4.5 0\n" + "8.5 0\n" + "8.5 0\n" + "8.5 0\n" + "8.5 0\n"; + const char *messy_mutations = + "0 2 1\n" + "1 1 2\n" + "2 6 3\n" + "3 3 4\n" + "4 0 1\n" + "5 2 2\n" + "6 4 3\n" + "7 5 4\n" + "8 5 1\n" + "9 7 2\n" + "10 1 3\n" + "11 0 4\n"; + tsk_tbl_collection_t tidy, messy; + + ret = tsk_tbl_collection_alloc(&tidy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tbl_collection_alloc(&messy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + messy.sequence_length = 10; + tidy.sequence_length = 10; + parse_individuals(paper_ex_individuals, tidy.individuals); + parse_nodes(paper_ex_nodes, tidy.nodes); + parse_sites(tidy_sites, tidy.sites); + parse_mutations(tidy_mutations, tidy.mutations); + // test cleaning doesn't mess up the tidy one + parse_individuals(paper_ex_individuals, messy.individuals); + parse_nodes(paper_ex_nodes, messy.nodes); + parse_sites(tidy_sites, messy.sites); + parse_mutations(tidy_mutations, messy.mutations); + + ret = tsk_tbl_collection_deduplicate_sites(&messy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_tbl_equals(tidy.sites, messy.sites)); + CU_ASSERT_TRUE(tsk_mutation_tbl_equals(tidy.mutations, messy.mutations)); + + tsk_site_tbl_clear(messy.sites); + tsk_mutation_tbl_clear(messy.mutations); + + // test with the actual messy one + parse_sites(messy_sites, messy.sites); + parse_mutations(messy_mutations, messy.mutations); + + ret = tsk_tbl_collection_deduplicate_sites(&messy, 0); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_TRUE(tsk_site_tbl_equals(tidy.sites, messy.sites)); + CU_ASSERT_TRUE(tsk_mutation_tbl_equals(tidy.mutations, messy.mutations)); + + tsk_tbl_collection_free(&tidy); + tsk_tbl_collection_free(&messy); +} + +static void +test_deduplicate_sites_errors(void) +{ + int ret; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 10; + ret = tsk_site_tbl_add_row(tables.sites, 2, "A", 1, "m", 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_tbl_add_row(tables.sites, 2, "TT", 2, "MM", 2); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_mutation_tbl_add_row(tables.mutations, 0, 0, -1, + "T", 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_tbl_add_row(tables.nodes, 0, 0, TSK_NULL, + TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* Negative position */ + tables.sites->position[0] = -1; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_SITE_POSITION); + tables.sites->position[0] = 2; + + /* unsorted position */ + tables.sites->position[1] = 0.5; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_UNSORTED_SITES); + tables.sites->position[1] = 2; + + /* negative site ID */ + tables.mutations->site[0] = -1; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tables.mutations->site[0] = 0; + + /* site ID out of bounds */ + tables.mutations->site[0] = 2; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); + tables.mutations->site[0] = 0; + + /* Bad offset in metadata */ + tables.sites->metadata_offset[0] = 2; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + tables.sites->metadata_offset[0] = 0; + + /* Bad length in metadata */ + tables.sites->metadata_offset[2] = 100; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + tables.sites->metadata_offset[2] = 3; + + /* Bad offset in ancestral_state */ + tables.sites->ancestral_state_offset[0] = 2; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + tables.sites->ancestral_state_offset[0] = 0; + + /* Bad length in ancestral_state */ + tables.sites->ancestral_state_offset[2] = 100; + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, TSK_ERR_BAD_OFFSET); + tables.sites->ancestral_state_offset[2] = 3; + + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL(ret, 0); + + tsk_tbl_collection_free(&tables); +} + +static void +test_deduplicate_sites_multichar(void) +{ + int ret; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tables.sequence_length = 10; + ret = tsk_site_tbl_add_row(tables.sites, 0, "AA", 1, "M", 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_site_tbl_add_row(tables.sites, 0, "0", 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_site_tbl_add_row(tables.sites, 1, "BBBBB", 5, "NNNNN", 5); + CU_ASSERT_EQUAL_FATAL(ret, 2); + ret = tsk_site_tbl_add_row(tables.sites, 1, "0", 1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 3); + + ret = tsk_tbl_collection_deduplicate_sites(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(tables.sites->num_rows, 2); + CU_ASSERT_EQUAL_FATAL(tables.sites->position[0], 0); + CU_ASSERT_EQUAL_FATAL(tables.sites->position[1], 1); + CU_ASSERT_EQUAL_FATAL(tables.sites->ancestral_state[0], 'A'); + CU_ASSERT_EQUAL_FATAL(tables.sites->ancestral_state_offset[1], 1); + CU_ASSERT_EQUAL_FATAL(tables.sites->metadata[0], 'M'); + CU_ASSERT_EQUAL_FATAL(tables.sites->metadata_offset[1], 1); + + CU_ASSERT_NSTRING_EQUAL(tables.sites->ancestral_state + 1, "BBBBB", 5); + CU_ASSERT_EQUAL_FATAL(tables.sites->ancestral_state_offset[2], 6); + CU_ASSERT_NSTRING_EQUAL(tables.sites->metadata + 1, "NNNNN", 5); + CU_ASSERT_EQUAL_FATAL(tables.sites->metadata_offset[2], 6); + + tsk_tbl_collection_free(&tables); +} + +static void +test_empty_tree_sequence(void) +{ + tsk_treeseq_t ts; + tsk_tbl_collection_t tables; + tsk_tree_t t; + tsk_id_t v; + int ret; + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_alloc(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SEQUENCE_LENGTH); + tsk_treeseq_free(&ts); + tables.sequence_length = 1.0; + ret = tsk_treeseq_alloc(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + verify_empty_tree_sequence(&ts, 1.0); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL_FATAL(t.left_root, TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.left, 0); + CU_ASSERT_EQUAL_FATAL(t.right, 1); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_parent(&t, 0, &v), TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_tree_free(&t); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_last(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL_FATAL(t.left_root, TSK_NULL); + CU_ASSERT_EQUAL_FATAL(t.left, 0); + CU_ASSERT_EQUAL_FATAL(t.right, 1); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_parent(&t, 0, &v), TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_tree_free(&t); + + tsk_treeseq_free(&ts); + tsk_tbl_collection_free(&tables); +} + +static void +test_zero_edges(void) +{ + const char *nodes = + "1 0 0\n" + "1 0 0\n"; + const char *edges = ""; + const char *sites = + "0.1 0\n" + "0.2 0\n"; + const char *mutations = + "0 0 1\n" + "1 1 1\n"; + tsk_treeseq_t ts, tss; + tsk_tree_t t; + const char *haplotypes[] = {"10", "01"}; + char *haplotype; + tsk_hapgen_t hapgen; + unsigned int j; + tsk_id_t samples, node_map; + const tsk_id_t z = TSK_NULL; + tsk_id_t parents[] = { + z, z, + }; + int ret; + + tsk_treeseq_from_text(&ts, 2, nodes, edges, NULL, sites, mutations, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 2.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + tsk_treeseq_print_state(&ts, _devnull); + + verify_trees(&ts, 1, parents); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(t.left, 0); + CU_ASSERT_EQUAL(t.right, 2); + CU_ASSERT_EQUAL(t.parent[0], TSK_NULL); + CU_ASSERT_EQUAL(t.parent[1], TSK_NULL); + CU_ASSERT_EQUAL(t.left_root, 0); + CU_ASSERT_EQUAL(t.left_sib[0], TSK_NULL); + CU_ASSERT_EQUAL(t.right_sib[0], 1); + tsk_tree_print_state(&t, _devnull); + tsk_tree_free(&t); + + ret = tsk_tree_alloc(&t, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_last(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL(t.left, 0); + CU_ASSERT_EQUAL(t.right, 2); + CU_ASSERT_EQUAL(t.parent[0], TSK_NULL); + CU_ASSERT_EQUAL(t.parent[1], TSK_NULL); + CU_ASSERT_EQUAL(t.left_root, 0); + CU_ASSERT_EQUAL(t.left_sib[0], TSK_NULL); + CU_ASSERT_EQUAL(t.right_sib[0], 1); + tsk_tree_print_state(&t, _devnull); + tsk_tree_free(&t); + + /* We give pointers ot samples and node_map here as they must be non null */ + ret = tsk_treeseq_simplify(&ts, &samples, 0, 0, &tss, &node_map); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&tss), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&tss), 2.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&tss), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_sites(&tss), 2); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&tss), 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&tss), 1); + tsk_treeseq_print_state(&ts, _devnull); + + ret = tsk_hapgen_alloc(&hapgen, &ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_hapgen_print_state(&hapgen, _devnull); + for (j = 0; j < 2; j++) { + ret = tsk_hapgen_get_haplotype(&hapgen, (tsk_id_t) j, &haplotype); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_STRING_EQUAL(haplotype, haplotypes[j]); + } + tsk_hapgen_free(&hapgen); + tsk_treeseq_free(&ts); + tsk_treeseq_free(&tss); +} + + + + +int +main(int argc, char **argv) +{ + CU_TestInfo tests[] = { + /* simplest example tests */ + {"test_simplest_records", test_simplest_records}, + {"test_simplest_nonbinary_records", test_simplest_nonbinary_records}, + {"test_simplest_unary_records", test_simplest_unary_records}, + {"test_simplest_non_sample_leaf_records", test_simplest_non_sample_leaf_records}, + {"test_simplest_degenerate_multiple_root_records", + test_simplest_degenerate_multiple_root_records}, + {"test_simplest_multiple_root_records", test_simplest_multiple_root_records}, + {"test_simplest_zero_root_tree", test_simplest_zero_root_tree}, + {"test_simplest_root_mutations", test_simplest_root_mutations}, + {"test_simplest_back_mutations", test_simplest_back_mutations}, + {"test_simplest_general_samples", test_simplest_general_samples}, + {"test_simplest_holey_tree_sequence", test_simplest_holey_tree_sequence}, + {"test_simplest_holey_tsk_treeseq_zero_roots", + test_simplest_holey_tsk_treeseq_zero_roots}, + {"test_simplest_holey_tsk_treeseq_mutation_parents", + test_simplest_holey_tsk_treeseq_mutation_parents}, + {"test_simplest_initial_gap_tree_sequence", test_simplest_initial_gap_tree_sequence}, + {"test_simplest_initial_gap_zero_roots", test_simplest_initial_gap_zero_roots}, + {"test_simplest_initial_gap_tsk_treeseq_mutation_parents", + test_simplest_initial_gap_tsk_treeseq_mutation_parents}, + {"test_simplest_final_gap_tree_sequence", test_simplest_final_gap_tree_sequence}, + {"test_simplest_final_gap_tsk_treeseq_mutation_parents", + test_simplest_final_gap_tsk_treeseq_mutation_parents}, + {"test_simplest_individuals", test_simplest_individuals}, + {"test_simplest_bad_individuals", test_simplest_bad_individuals}, + {"test_simplest_bad_edges", test_simplest_bad_edges}, + {"test_simplest_bad_indexes", test_simplest_bad_indexes}, + {"test_simplest_bad_migrations", test_simplest_bad_migrations}, + {"test_simplest_migration_simplify", test_simplest_migration_simplify}, + {"test_simplest_overlapping_parents", test_simplest_overlapping_parents}, + {"test_simplest_contradictory_children", test_simplest_contradictory_children}, + {"test_simplest_overlapping_edges_simplify", + test_simplest_overlapping_edges_simplify}, + {"test_simplest_overlapping_unary_edges_simplify", + test_simplest_overlapping_unary_edges_simplify}, + {"test_simplest_overlapping_unary_edges_internal_samples_simplify", + test_simplest_overlapping_unary_edges_internal_samples_simplify}, + {"test_simplest_reduce_site_topology", test_simplest_reduce_site_topology}, + {"test_simplest_population_filter", test_simplest_population_filter}, + {"test_simplest_individual_filter", test_simplest_individual_filter}, + + /* Single tree tests */ + {"test_single_tree_good_records", test_single_tree_good_records}, + {"test_single_nonbinary_tree_good_records", + test_single_nonbinary_tree_good_records}, + {"test_single_tree_bad_records", test_single_tree_bad_records}, + {"test_single_tree_good_mutations", test_single_tree_good_mutations}, + {"test_single_tree_bad_mutations", test_single_tree_bad_mutations}, + {"test_single_tree_iter", test_single_tree_iter}, + {"test_single_tree_general_samples_iter", test_single_tree_general_samples_iter}, + {"test_single_nonbinary_tree_iter", test_single_nonbinary_tree_iter}, + {"test_single_tree_iter_times", test_single_tree_iter_times}, + {"test_single_tree_simplify", test_single_tree_simplify}, + {"test_single_tree_compute_mutation_parents", test_single_tree_compute_mutation_parents}, + + /* Multi tree tests */ + {"test_simple_multi_tree", test_simple_multi_tree}, + {"test_nonbinary_multi_tree", test_nonbinary_multi_tree}, + {"test_unary_multi_tree", test_unary_multi_tree}, + {"test_internal_sample_multi_tree", test_internal_sample_multi_tree}, + {"test_internal_sample_simplified_multi_tree", + test_internal_sample_simplified_multi_tree}, + {"test_left_to_right_multi_tree", test_left_to_right_multi_tree}, + {"test_gappy_multi_tree", test_gappy_multi_tree}, + {"test_tsk_treeseq_bad_records", test_tsk_treeseq_bad_records}, + + /* Diff iter tests */ + {"test_simple_diff_iter", test_simple_diff_iter}, + {"test_nonbinary_diff_iter", test_nonbinary_diff_iter}, + {"test_unary_diff_iter", test_unary_diff_iter}, + {"test_internal_sample_diff_iter", test_internal_sample_diff_iter}, + + /* Sample sets */ + {"test_simple_sample_sets", test_simple_sample_sets}, + {"test_nonbinary_sample_sets", test_nonbinary_sample_sets}, + {"test_internal_sample_sample_sets", test_internal_sample_sample_sets}, + + /* Misc */ + {"test_tree_errors", test_tree_errors}, + {"test_genealogical_nearest_neighbours_errors", + test_genealogical_nearest_neighbours_errors}, + {"test_deduplicate_sites", test_deduplicate_sites}, + {"test_deduplicate_sites_errors", test_deduplicate_sites_errors}, + {"test_deduplicate_sites_multichar", test_deduplicate_sites_multichar}, + {"test_empty_tree_sequence", test_empty_tree_sequence}, + {"test_zero_edges", test_zero_edges}, + + {NULL}, + }; + + return test_main(tests, argc, argv); +} diff --git a/c/testlib.c b/c/testlib.c new file mode 100644 index 0000000000..4685748094 --- /dev/null +++ b/c/testlib.c @@ -0,0 +1,643 @@ +#define _GNU_SOURCE +#include +#include +#include + +#include "testlib.h" + +/* Simple single tree example. */ +const char *single_tree_ex_nodes =/* 6 */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / \ */ + "1 0 -1 -1\n" /* / 5 */ + "0 1 -1 -1\n" /* 4 / \ */ + "0 2 -1 -1\n" /* / \ / \ */ + "0 3 -1 -1\n"; /* 0 1 2 3 */ +const char *single_tree_ex_edges = + "0 1 4 0,1\n" + "0 1 5 2,3\n" + "0 1 6 4,5\n"; +const char *single_tree_ex_sites = + "0.1 0\n" + "0.2 0\n" + "0.3 0\n"; +const char *single_tree_ex_mutations = + "0 2 1 -1\n" + "1 4 1 -1\n" + "1 0 0 1\n" /* Back mutation over 0 */ + "2 0 1 -1\n" /* recurrent mutations over samples */ + "2 1 1 -1\n" + "2 2 1 -1\n" + "2 3 1 -1\n"; + +/* Example from the PLOS paper */ +const char *paper_ex_nodes = + "1 0 -1 0\n" + "1 0 -1 0\n" + "1 0 -1 1\n" + "1 0 -1 1\n" + "0 0.071 -1 -1\n" + "0 0.090 -1 -1\n" + "0 0.170 -1 -1\n" + "0 0.202 -1 -1\n" + "0 0.253 -1 -1\n"; +const char *paper_ex_edges = + "2 10 4 2\n" + "2 10 4 3\n" + "0 10 5 1\n" + "0 2 5 3\n" + "2 10 5 4\n" + "0 7 6 0,5\n" + "7 10 7 0,5\n" + "0 2 8 2,6\n"; +/* We make one mutation for each tree */ +const char *paper_ex_sites = + "1 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *paper_ex_mutations = + "0 2 1\n" + "1 0 1\n" + "2 5 1\n"; +/* Two (diploid) indivduals */ +const char *paper_ex_individuals = + "0 0.2,1.5\n" + "0 0.0,0.0\n"; + +/* An example of a nonbinary tree sequence */ +const char *nonbinary_ex_nodes = + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "0 0.01 0 -1\n" + "0 0.068 0 -1\n" + "0 0.130 0 -1\n" + "0 0.279 0 -1\n" + "0 0.405 0 -1\n"; +const char *nonbinary_ex_edges = + "0 100 8 0,1,2,3\n" + "0 100 9 6,8\n" + "0 100 10 4\n" + "0 17 10 5\n" + "0 100 10 7\n" + "17 100 11 5,9\n" + "0 17 12 9\n" + "0 100 12 10\n" + "17 100 12 11"; +const char *nonbinary_ex_sites = + "1 0\n" + "18 0\n"; +const char *nonbinary_ex_mutations = + "0 2 1\n" + "1 11 1"; + +/* An example of a tree sequence with unary nodes. */ +const char *unary_ex_nodes = + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "1 0 0 -1\n" + "0 0.071 0 -1\n" + "0 0.090 0 -1\n" + "0 0.170 0 -1\n" + "0 0.202 0 -1\n" + "0 0.253 0 -1\n"; +const char *unary_ex_edges = + "2 10 4 2,3\n" + "0 10 5 1\n" + "0 2 5 3\n" + "2 10 5 4\n" + "0 7 6 0,5\n" + "7 10 7 0\n" + "0 2 7 2\n" + "7 10 7 5\n" + "0 7 8 6\n" + "0 2 8 7\n"; + +/* We make one mutation for each tree, over unary nodes if this exist */ +const char *unary_ex_sites = + "1.0 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *unary_ex_mutations = + "0 2 1\n" + "1 6 1\n" + "2 5 1\n"; + +/* An example of a tree sequence with internally sampled nodes. */ + +/* TODO: find a way to draw these side-by-side */ +/* + 7 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +| | | 3 +| | | +| 1 2 +| +0 + + 8 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +3 | | | + | | | + 1 2 | + | + 0 + + 6 ++-+-+ +| 5 +| +-++ +| | 4 +| | +++ +| | | 3 +| | | +| 1 2 +| +0 +*/ + +const char *internal_sample_ex_nodes = + "1 0.0 0 -1\n" + "1 0.1 0 -1\n" + "1 0.1 0 -1\n" + "1 0.2 0 -1\n" + "0 0.4 0 -1\n" + "1 0.5 0 -1\n" + "0 0.7 0 -1\n" + "0 1.0 0 -1\n" + "0 1.2 0 -1\n"; +const char *internal_sample_ex_edges = + "2 8 4 0\n" + "0 10 4 2\n" + "0 2 4 3\n" + "8 10 4 3\n" + "0 10 5 1,4\n" + "8 10 6 0,5\n" + "0 2 7 0,5\n" + "2 8 8 3,5\n"; +/* We make one mutation for each tree, some above the internal node */ +const char *internal_sample_ex_sites = + "1.0 0\n" + "4.5 0\n" + "8.5 0\n"; +const char *internal_sample_ex_mutations = + "0 2 1\n" + "1 5 1\n" + "2 5 1\n"; + + +/* Simple utilities to parse text so we can write declaritive + * tests. This is not intended as a robust general input mechanism. + */ + +void +parse_nodes(const char *text, tsk_node_tbl_t *node_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + double time; + int flags, population, individual; + char *name; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + flags = atoi(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + time = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + population = atoi(p); + p = strtok(NULL, whitespace); + if (p == NULL) { + individual = -1; + } else { + individual = atoi(p); + p = strtok(NULL, whitespace); + } + if (p == NULL) { + name = ""; + } else { + name = p; + } + ret = tsk_node_tbl_add_row(node_table, flags, time, population, + individual, name, strlen(name)); + CU_ASSERT_FATAL(ret >= 0); + } +} + +void +parse_edges(const char *text, tsk_edge_tbl_t *edge_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE], sub_line[MAX_LINE]; + const char *whitespace = " \t"; + char *p, *q; + double left, right; + tsk_id_t parent, child; + uint32_t num_children; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + left = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + right = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + parent = atoi(p); + num_children = 0; + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + + num_children = 1; + q = p; + while (*q != '\0') { + if (*q == ',') { + num_children++; + } + q++; + } + CU_ASSERT_FATAL(num_children >= 1); + strncpy(sub_line, p, MAX_LINE); + q = strtok(sub_line, ","); + for (k = 0; k < num_children; k++) { + CU_ASSERT_FATAL(q != NULL); + child = atoi(q); + ret = tsk_edge_tbl_add_row(edge_table, left, right, parent, child); + CU_ASSERT_FATAL(ret >= 0); + q = strtok(NULL, ","); + } + CU_ASSERT_FATAL(q == NULL); + } +} + +void +parse_sites(const char *text, tsk_site_tbl_t *site_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + double position; + char ancestral_state[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + position = atof(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + strncpy(ancestral_state, p, MAX_LINE); + ret = tsk_site_tbl_add_row(site_table, position, ancestral_state, + strlen(ancestral_state), NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + } +} + +void +parse_mutations(const char *text, tsk_mutation_tbl_t *mutation_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + const char *whitespace = " \t"; + char *p; + tsk_id_t node, site, parent; + char derived_state[MAX_LINE]; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + site = atoi(p); + CU_ASSERT_FATAL(p != NULL); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + node = atoi(p); + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + strncpy(derived_state, p, MAX_LINE); + parent = TSK_NULL; + p = strtok(NULL, whitespace); + if (p != NULL) { + parent = atoi(p); + } + ret = tsk_mutation_tbl_add_row(mutation_table, site, node, parent, + derived_state, strlen(derived_state), NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + } +} + +void +parse_individuals(const char *text, tsk_individual_tbl_t *individual_table) +{ + int ret; + size_t c, k; + size_t MAX_LINE = 1024; + char line[MAX_LINE]; + char sub_line[MAX_LINE]; + const char *whitespace = " \t"; + char *p, *q; + double location[MAX_LINE]; + int location_len; + int flags; + char *name; + + c = 0; + while (text[c] != '\0') { + /* Fill in the line */ + k = 0; + while (text[c] != '\n' && text[c] != '\0') { + CU_ASSERT_FATAL(k < MAX_LINE - 1); + line[k] = text[c]; + c++; + k++; + } + if (text[c] == '\n') { + c++; + } + line[k] = '\0'; + p = strtok(line, whitespace); + CU_ASSERT_FATAL(p != NULL); + flags = atoi(p); + + p = strtok(NULL, whitespace); + CU_ASSERT_FATAL(p != NULL); + // the locations are comma-separated + location_len = 1; + q = p; + while (*q != '\0') { + if (*q == ',') { + location_len++; + } + q++; + } + CU_ASSERT_FATAL(location_len >= 1); + strncpy(sub_line, p, MAX_LINE); + q = strtok(sub_line, ","); + for (k = 0; k < location_len; k++) { + CU_ASSERT_FATAL(q != NULL); + location[k] = atof(q); + q = strtok(NULL, ","); + } + CU_ASSERT_FATAL(q == NULL); + p = strtok(NULL, whitespace); + if (p == NULL) { + name = ""; + } else { + name = p; + } + ret = tsk_individual_tbl_add_row(individual_table, flags, location, location_len, + name, strlen(name)); + CU_ASSERT_FATAL(ret >= 0); + } +} + +void +tsk_treeseq_from_text(tsk_treeseq_t *ts, double sequence_length, + const char *nodes, const char *edges, + const char *migrations, const char *sites, const char *mutations, + const char *individuals, const char *provenance) +{ + int ret; + tsk_tbl_collection_t tables; + tsk_id_t max_population_id; + tsk_tbl_size_t j; + + CU_ASSERT_FATAL(ts != NULL); + CU_ASSERT_FATAL(nodes != NULL); + CU_ASSERT_FATAL(edges != NULL); + /* Not supporting provenance here for now */ + CU_ASSERT_FATAL(provenance == NULL); + + ret = tsk_tbl_collection_alloc(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = sequence_length; + parse_nodes(nodes, tables.nodes); + parse_edges(edges, tables.edges); + if (sites != NULL) { + parse_sites(sites, tables.sites); + } + if (mutations != NULL) { + parse_mutations(mutations, tables.mutations); + } + if (individuals != NULL) { + parse_individuals(individuals, tables.individuals); + } + /* We need to add in populations if they are referenced */ + max_population_id = -1; + for (j = 0; j < tables.nodes->num_rows; j++) { + max_population_id = TSK_MAX(max_population_id, tables.nodes->population[j]); + } + if (max_population_id >= 0) { + for (j = 0; j <= (tsk_tbl_size_t) max_population_id; j++) { + ret = tsk_population_tbl_add_row(tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, j); + } + } + + ret = tsk_treeseq_alloc(ts, &tables, TSK_BUILD_INDEXES); + /* tsk_treeseq_print_state(ts, stdout); */ + /* printf("ret = %s\n", tsk_strerror(ret)); */ + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_tbl_collection_free(&tables); +} + +void +unsort_edges(tsk_edge_tbl_t *edges, size_t start) +{ + size_t j, k; + size_t n = edges->num_rows - start; + tsk_edge_t *buff = malloc(n * sizeof(tsk_edge_t)); + CU_ASSERT_FATAL(buff != NULL); + + for (j = 0; j < n; j++) { + k = start + j; + buff[j].left = edges->left[k]; + buff[j].right = edges->right[k]; + buff[j].parent = edges->parent[k]; + buff[j].child = edges->child[k]; + } + for (j = 0; j < n; j++) { + k = start + j; + edges->left[k] = buff[n - j - 1].left; + edges->right[k] = buff[n - j - 1].right; + edges->parent[k] = buff[n - j - 1].parent; + edges->child[k] = buff[n - j - 1].child; + } + free(buff); +} + +static int +tskit_suite_init(void) +{ + int fd = -1; + static char template[] = "/tmp/tsk_c_test_XXXXXX"; + + _tmp_file_name = NULL; + _devnull = NULL; + + _tmp_file_name = malloc(sizeof(template)); + if (_tmp_file_name == NULL) { + return CUE_NOMEMORY; + } + strcpy(_tmp_file_name, template); + fd = mkstemp(_tmp_file_name); + if (fd == -1) { + return CUE_SINIT_FAILED; + } + close(fd); + _devnull = fopen("/dev/null", "w"); + if (_devnull == NULL) { + return CUE_SINIT_FAILED; + } + return CUE_SUCCESS; +} + +static int +tskit_suite_cleanup(void) +{ + if (_tmp_file_name != NULL) { + unlink(_tmp_file_name); + free(_tmp_file_name); + } + if (_devnull != NULL) { + fclose(_devnull); + } + return CUE_SUCCESS; +} + +static void +handle_cunit_error(void) +{ + fprintf(stderr, "CUnit error occured: %d: %s\n", CU_get_error(), CU_get_error_msg()); + exit(EXIT_FAILURE); +} + +int +test_main(CU_TestInfo *tests, int argc, char **argv) +{ + int ret; + CU_pTest test; + CU_pSuite suite; + CU_SuiteInfo suites[] = { + { + .pName = "tskit", + .pInitFunc = tskit_suite_init, + .pCleanupFunc = tskit_suite_cleanup, + .pTests = tests, + }, + CU_SUITE_INFO_NULL, + }; + if (CUE_SUCCESS != CU_initialize_registry()) { + handle_cunit_error(); + } + if (CUE_SUCCESS != CU_register_suites(suites)) { + handle_cunit_error(); + } + CU_basic_set_mode(CU_BRM_VERBOSE); + + if (argc == 1) { + CU_basic_run_tests(); + } else if (argc == 2) { + suite = CU_get_suite_by_name("tskit", CU_get_registry()); + if (suite == NULL) { + printf("Suite not found\n"); + return EXIT_FAILURE; + } + test = CU_get_test_by_name(argv[1], suite); + if (test == NULL) { + printf("Test '%s' not found\n", argv[1]); + return EXIT_FAILURE; + } + CU_basic_run_test(suite, test); + } else { + printf("usage: %s \n", argv[0]); + return EXIT_FAILURE; + } + + ret = EXIT_SUCCESS; + if (CU_get_number_of_tests_failed() != 0) { + printf("Test failed!\n"); + ret = EXIT_FAILURE; + } + CU_cleanup_registry(); + return ret; +} diff --git a/c/testlib.h b/c/testlib.h new file mode 100644 index 0000000000..e4fe301f93 --- /dev/null +++ b/c/testlib.h @@ -0,0 +1,56 @@ +#ifndef __TESTLIB_H__ +#define __TESTLIB_H__ + +#include + +#include +#include "tsk_trees.h" + +/* Global variables used in the test suite */ + +char * _tmp_file_name; +FILE * _devnull; + +int test_main(CU_TestInfo *tests, int argc, char **argv); + +void tsk_treeseq_from_text(tsk_treeseq_t *ts, + double sequence_length, + const char *nodes, const char *edges, const char *migrations, + const char *sites, const char *mutations, + const char *individuals, const char *provenance); + +void parse_nodes(const char *text, tsk_node_tbl_t *node_table); +void parse_edges(const char *text, tsk_edge_tbl_t *edge_table); +void parse_sites(const char *text, tsk_site_tbl_t *site_table); +void parse_mutations(const char *text, tsk_mutation_tbl_t *mutation_table); +void parse_individuals(const char *text, tsk_individual_tbl_t *individual_table); + +void unsort_edges(tsk_edge_tbl_t *edges, size_t start); + +extern const char *single_tree_ex_nodes; +extern const char *single_tree_ex_edges; +extern const char *single_tree_ex_sites; +extern const char *single_tree_ex_mutations; + +extern const char *nonbinary_ex_nodes; +extern const char *nonbinary_ex_edges; +extern const char *nonbinary_ex_sites; +extern const char *nonbinary_ex_mutations; + +extern const char *unary_ex_nodes; +extern const char *unary_ex_edges; +extern const char *unary_ex_sites; +extern const char *unary_ex_mutations; + +extern const char *internal_sample_ex_nodes; +extern const char *internal_sample_ex_edges; +extern const char *internal_sample_ex_sites; +extern const char *internal_sample_ex_mutations; + +extern const char *paper_ex_nodes; +extern const char *paper_ex_edges; +extern const char *paper_ex_sites; +extern const char *paper_ex_mutations; +extern const char *paper_ex_individuals; + +#endif diff --git a/c/tsk_convert.c b/c/tsk_convert.c new file mode 100644 index 0000000000..e44e475ce2 --- /dev/null +++ b/c/tsk_convert.c @@ -0,0 +1,454 @@ +#include +#include +#include +#include +#include +#include + +#include "tsk_convert.h" + +/* If we want the tskit library version embedded in the output, we need to + * define it at compile time. */ +/* TODO need to refine this a bit for embedded applications. */ +#ifndef TSK_LIBRARY_VERSION_STR +#define TSK_LIBRARY_VERSION_STR "undefined" +#endif + +/* ======================================================== * + * Newick output. + * ======================================================== */ + +/* This infrastructure is left-over from an earlier more complex version + * of this algorithm that worked over a tree sequence and cached the newick + * subtrees, updating according to diffs. It's unclear whether this complexity + * was of any real-world use, since newick output for large trees is pretty + * pointless. */ + +typedef struct { + size_t precision; + int flags; + char *newick; + tsk_tree_t *tree; +} tsk_newick_converter_t; + +static int +tsk_newick_converter_run(tsk_newick_converter_t *self, tsk_id_t root, + size_t buffer_size, char *buffer) +{ + int ret = TSK_ERR_GENERIC; + tsk_tree_t *tree = self->tree; + tsk_id_t *stack = self->tree->stack1; + const double *time = self->tree->tree_sequence->tables->nodes->time; + int stack_top = 0; + int label; + size_t s = 0; + int r; + tsk_id_t u, v, w, root_parent; + double branch_length; + + if (root < 0 || root >= (tsk_id_t) self->tree->num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (buffer == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + root_parent = tree->parent[root]; + stack[0] = root; + u = root_parent; + while (stack_top >= 0) { + v = stack[stack_top]; + if (tree->left_child[v] != TSK_NULL && v != u) { + if (s >= buffer_size) { + ret = TSK_ERR_BUFFER_OVERFLOW; + goto out; + } + buffer[s] = '('; + s++; + for (w = tree->right_child[v]; w != TSK_NULL; w = tree->left_sib[w]) { + stack_top++; + stack[stack_top] = w; + } + } else { + u = tree->parent[v]; + stack_top--; + if (tree->left_child[v] == TSK_NULL) { + if (s >= buffer_size) { + ret = TSK_ERR_BUFFER_OVERFLOW; + goto out; + } + /* We do this for ms-compatability. This should be a configurable option + * via the flags attribute */ + label = v + 1; + r = snprintf(buffer + s, buffer_size - s, "%d", label); + if (r < 0) { + ret = TSK_ERR_IO; + goto out; + } + s += (size_t) r; + if (s >= buffer_size) { + ret = TSK_ERR_BUFFER_OVERFLOW; + goto out; + } + } + if (u != root_parent) { + branch_length = (time[u] - time[v]); + r = snprintf(buffer + s, buffer_size - s, ":%.*f", (int) self->precision, + branch_length); + if (r < 0) { + ret = TSK_ERR_IO; + goto out; + } + s += (size_t) r; + if (s >= buffer_size) { + ret = TSK_ERR_BUFFER_OVERFLOW; + goto out; + } + if (v == tree->right_child[u]) { + buffer[s] = ')'; + } else { + buffer[s] = ','; + } + s++; + } + } + } + if ((s + 1) >= buffer_size) { + ret = TSK_ERR_BUFFER_OVERFLOW; + goto out; + } + buffer[s] = ';'; + buffer[s + 1] = '\0'; + ret = 0; +out: + return ret; +} + +static int +tsk_newick_converter_alloc(tsk_newick_converter_t *self, tsk_tree_t *tree, + size_t precision, int flags) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_newick_converter_t)); + self->precision = precision; + self->flags = flags; + self->tree = tree; + return ret; +} + +static int +tsk_newick_converter_free(tsk_newick_converter_t *TSK_UNUSED(self)) +{ + return 0; +} + +int +tsk_convert_newick(tsk_tree_t *tree, tsk_id_t root, size_t precision, int flags, + size_t buffer_size, char *buffer) +{ + int ret = 0; + tsk_newick_converter_t nc; + + ret = tsk_newick_converter_alloc(&nc, tree, precision, flags); + if (ret != 0) { + goto out; + } + ret = tsk_newick_converter_run(&nc, root, buffer_size, buffer); +out: + tsk_newick_converter_free(&nc); + return ret; +} + +/* ======================================================== * + * VCF conversion. + * ======================================================== */ + +void +tsk_vcf_converter_print_state(tsk_vcf_converter_t *self, FILE* out) +{ + fprintf(out, "VCF converter state\n"); + fprintf(out, "ploidy = %d\n", self->ploidy); + fprintf(out, "num_samples = %d\n", (int) self->num_samples); + fprintf(out, "contig_length = %lu\n", self->contig_length); + fprintf(out, "num_vcf_samples = %d\n", (int) self->num_vcf_samples); + fprintf(out, "header = %d bytes\n", (int) strlen(self->header)); + fprintf(out, "vcf_genotypes = %d bytes: %s", (int) self->vcf_genotypes_size, + self->vcf_genotypes); + fprintf(out, "record = %d bytes\n", (int) self->record_size); +} + +static int TSK_WARN_UNUSED +tsk_vcf_converter_make_header(tsk_vcf_converter_t *self, const char *contig_id) +{ + int ret = TSK_ERR_GENERIC; + const char *header_prefix_template = + "##fileformat=VCFv4.2\n" + "##source=msprime " TSK_LIBRARY_VERSION_STR "\n" + "##FILTER=\n" + "##contig=\n" + "##FORMAT=\n" + "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT"; + char* header_prefix = NULL; + const char *sample_pattern = "\tmsp_%d"; + size_t buffer_size, offset; + uint32_t j; + int written; + + written = snprintf(NULL, 0, header_prefix_template, contig_id, self->contig_length); + if (written < 0) { + ret = TSK_ERR_IO; + goto out; + } + buffer_size = (size_t) written + 1; + header_prefix = malloc(buffer_size); + if (header_prefix == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + written = snprintf(header_prefix, buffer_size, header_prefix_template, + contig_id, self->contig_length); + if (written < 0) { + ret = TSK_ERR_IO; + goto out; + } + offset = buffer_size - 1; + for (j = 0; j < self->num_vcf_samples; j++) { + written = snprintf(NULL, 0, sample_pattern, j); + if (written < 0) { + ret = TSK_ERR_IO; + goto out; + } + buffer_size += (size_t) written; + } + buffer_size += 1; /* make room for \n */ + self->header = malloc(buffer_size); + if (self->header == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->header, header_prefix, offset); + for (j = 0; j < self->num_vcf_samples; j++) { + written = snprintf(self->header + offset, buffer_size - offset, + sample_pattern, j); + if (written < 0) { + ret = TSK_ERR_IO; + goto out; + } + offset += (size_t) written; + assert(offset < buffer_size); + } + self->header[buffer_size - 2] = '\n'; + self->header[buffer_size - 1] = '\0'; + ret = 0; +out: + if (header_prefix != NULL) { + free(header_prefix); + } + return ret; +} + +static int TSK_WARN_UNUSED +tsk_vcf_converter_make_record(tsk_vcf_converter_t *self, const char *contig_id) +{ + int ret = TSK_ERR_GENERIC; + unsigned int ploidy = self->ploidy; + size_t n = self->num_vcf_samples; + size_t j, k; + + self->vcf_genotypes_size = 2 * self->num_samples + 1; + /* it's not worth working out exactly what size the record prefix + * will be. 1K is plenty for us */ + self->record_size = 1024 + self->contig_id_size + self->vcf_genotypes_size; + self->record = malloc(self->record_size); + self->vcf_genotypes = malloc(self->vcf_genotypes_size); + self->genotypes = malloc(self->num_samples * sizeof(char)); + if (self->record == NULL || self->vcf_genotypes == NULL + || self->genotypes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->record, contig_id, self->contig_id_size); + /* Set up the vcf_genotypes string. We don't want to have to put + * in tabs and |s for every row so we insert them at the start. + */ + for (j = 0; j < n; j++) { + for (k = 0; k < ploidy; k++) { + self->vcf_genotypes[2 * ploidy * j + 2 * k] = '0'; + self->vcf_genotypes[2 * ploidy * j + 2 * k + 1] = '|'; + } + self->vcf_genotypes[2 * ploidy * (j + 1) - 1] = '\t'; + } + self->vcf_genotypes[self->vcf_genotypes_size - 2] = '\n'; + self->vcf_genotypes[self->vcf_genotypes_size - 1] = '\0'; + ret = 0; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_vcf_converter_write_record(tsk_vcf_converter_t *self, tsk_variant_t *variant) +{ + int ret = TSK_ERR_GENERIC; + int written; + uint32_t j, k; + size_t offset; + unsigned int p = self->ploidy; + /* TODO update this to use "%.*s", len, alleles[0] etc to write out the + * alleles properly. */ + const char *template = "\t%lu\t.\tA\tT\t.\tPASS\t.\tGT\t"; + unsigned long pos = self->positions[variant->site->id]; + + /* CHROM was written at init time as it is constant */ + written = snprintf(self->record + self->contig_id_size, + self->record_size - self->contig_id_size, template, pos); + if (written < 0) { + ret = TSK_ERR_IO; + goto out; + } + offset = self->contig_id_size + (size_t) written; + + for (j = 0; j < self->num_vcf_samples; j++) { + for (k = 0; k < p; k++) { + self->vcf_genotypes[2 * p * j + 2 * k] = + (char) ('0' + variant->genotypes.u8[j * p + k]); + } + } + assert(offset + self->vcf_genotypes_size < self->record_size); + memcpy(self->record + offset, self->vcf_genotypes, self->vcf_genotypes_size); + ret = 0; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_vcf_converter_convert_positions(tsk_vcf_converter_t *self, tsk_treeseq_t *tree_sequence) +{ + int ret = 0; + unsigned long pos; + tsk_site_t site; + /* VCF is 1-based, so we must make sure we never have a 0 coordinate */ + unsigned long last_position = 0; + size_t j; + + for (j = 0; j < self->num_sites; j++) { + ret = tsk_treeseq_get_site(tree_sequence, j, &site); + if (ret != 0) { + goto out; + } + /* FIXME: we shouldn't be doing this. Round to the nearest integer + * instead. https://github.com/tskit-dev/tskit/issues/2 */ + + /* update pos. We use a simple algorithm to ensure positions + * are unique. */ + pos = (unsigned long) round(site.position); + if (pos <= last_position) { + pos = last_position + 1; + } + last_position = pos; + self->positions[j] = pos; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_vcf_converter_get_header(tsk_vcf_converter_t *self, char **header) +{ + *header = self->header; + return 0; +} + +int TSK_WARN_UNUSED +tsk_vcf_converter_next(tsk_vcf_converter_t *self, char **record) +{ + int ret = -1; + int err; + tsk_variant_t *variant; + + ret = tsk_vargen_next(self->vargen, &variant); + if (ret < 0) { + goto out; + } + if (ret == 1) { + err = tsk_vcf_converter_write_record(self, variant); + if (err != 0) { + ret = err; + goto out; + } + *record = self->record; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_vcf_converter_alloc(tsk_vcf_converter_t *self, + tsk_treeseq_t *tree_sequence, unsigned int ploidy, const char *contig_id) +{ + int ret = -1; + + memset(self, 0, sizeof(tsk_vcf_converter_t)); + self->ploidy = ploidy; + self->contig_id_size = strlen(contig_id); + self->num_samples = tsk_treeseq_get_num_samples(tree_sequence); + if (ploidy < 1 || self->num_samples % ploidy != 0) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + self->num_vcf_samples = self->num_samples / self->ploidy; + self->vargen = malloc(sizeof(tsk_vargen_t)); + if (self->vargen == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_vargen_alloc(self->vargen, tree_sequence, NULL, 0, 0); + if (ret != 0) { + goto out; + } + self->num_sites = tsk_treeseq_get_num_sites(tree_sequence); + self->positions = malloc(self->num_sites * sizeof(unsigned long)); + if (self->positions == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_vcf_converter_convert_positions(self, tree_sequence); + if (ret != 0) { + goto out; + } + self->contig_length = + (unsigned long) round(tsk_treeseq_get_sequence_length(tree_sequence)); + if (self->num_sites > 0) { + self->contig_length = TSK_MAX( + self->contig_length, + self->positions[self->num_sites - 1]); + } + ret = tsk_vcf_converter_make_header(self, contig_id); + if (ret != 0) { + goto out; + } + if (tsk_treeseq_get_num_edges(tree_sequence) > 0) { + ret = tsk_vcf_converter_make_record(self, contig_id); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +int +tsk_vcf_converter_free(tsk_vcf_converter_t *self) +{ + tsk_safe_free(self->genotypes); + tsk_safe_free(self->header); + tsk_safe_free(self->vcf_genotypes); + tsk_safe_free(self->record); + tsk_safe_free(self->positions); + if (self->vargen != NULL) { + tsk_vargen_free(self->vargen); + free(self->vargen); + } + return 0; +} diff --git a/c/tsk_convert.h b/c/tsk_convert.h new file mode 100644 index 0000000000..520f8b236f --- /dev/null +++ b/c/tsk_convert.h @@ -0,0 +1,44 @@ +#ifndef TSK_CONVERT_H +#define TSK_CONVERT_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "tsk_genotypes.h" + +/* TODO do we really need to expose this or would a simpler function be + * more appropriate? Depends on how we use it at the Python level probably. */ + +typedef struct { + size_t num_samples; + size_t num_vcf_samples; + unsigned int ploidy; + char *genotypes; + char *header; + char *record; + char *vcf_genotypes; + size_t vcf_genotypes_size; + size_t contig_id_size; + size_t record_size; + size_t num_sites; + unsigned long contig_length; + unsigned long *positions; + tsk_vargen_t *vargen; +} tsk_vcf_converter_t; + +int tsk_vcf_converter_alloc(tsk_vcf_converter_t *self, + tsk_treeseq_t *tree_sequence, unsigned int ploidy, const char *chrom); +int tsk_vcf_converter_get_header(tsk_vcf_converter_t *self, char **header); +int tsk_vcf_converter_next(tsk_vcf_converter_t *self, char **record); +int tsk_vcf_converter_free(tsk_vcf_converter_t *self); +void tsk_vcf_converter_print_state(tsk_vcf_converter_t *self, FILE *out); + + +int tsk_convert_newick(tsk_tree_t *tree, tsk_id_t root, size_t precision, int flags, + size_t buffer_size, char *buffer); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/tsk_core.c b/c/tsk_core.c new file mode 100644 index 0000000000..b4a39b5fa6 --- /dev/null +++ b/c/tsk_core.c @@ -0,0 +1,448 @@ +#include +#include +#include +#include +#include + +#include +#include "tsk_core.h" + +#define UUID_NUM_BYTES 16 + +#if defined(_WIN32) + +#include +#include + +static int TSK_WARN_UNUSED +get_random_bytes(uint8_t *buf) +{ + /* Based on CPython's code in bootstrap_hash.c */ + int ret = TSK_ERR_GENERATE_UUID; + HCRYPTPROV hCryptProv = NULL; + + if (!CryptAcquireContext(&hCryptProv, NULL, NULL, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) { + goto out; + } + if (!CryptGenRandom(hCryptProv, (DWORD) UUID_NUM_BYTES, buf)) { + goto out; + } + if (!CryptReleaseContext(hCryptProv, 0)) { + hCryptProv = NULL; + goto out; + } + hCryptProv = NULL; + ret = 0; +out: + if (hCryptProv != NULL) { + CryptReleaseContext(hCryptProv, 0); + } + return ret; +} + +#else + +/* Assuming the existance of /dev/urandom on Unix platforms */ +static int TSK_WARN_UNUSED +get_random_bytes(uint8_t *buf) +{ + int ret = TSK_ERR_GENERATE_UUID; + FILE *f = fopen("/dev/urandom", "r"); + + if (f == NULL) { + goto out; + } + if (fread(buf, UUID_NUM_BYTES, 1, f) != 1) { + goto out; + } + if (fclose(f) != 0) { + goto out; + } + ret = 0; +out: + return ret; +} + +#endif + +/* Generate a new UUID4 using a system-generated source of randomness. + * Note that this function writes a NULL terminator to the end of this + * string, so that the total length of the buffer must be 37 bytes. + */ +int +tsk_generate_uuid(char *dest, int TSK_UNUSED(flags)) +{ + int ret = 0; + uint8_t buf[UUID_NUM_BYTES]; + const char *pattern = + "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x"; + + ret = get_random_bytes(buf); + if (ret != 0) { + goto out; + } + if (snprintf(dest, TSK_UUID_SIZE + 1, pattern, + buf[0], buf[1], buf[2], buf[3], + buf[4], buf[5], buf[6], buf[7], + buf[8], buf[9], buf[10], buf[11], + buf[12], buf[13], buf[14], buf[15]) < 0) { + ret = TSK_ERR_GENERATE_UUID; + goto out; + } +out: + return ret; +} +static const char * +tsk_strerror_internal(int err) +{ + const char *ret = "Unknown error"; + + switch (err) { + case 0: + ret = "Normal exit condition. This is not an error!"; + goto out; + + /* General errors */ + case TSK_ERR_GENERIC: + ret = "Generic error; please file a bug report"; + break; + case TSK_ERR_NO_MEMORY: + ret = "Out of memory."; + break; + case TSK_ERR_IO: + if (errno != 0) { + ret = strerror(errno); + } else { + ret = "Unspecified IO error"; + } + break; + case TSK_ERR_BAD_PARAM_VALUE: + ret = "Bad parameter value provided"; + break; + case TSK_ERR_BUFFER_OVERFLOW: + ret = "Supplied buffer is too small."; + break; + case TSK_ERR_UNSUPPORTED_OPERATION: + ret = "Operation cannot be performed in current configuration"; + break; + case TSK_ERR_GENERATE_UUID: + ret = "Error generating UUID"; + break; + + /* File format errors */ + case TSK_ERR_FILE_FORMAT: + ret = "File format error"; + break; + case TSK_ERR_FILE_VERSION_TOO_OLD: + ret = "tskit file version too old. Please upgrade using the " + "'tskit upgrade' command"; + break; + case TSK_ERR_FILE_VERSION_TOO_NEW: + ret = "tskit file version is too new for this instance. " + "Please upgrade tskit to the latest version."; + break; + + /* Out of bounds errors */ + case TSK_ERR_BAD_OFFSET: + ret = "Bad offset provided in input array."; + break; + case TSK_ERR_OUT_OF_BOUNDS: + ret = "Object reference out of bounds"; + break; + case TSK_ERR_NODE_OUT_OF_BOUNDS: + ret = "Node out of bounds"; + break; + case TSK_ERR_EDGE_OUT_OF_BOUNDS: + ret = "Edge out of bounds"; + break; + case TSK_ERR_POPULATION_OUT_OF_BOUNDS: + ret = "Population out of bounds"; + break; + case TSK_ERR_SITE_OUT_OF_BOUNDS: + ret = "Site out of bounds"; + break; + case TSK_ERR_MUTATION_OUT_OF_BOUNDS: + ret = "Mutation out of bounds"; + break; + case TSK_ERR_MIGRATION_OUT_OF_BOUNDS: + ret = "Migration out of bounds"; + break; + case TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS: + ret = "Individual out of bounds"; + break; + case TSK_ERR_PROVENANCE_OUT_OF_BOUNDS: + ret = "Provenance out of bounds"; + break; + + /* Edge errors */ + case TSK_ERR_NULL_PARENT: + ret = "Edge in parent is null."; + break; + case TSK_ERR_NULL_CHILD: + ret = "Edge in parent is null."; + break; + case TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME: + ret = "Edges must be listed in (time[parent], child, left) order;" + " time[parent] order violated"; + break; + case TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS: + ret = "All edges for a given parent must be contiguous"; + break; + case TSK_ERR_EDGES_NOT_SORTED_CHILD: + ret = "Edges must be listed in (time[parent], child, left) order;" + " child order violated"; + break; + case TSK_ERR_EDGES_NOT_SORTED_LEFT: + ret = "Edges must be listed in (time[parent], child, left) order;" + " left order violated"; + break; + case TSK_ERR_BAD_NODE_TIME_ORDERING: + ret = "time[parent] must be greater than time[child]"; + break; + case TSK_ERR_BAD_EDGE_INTERVAL: + ret = "Bad edge interval where right <= left"; + break; + case TSK_ERR_DUPLICATE_EDGES: + ret = "Duplicate edges provided."; + break; + case TSK_ERR_RIGHT_GREATER_SEQ_LENGTH: + ret = "Right coordinate > sequence length."; + break; + case TSK_ERR_LEFT_LESS_ZERO: + ret = "Left coordinate must be >= 0"; + break; + case TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN: + ret = "Bad edges: contradictory children for a given parent over " + "an interval."; + break; + + /* Site errors */ + case TSK_ERR_UNSORTED_SITES: + ret = "Sites must be provided in strictly increasing position order."; + break; + case TSK_ERR_DUPLICATE_SITE_POSITION: + ret = "Duplicate site positions."; + break; + case TSK_ERR_BAD_SITE_POSITION: + ret = "Sites positions must be between 0 and sequence_length"; + break; + + /* Mutation errors */ + case TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE: + ret = "Specified parent mutation is at a different site."; + break; + case TSK_ERR_MUTATION_PARENT_EQUAL: + ret = "Parent mutation refers to itself."; + break; + case TSK_ERR_MUTATION_PARENT_AFTER_CHILD: + ret = "Parent mutation ID must be < current ID."; + break; + case TSK_ERR_TOO_MANY_ALLELES: + ret = "Cannot have more than 255 alleles."; + break; + case TSK_ERR_INCONSISTENT_MUTATIONS: + ret = "Inconsistent mutations: state already equal to derived state."; + break; + case TSK_ERR_NON_SINGLE_CHAR_MUTATION: + ret = "Only single char mutations supported."; + break; + case TSK_ERR_UNSORTED_MUTATIONS: + ret = "Mutations must be provided in non-decreasing site order"; + break; + + /* Sample errors */ + case TSK_ERR_DUPLICATE_SAMPLE: + ret = "Duplicate value provided in tracked leaf list."; + break; + case TSK_ERR_BAD_SAMPLES: + ret = "Bad sample configuration provided."; + break; + + /* Table errors */ + case TSK_ERR_BAD_TABLE_POSITION: + ret = "Bad table position provided to truncate/reset."; + break; + case TSK_ERR_BAD_SEQUENCE_LENGTH: + ret = "Sequence length must be > 0."; + break; + case TSK_ERR_TABLES_NOT_INDEXED: + ret = "Table collection must be indexed."; + break; + + /* Limitations */ + case TSK_ERR_ONLY_INFINITE_SITES: + ret = "Only infinite sites mutations are supported for this operation."; + break; + case TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED: + ret = "Migrations not currently supported by simplify."; + break; + case TSK_ERR_NONBINARY_MUTATIONS_UNSUPPORTED: + ret = "Only binary mutations are supported for this operation."; + break; + } +out: + return ret; +} + +int +tsk_set_kas_error(int err) +{ + /* Flip this bit. As the error is negative, this sets the bit to 0 */ + return err ^ (1 << TSK_KAS_ERR_BIT); +} + +bool +tsk_is_kas_error(int err) +{ + return !(err & (1 << TSK_KAS_ERR_BIT)); +} + +const char * +tsk_strerror(int err) +{ + if (tsk_is_kas_error(err)) { + err ^= (1 << TSK_KAS_ERR_BIT); + return kas_strerror(err); + } else { + return tsk_strerror_internal(err); + } +} + +void +__tsk_safe_free(void **ptr) { + if (ptr != NULL) { + if (*ptr != NULL) { + free(*ptr); + *ptr = NULL; + } + } +} + + +/* Block allocator. Simple allocator when we lots of chunks of memory + * and don't need to free them individually. + */ + +void +tsk_blkalloc_print_state(tsk_blkalloc_t *self, FILE *out) +{ + fprintf(out, "Block allocator%p::\n", (void *) self); + fprintf(out, "\ttop = %d\n", (int) self->top); + fprintf(out, "\tchunk_size = %d\n", (int) self->chunk_size); + fprintf(out, "\tnum_chunks = %d\n", (int) self->num_chunks); + fprintf(out, "\ttotal_allocated = %d\n", (int) self->total_allocated); + fprintf(out, "\ttotal_size = %d\n", (int) self->total_size); +} + +int TSK_WARN_UNUSED +tsk_blkalloc_reset(tsk_blkalloc_t *self) +{ + int ret = 0; + + self->top = 0; + self->current_chunk = 0; + self->total_allocated = 0; + return ret; +} + +int TSK_WARN_UNUSED +tsk_blkalloc_alloc(tsk_blkalloc_t *self, size_t chunk_size) +{ + int ret = 0; + + assert(chunk_size > 0); + memset(self, 0, sizeof(tsk_blkalloc_t)); + self->chunk_size = chunk_size; + self->top = 0; + self->current_chunk = 0; + self->total_allocated = 0; + self->total_size = 0; + self->num_chunks = 0; + self->mem_chunks = malloc(sizeof(char *)); + if (self->mem_chunks == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->mem_chunks[0] = malloc(chunk_size); + if (self->mem_chunks[0] == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->num_chunks = 1; + self->total_size = chunk_size + sizeof(void *); +out: + return ret; +} + +void * TSK_WARN_UNUSED +tsk_blkalloc_get(tsk_blkalloc_t *self, size_t size) +{ + void *ret = NULL; + void *p; + + assert(size < self->chunk_size); + if ((self->top + size) > self->chunk_size) { + if (self->current_chunk == (self->num_chunks - 1)) { + p = realloc(self->mem_chunks, (self->num_chunks + 1) * sizeof(void *)); + if (p == NULL) { + goto out; + } + self->mem_chunks = p; + p = malloc(self->chunk_size); + if (p == NULL) { + goto out; + } + self->mem_chunks[self->num_chunks] = p; + self->num_chunks++; + self->total_size += self->chunk_size + sizeof(void *); + } + self->current_chunk++; + self->top = 0; + } + ret = self->mem_chunks[self->current_chunk] + self->top; + self->top += size; + self->total_allocated += size; +out: + return ret; +} + +void +tsk_blkalloc_free(tsk_blkalloc_t *self) +{ + size_t j; + + for (j = 0; j < self->num_chunks; j++) { + if (self->mem_chunks[j] != NULL) { + free(self->mem_chunks[j]); + } + } + if (self->mem_chunks != NULL) { + free(self->mem_chunks); + } +} + +/* Mirrors the semantics of numpy's searchsorted function. Uses binary + * search to find the index of the closest value in the array. */ +size_t +tsk_search_sorted(const double *restrict array, size_t size, double value) +{ + int64_t upper = (int64_t) size; + int64_t lower = 0; + int64_t offset = 0; + int64_t mid; + + if (upper == 0) { + return 0; + } + + while (upper - lower > 1) { + mid = (upper + lower) / 2; + if (value >= array[mid]) { + lower = mid; + } else { + upper = mid; + } + } + offset = (int64_t) (array[lower] < value); + return (size_t) (lower + offset); +} diff --git a/c/tsk_core.h b/c/tsk_core.h new file mode 100644 index 0000000000..59e6bbcbc1 --- /dev/null +++ b/c/tsk_core.h @@ -0,0 +1,184 @@ +#ifndef __TSK_CORE_H__ +#define __TSK_CORE_H__ + +#include + +#ifdef __GNUC__ + #define TSK_WARN_UNUSED __attribute__ ((warn_unused_result)) + #define TSK_UNUSED(x) TSK_UNUSED_ ## x __attribute__((__unused__)) +#else + #define TSK_WARN_UNUSED + #define TSK_UNUSED(x) TSK_UNUSED_ ## x + /* Don't bother with restrict for MSVC */ + #define restrict +#endif + +/* This sets up TSK_DBL_DECIMAL_DIG, which can then be used as a + * precision specifier when writing out doubles, if you want sufficient + * decimal digits to be written to guarantee a lossless round-trip + * after being read back in. Usage: + * + * printf("%.*g", TSK_DBL_DECIMAL_DIG, foo); + * + * See https://stackoverflow.com/a/19897395/2752221 + */ +#ifdef DBL_DECIMAL_DIG +#define TSK_DBL_DECIMAL_DIG (DBL_DECIMAL_DIG) +#else +#define TSK_DBL_DECIMAL_DIG (DBL_DIG + 3) +#endif + + +/* Node flags */ +#define TSK_NODE_IS_SAMPLE 1u + +/* The null ID */ +#define TSK_NULL (-1) + +/* Flags for simplify() */ +#define TSK_FILTER_SITES (1 << 0) +#define TSK_REDUCE_TO_SITE_TOPOLOGY (1 << 1) +#define TSK_FILTER_POPULATIONS (1 << 2) +#define TSK_FILTER_INDIVIDUALS (1 << 3) + +/* Flags for check_integrity */ +#define TSK_CHECK_OFFSETS (1 << 0) +#define TSK_CHECK_EDGE_ORDERING (1 << 1) +#define TSK_CHECK_SITE_ORDERING (1 << 2) +#define TSK_CHECK_SITE_DUPLICATES (1 << 3) +#define TSK_CHECK_MUTATION_ORDERING (1 << 4) +#define TSK_CHECK_INDEXES (1 << 5) +#define TSK_CHECK_ALL \ + (TSK_CHECK_OFFSETS | TSK_CHECK_EDGE_ORDERING | TSK_CHECK_SITE_ORDERING | \ + TSK_CHECK_SITE_DUPLICATES | TSK_CHECK_MUTATION_ORDERING | TSK_CHECK_INDEXES) + +/* Flags for dump tables */ +/* #define TSK_ALLOC_TABLES 1 */ + +/* Flags for load tables */ +#define TSK_BUILD_INDEXES 1 + +/* Generic debug flag shared across all calls. Uses + * the top bit to avoid clashes with other flags. */ +#define TSK_DEBUG (1 << 31) + +#define TSK_LOAD_EXTENDED_CHECKS 1 + +#define TSK_FILE_FORMAT_NAME "tskit.trees" +#define TSK_FILE_FORMAT_NAME_LENGTH 11 +#define TSK_FILE_FORMAT_VERSION_MAJOR 12 +#define TSK_FILE_FORMAT_VERSION_MINOR 0 + +/* Error codes */ + +/* General errrors */ +#define TSK_ERR_GENERIC -1 +#define TSK_ERR_NO_MEMORY -2 +#define TSK_ERR_IO -3 +#define TSK_ERR_BAD_PARAM_VALUE -4 +#define TSK_ERR_BUFFER_OVERFLOW -5 +#define TSK_ERR_UNSUPPORTED_OPERATION -6 +#define TSK_ERR_GENERATE_UUID -7 + +/* File format errors */ +#define TSK_ERR_FILE_FORMAT -100 +#define TSK_ERR_FILE_VERSION_TOO_OLD -101 +#define TSK_ERR_FILE_VERSION_TOO_NEW -102 + +/* Out of bounds errors */ +#define TSK_ERR_BAD_OFFSET -200 +#define TSK_ERR_OUT_OF_BOUNDS -201 +#define TSK_ERR_NODE_OUT_OF_BOUNDS -202 +#define TSK_ERR_EDGE_OUT_OF_BOUNDS -203 +#define TSK_ERR_POPULATION_OUT_OF_BOUNDS -204 +#define TSK_ERR_SITE_OUT_OF_BOUNDS -205 +#define TSK_ERR_MUTATION_OUT_OF_BOUNDS -206 +#define TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS -207 +#define TSK_ERR_MIGRATION_OUT_OF_BOUNDS -208 +#define TSK_ERR_PROVENANCE_OUT_OF_BOUNDS -209 + +/* Edge errors */ +#define TSK_ERR_NULL_PARENT -300 +#define TSK_ERR_NULL_CHILD -301 +#define TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME -302 +#define TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS -303 +#define TSK_ERR_EDGES_NOT_SORTED_CHILD -304 +#define TSK_ERR_EDGES_NOT_SORTED_LEFT -305 +#define TSK_ERR_BAD_NODE_TIME_ORDERING -306 +#define TSK_ERR_BAD_EDGE_INTERVAL -307 +#define TSK_ERR_DUPLICATE_EDGES -308 +#define TSK_ERR_RIGHT_GREATER_SEQ_LENGTH -309 +#define TSK_ERR_LEFT_LESS_ZERO -310 +#define TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN -311 + +/* Site errors */ +#define TSK_ERR_UNSORTED_SITES -400 +#define TSK_ERR_DUPLICATE_SITE_POSITION -401 +#define TSK_ERR_BAD_SITE_POSITION -402 + +/* Mutation errors */ +#define TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE -500 +#define TSK_ERR_MUTATION_PARENT_EQUAL -501 +#define TSK_ERR_MUTATION_PARENT_AFTER_CHILD -502 +#define TSK_ERR_TOO_MANY_ALLELES -503 +#define TSK_ERR_INCONSISTENT_MUTATIONS -504 +#define TSK_ERR_NON_SINGLE_CHAR_MUTATION -505 +#define TSK_ERR_UNSORTED_MUTATIONS -506 + +/* Sample errors */ +#define TSK_ERR_DUPLICATE_SAMPLE -600 +#define TSK_ERR_BAD_SAMPLES -601 + +/* Table errors */ +#define TSK_ERR_BAD_TABLE_POSITION -700 +#define TSK_ERR_BAD_SEQUENCE_LENGTH -701 +#define TSK_ERR_TABLES_NOT_INDEXED -702 + +/* Limitations */ +#define TSK_ERR_ONLY_INFINITE_SITES -800 +#define TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED -801 +#define TSK_ERR_NONBINARY_MUTATIONS_UNSUPPORTED -802 + +/* This bit is 0 for any errors originating from kastore */ +#define TSK_KAS_ERR_BIT 14 + +int tsk_set_kas_error(int err); +bool tsk_is_kas_error(int err); +const char * tsk_strerror(int err); +void __tsk_safe_free(void **ptr); + +#define tsk_safe_free(pointer) __tsk_safe_free((void **) &(pointer)) +#define TSK_MAX(a,b) ((a) > (b) ? (a) : (b)) +#define TSK_MIN(a,b) ((a) < (b) ? (a) : (b)) + +/* This is a simple allocator that is optimised to efficiently allocate a + * large number of small objects without large numbers of calls to malloc. + * The allocator mallocs memory in chunks of a configurable size. When + * responding to calls to get(), it will return a chunk of this memory. + * This memory cannot be subsequently handed back to the allocator. However, + * all memory allocated by the allocator can be returned at once by calling + * reset. + */ + +typedef struct { + size_t chunk_size; /* number of bytes per chunk */ + size_t top; /* the offset of the next available byte in the current chunk */ + size_t current_chunk; /* the index of the chunk currently being used */ + size_t total_size; /* the total number of bytes allocated + overhead. */ + size_t total_allocated; /* the total number of bytes allocated. */ + size_t num_chunks; /* the number of memory chunks. */ + char **mem_chunks; /* the memory chunks */ +} tsk_blkalloc_t; + +extern void tsk_blkalloc_print_state(tsk_blkalloc_t *self, FILE *out); +extern int tsk_blkalloc_reset(tsk_blkalloc_t *self); +extern int tsk_blkalloc_alloc(tsk_blkalloc_t *self, size_t chunk_size); +extern void * tsk_blkalloc_get(tsk_blkalloc_t *self, size_t size); +extern void tsk_blkalloc_free(tsk_blkalloc_t *self); + +size_t tsk_search_sorted(const double *array, size_t size, double value); + +#define TSK_UUID_SIZE 36 +int tsk_generate_uuid(char *dest, int flags); + +#endif diff --git a/c/tsk_genotypes.c b/c/tsk_genotypes.c new file mode 100644 index 0000000000..1034afeffc --- /dev/null +++ b/c/tsk_genotypes.c @@ -0,0 +1,609 @@ +#include +#include +#include +#include +#include + +#include "tsk_genotypes.h" + + +/* ======================================================== * + * Haplotype generator + * ======================================================== */ + +/* Ensure the tree is in a consistent state */ +static void +tsk_hapgen_check_state(tsk_hapgen_t *TSK_UNUSED(self)) +{ + /* TODO some checks! */ +} + +void +tsk_hapgen_print_state(tsk_hapgen_t *self, FILE *out) +{ + size_t j; + + fprintf(out, "Hapgen state\n"); + fprintf(out, "num_samples = %d\n", (int) self->num_samples); + fprintf(out, "num_sites = %d\n", (int) self->num_sites); + fprintf(out, "haplotype matrix\n"); + for (j = 0; j < self->num_samples; j++) { + fprintf(out, "%s\n", + self->haplotype_matrix + (j * (self->num_sites + 1))); + } + tsk_hapgen_check_state(self); +} + + +static inline int TSK_WARN_UNUSED +tsk_hapgen_update_sample(tsk_hapgen_t * self, size_t sample_index, tsk_id_t site, + const char *derived_state) +{ + int ret = 0; + size_t index = sample_index * (self->num_sites + 1) + (size_t) site; + + if (self->haplotype_matrix[index] == derived_state[0]) { + ret = TSK_ERR_INCONSISTENT_MUTATIONS; + goto out; + } + self->haplotype_matrix[index] = derived_state[0]; +out: + return ret; +} + +static int +tsk_hapgen_apply_tree_site(tsk_hapgen_t *self, tsk_site_t *site) +{ + int ret = 0; + const tsk_id_t *restrict list_left = self->tree.left_sample; + const tsk_id_t *restrict list_right = self->tree.right_sample; + const tsk_id_t *restrict list_next = self->tree.next_sample; + tsk_id_t node, index, stop; + tsk_tbl_size_t j; + const char *derived_state; + + for (j = 0; j < site->mutations_length; j++) { + if (site->mutations[j].derived_state_length != 1) { + ret = TSK_ERR_NON_SINGLE_CHAR_MUTATION; + goto out; + } + derived_state = site->mutations[j].derived_state; + node = site->mutations[j].node; + index = list_left[node]; + if (index != TSK_NULL) { + stop = list_right[node]; + while (true) { + ret = tsk_hapgen_update_sample(self, (size_t) index, site->id, derived_state); + if (ret != 0) { + goto out; + } + if (index == stop) { + break; + } + index = list_next[index]; + } + } + } +out: + return ret; +} + +static int +tsk_hapgen_generate_all_haplotypes(tsk_hapgen_t *self) +{ + int ret = 0; + tsk_tbl_size_t j; + tsk_tbl_size_t num_sites = 0; + tsk_site_t *sites = NULL; + tsk_tree_t *t = &self->tree; + + for (ret = tsk_tree_first(t); ret == 1; ret = tsk_tree_next(t)) { + ret = tsk_tree_get_sites(t, &sites, &num_sites); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_sites; j++) { + ret = tsk_hapgen_apply_tree_site(self, &sites[j]); + if (ret != 0) { + goto out; + } + } + } +out: + return ret; +} + +int +tsk_hapgen_alloc(tsk_hapgen_t *self, tsk_treeseq_t *tree_sequence) +{ + int ret = 0; + size_t j, k; + tsk_site_t site; + + assert(tree_sequence != NULL); + memset(self, 0, sizeof(tsk_hapgen_t)); + self->num_samples = tsk_treeseq_get_num_samples(tree_sequence); + self->sequence_length = tsk_treeseq_get_sequence_length(tree_sequence); + self->num_sites = tsk_treeseq_get_num_sites(tree_sequence); + self->tree_sequence = tree_sequence; + + ret = tsk_treeseq_get_sample_index_map(tree_sequence, &self->sample_index_map); + if (ret != 0) { + goto out; + } + ret = tsk_tree_alloc(&self->tree, tree_sequence, TSK_SAMPLE_LISTS); + if (ret != 0) { + goto out; + } + self->haplotype_matrix = malloc( + self->num_samples * (self->num_sites + 1) * sizeof(char)); + if (self->haplotype_matrix == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* Set the NULL string ends. */ + for (j = 0; j < self->num_samples; j++) { + self->haplotype_matrix[ + (j + 1) * (self->num_sites + 1) - 1] = '\0'; + } + /* For each site set the ancestral type */ + for (k = 0; k < self->num_sites; k++) { + ret = tsk_treeseq_get_site(self->tree_sequence, k, &site); + if (ret != 0) { + goto out; + } + if (site.ancestral_state_length != 1) { + ret = TSK_ERR_NON_SINGLE_CHAR_MUTATION; + goto out; + } + for (j = 0; j < self->num_samples; j++) { + self->haplotype_matrix[j * (self->num_sites + 1) + k] = + site.ancestral_state[0]; + } + } + ret = tsk_hapgen_generate_all_haplotypes(self); +out: + return ret; +} + +int +tsk_hapgen_free(tsk_hapgen_t *self) +{ + tsk_safe_free(self->output_haplotype); + tsk_safe_free(self->haplotype_matrix); + tsk_tree_free(&self->tree); + return 0; +} + +int +tsk_hapgen_get_haplotype(tsk_hapgen_t *self, tsk_id_t sample_index, char **haplotype) +{ + int ret = 0; + + if (sample_index >= (tsk_id_t) self->num_samples) { + ret = TSK_ERR_OUT_OF_BOUNDS; + goto out; + } + *haplotype = self->haplotype_matrix + ((size_t) sample_index) * (self->num_sites + 1); +out: + return ret; +} + +/* ======================================================== * + * Variant generator + * ======================================================== */ + +void +tsk_vargen_print_state(tsk_vargen_t *self, FILE *out) +{ + fprintf(out, "tsk_vargen state\n"); + fprintf(out, "tree_site_index = %d\n", (int) self->tree_site_index); +} + +static int +tsk_vargen_next_tree(tsk_vargen_t *self) +{ + int ret = 0; + + ret = tsk_tree_next(&self->tree); + if (ret == 0) { + self->finished = 1; + } else if (ret < 0) { + goto out; + } + self->tree_site_index = 0; +out: + return ret; +} + +int +tsk_vargen_alloc(tsk_vargen_t *self, tsk_treeseq_t *tree_sequence, + tsk_id_t *samples, size_t num_samples, int flags) +{ + int ret = TSK_ERR_NO_MEMORY; + int tree_flags; + size_t j, num_nodes, num_samples_alloc; + tsk_tbl_size_t max_alleles = 4; + + assert(tree_sequence != NULL); + memset(self, 0, sizeof(tsk_vargen_t)); + + if (samples == NULL) { + self->num_samples = tsk_treeseq_get_num_samples(tree_sequence); + num_samples_alloc = self->num_samples; + } else { + /* Take a copy of the samples for simplicity */ + num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); + /* We can have num_samples = 0 here, so guard against malloc(0) */ + num_samples_alloc = num_samples + 1; + self->samples = malloc(num_samples_alloc * sizeof(*self->samples)); + self->sample_index_map = malloc(num_nodes * sizeof(*self->sample_index_map)); + if (self->samples == NULL || self->sample_index_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->samples, samples, num_samples * sizeof(*self->samples)); + memset(self->sample_index_map, 0xff, num_nodes * sizeof(*self->sample_index_map)); + /* Create the reverse mapping */ + for (j = 0; j < num_samples; j++) { + if (samples[j] < 0 || samples[j] >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_OUT_OF_BOUNDS; + goto out; + } + if (self->sample_index_map[samples[j]] != TSK_NULL) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + self->sample_index_map[samples[j]] = (tsk_id_t) j; + } + self->num_samples = num_samples; + } + self->num_sites = tsk_treeseq_get_num_sites(tree_sequence); + self->tree_sequence = tree_sequence; + self->flags = flags; + if (self->flags & TSK_16_BIT_GENOTYPES) { + self->variant.genotypes.u16 = malloc( + num_samples_alloc * sizeof(*self->variant.genotypes.u16)); + } else { + self->variant.genotypes.u8 = malloc( + num_samples_alloc * sizeof(*self->variant.genotypes.u8)); + } + self->variant.max_alleles = max_alleles; + self->variant.alleles = malloc(max_alleles * sizeof(*self->variant.alleles)); + self->variant.allele_lengths = malloc(max_alleles + * sizeof(*self->variant.allele_lengths)); + /* Because genotypes is a union we can check the pointer */ + if (self->variant.genotypes.u8 == NULL || self->variant.alleles == NULL + || self->variant.allele_lengths == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* When a list of samples is given, we use the traversal based algorithm + * and turn off the sample list tracking in the tree */ + tree_flags = 0; + if (self->samples == NULL) { + tree_flags = TSK_SAMPLE_LISTS; + } + ret = tsk_tree_alloc(&self->tree, tree_sequence, tree_flags); + if (ret != 0) { + goto out; + } + self->finished = 0; + self->tree_site_index = 0; + ret = tsk_tree_first(&self->tree); + if (ret < 0) { + goto out; + } + ret = 0; +out: + return ret; +} + +int +tsk_vargen_free(tsk_vargen_t *self) +{ + tsk_tree_free(&self->tree); + tsk_safe_free(self->variant.genotypes.u8); + tsk_safe_free(self->variant.alleles); + tsk_safe_free(self->variant.allele_lengths); + tsk_safe_free(self->samples); + tsk_safe_free(self->sample_index_map); + return 0; +} + +static int +tsk_vargen_expand_alleles(tsk_vargen_t *self) +{ + int ret = 0; + tsk_variant_t *var = &self->variant; + void *p; + tsk_tbl_size_t hard_limit = UINT8_MAX; + + if (self->flags & TSK_16_BIT_GENOTYPES) { + hard_limit = UINT16_MAX; + } + if (var->max_alleles == hard_limit) { + ret = TSK_ERR_TOO_MANY_ALLELES; + goto out; + } + var->max_alleles = TSK_MIN(hard_limit, var->max_alleles * 2); + p = realloc(var->alleles, var->max_alleles * sizeof(*var->alleles)); + if (p == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + var->alleles = p; + p = realloc(var->allele_lengths, var->max_alleles * sizeof(*var->allele_lengths)); + if (p == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + var->allele_lengths = p; +out: + return ret; +} + +/* The following pair of functions are identical except one handles 8 bit + * genotypes and the other handles 16 bit genotypes. This is done for performance + * reasons as this is a key function and for common alleles can entail + * iterating over millions of samples. The compiler hints are included for the + * same reason. + */ +static int TSK_WARN_UNUSED +tsk_vargen_update_genotypes_u8_sample_list(tsk_vargen_t *self, tsk_id_t node, tsk_tbl_size_t derived) +{ + uint8_t *restrict genotypes = self->variant.genotypes.u8; + const tsk_id_t *restrict list_left = self->tree.left_sample; + const tsk_id_t *restrict list_right = self->tree.right_sample; + const tsk_id_t *restrict list_next = self->tree.next_sample; + tsk_id_t index, stop; + int ret = 0; + + assert(derived < UINT8_MAX); + + index = list_left[node]; + if (index != TSK_NULL) { + stop = list_right[node]; + while (true) { + if (genotypes[index] == derived) { + ret = TSK_ERR_INCONSISTENT_MUTATIONS; + goto out; + } + genotypes[index] = (uint8_t) derived; + if (index == stop) { + break; + } + index = list_next[index]; + } + } +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_vargen_update_genotypes_u16_sample_list(tsk_vargen_t *self, tsk_id_t node, tsk_tbl_size_t derived) +{ + uint16_t *restrict genotypes = self->variant.genotypes.u16; + const tsk_id_t *restrict list_left = self->tree.left_sample; + const tsk_id_t *restrict list_right = self->tree.right_sample; + const tsk_id_t *restrict list_next = self->tree.next_sample; + tsk_id_t index, stop; + int ret = 0; + + assert(derived < UINT16_MAX); + + index = list_left[node]; + if (index != TSK_NULL) { + stop = list_right[node]; + while (true) { + if (genotypes[index] == derived) { + ret = TSK_ERR_INCONSISTENT_MUTATIONS; + goto out; + } + genotypes[index] = (uint16_t) derived; + if (index == stop) { + break; + } + index = list_next[index]; + } + } +out: + return ret; +} + +/* The following functions implement the genotype setting by traversing + * down the tree to the samples. We're not so worried about performance here + * because this should only be used when we have a very small number of samples, + * and so we use a visit function to avoid duplicating code. + */ + +typedef int (*visit_func_t)(tsk_vargen_t *, tsk_id_t, tsk_tbl_size_t); + +static int TSK_WARN_UNUSED +tsk_vargen_traverse(tsk_vargen_t *self, tsk_id_t node, tsk_tbl_size_t derived, visit_func_t visit) +{ + int ret = 0; + tsk_id_t * restrict stack = self->tree.stack1; + const tsk_id_t * restrict left_child = self->tree.left_child; + const tsk_id_t * restrict right_sib = self->tree.right_sib; + const tsk_id_t *restrict sample_index_map = self->sample_index_map; + tsk_id_t u, v, sample_index; + int stack_top; + + stack_top = 0; + stack[0] = node; + while (stack_top >= 0) { + u = stack[stack_top]; + sample_index = sample_index_map[u]; + if (sample_index != TSK_NULL) { + ret = visit(self, sample_index, derived); + if (ret != 0) { + goto out; + } + } + stack_top--; + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + stack_top++; + stack[stack_top] = v; + } + } +out: + return ret; +} + +static int +tsk_vargen_visit_u8(tsk_vargen_t *self, tsk_id_t sample_index, tsk_tbl_size_t derived) +{ + int ret = 0; + uint8_t *restrict genotypes = self->variant.genotypes.u8; + + assert(derived < UINT8_MAX); + assert(sample_index != -1); + if (genotypes[sample_index] == derived) { + ret = TSK_ERR_INCONSISTENT_MUTATIONS; + goto out; + } + genotypes[sample_index] = (uint8_t) derived; +out: + return ret; +} + +static int +tsk_vargen_visit_u16(tsk_vargen_t *self, tsk_id_t sample_index, tsk_tbl_size_t derived) +{ + int ret = 0; + uint16_t *restrict genotypes = self->variant.genotypes.u16; + + assert(derived < UINT16_MAX); + assert(sample_index != -1); + if (genotypes[sample_index] == derived) { + ret = TSK_ERR_INCONSISTENT_MUTATIONS; + goto out; + } + genotypes[sample_index] = (uint16_t) derived; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_vargen_update_genotypes_u8_traversal(tsk_vargen_t *self, tsk_id_t node, tsk_tbl_size_t derived) +{ + return tsk_vargen_traverse(self, node, derived, tsk_vargen_visit_u8); +} + +static int TSK_WARN_UNUSED +tsk_vargen_update_genotypes_u16_traversal(tsk_vargen_t *self, tsk_id_t node, tsk_tbl_size_t derived) +{ + return tsk_vargen_traverse(self, node, derived, tsk_vargen_visit_u16); +} + +static int +tsk_vargen_update_site(tsk_vargen_t *self) +{ + int ret = 0; + tsk_tbl_size_t j, derived; + tsk_variant_t *var = &self->variant; + tsk_site_t *site = var->site; + tsk_mutation_t mutation; + bool genotypes16 = !!(self->flags & TSK_16_BIT_GENOTYPES); + bool by_traversal = self->samples != NULL; + int (*update_genotypes)(tsk_vargen_t *, tsk_id_t, tsk_tbl_size_t); + + /* For now we use a traversal method to find genotypes when we have a + * specified set of samples, but we should provide the option to do it + * via tracked_samples in the tree also. There will be a tradeoff: if + * we only have a small number of samples, it's probably better to + * do it by traversal. For large sets of samples though, it'll be + * definitely better to use the sample list infrastructure. */ + if (genotypes16) { + update_genotypes = tsk_vargen_update_genotypes_u16_sample_list; + if (by_traversal) { + update_genotypes = tsk_vargen_update_genotypes_u16_traversal; + } + } else { + update_genotypes = tsk_vargen_update_genotypes_u8_sample_list; + if (by_traversal) { + update_genotypes = tsk_vargen_update_genotypes_u8_traversal; + } + } + + /* Ancestral state is always allele 0 */ + var->alleles[0] = site->ancestral_state; + var->allele_lengths[0] = site->ancestral_state_length; + var->num_alleles = 1; + + /* The algorithm for generating the allelic state of every sample works by + * examining each mutation in order, and setting the state for all the + * samples under the mutation's node. For complex sites where there is + * more than one mutation, we depend on the ordering of mutations being + * correct. Specifically, any mutation that is above another mutation in + * the tree must be visited first. This is enforced using the mutation.parent + * field, where we require that a mutation's parent must appear before it + * in the list of mutations. This guarantees the correctness of this algorithm. + */ + if (genotypes16) { + memset(self->variant.genotypes.u16, 0, 2 * self->num_samples); + } else { + memset(self->variant.genotypes.u8, 0, self->num_samples); + } + for (j = 0; j < site->mutations_length; j++) { + mutation = site->mutations[j]; + /* Compute the allele index for this derived state value. */ + derived = 0; + while (derived < var->num_alleles) { + if (mutation.derived_state_length == var->allele_lengths[derived] + && memcmp(mutation.derived_state, var->alleles[derived], + var->allele_lengths[derived]) == 0) { + break; + } + derived++; + } + if (derived == var->num_alleles) { + if (var->num_alleles == var->max_alleles) { + ret = tsk_vargen_expand_alleles(self); + if (ret != 0) { + goto out; + } + } + var->alleles[derived] = mutation.derived_state; + var->allele_lengths[derived] = mutation.derived_state_length; + var->num_alleles++; + } + ret = update_genotypes(self, mutation.node, derived); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +int +tsk_vargen_next(tsk_vargen_t *self, tsk_variant_t **variant) +{ + int ret = 0; + + bool not_done = true; + + if (!self->finished) { + while (not_done && self->tree_site_index == self->tree.sites_length) { + ret = tsk_vargen_next_tree(self); + if (ret < 0) { + goto out; + } + not_done = ret == 1; + } + if (not_done) { + self->variant.site = &self->tree.sites[self->tree_site_index]; + ret = tsk_vargen_update_site(self); + if (ret != 0) { + goto out; + } + self->tree_site_index++; + *variant = &self->variant; + ret = 1; + } + } +out: + return ret; +} diff --git a/c/tsk_genotypes.h b/c/tsk_genotypes.h new file mode 100644 index 0000000000..ebeb01df9f --- /dev/null +++ b/c/tsk_genotypes.h @@ -0,0 +1,64 @@ +#ifndef TSK_GENOTYPES_H +#define TSK_GENOTYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "tsk_trees.h" + +#define TSK_16_BIT_GENOTYPES 1 + +typedef struct { + size_t num_samples; + double sequence_length; + size_t num_sites; + tsk_treeseq_t *tree_sequence; + tsk_id_t *sample_index_map; + char *output_haplotype; + char *haplotype_matrix; + tsk_tree_t tree; +} tsk_hapgen_t; + +typedef struct { + tsk_site_t *site; + const char **alleles; + tsk_tbl_size_t *allele_lengths; + tsk_tbl_size_t num_alleles; + tsk_tbl_size_t max_alleles; + union { + uint8_t *u8; + uint16_t *u16; + } genotypes; +} tsk_variant_t; + +typedef struct { + size_t num_samples; + size_t num_sites; + tsk_treeseq_t *tree_sequence; + tsk_id_t *samples; + tsk_id_t *sample_index_map; + size_t tree_site_index; + int finished; + tsk_tree_t tree; + int flags; + tsk_variant_t variant; +} tsk_vargen_t; + +int tsk_hapgen_alloc(tsk_hapgen_t *self, tsk_treeseq_t *tree_sequence); +/* FIXME this is inconsistent with the tables API which uses size_t for + * IDs in functions. Not clear which is better */ +int tsk_hapgen_get_haplotype(tsk_hapgen_t *self, tsk_id_t j, char **haplotype); +int tsk_hapgen_free(tsk_hapgen_t *self); +void tsk_hapgen_print_state(tsk_hapgen_t *self, FILE *out); + +int tsk_vargen_alloc(tsk_vargen_t *self, tsk_treeseq_t *tree_sequence, + tsk_id_t *samples, size_t num_samples, int flags); +int tsk_vargen_next(tsk_vargen_t *self, tsk_variant_t **variant); +int tsk_vargen_free(tsk_vargen_t *self); +void tsk_vargen_print_state(tsk_vargen_t *self, FILE *out); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/tsk_stats.c b/c/tsk_stats.c new file mode 100644 index 0000000000..330c9eea05 --- /dev/null +++ b/c/tsk_stats.c @@ -0,0 +1,455 @@ +#include +#include +#include +#include +#include + +#include "tsk_stats.h" + +static void +tsk_ld_calc_check_state(tsk_ld_calc_t *self) +{ + uint32_t u; + uint32_t num_nodes = (uint32_t) tsk_treeseq_get_num_nodes(self->tree_sequence); + tsk_tree_t *tA = self->outer_tree; + tsk_tree_t *tB = self->inner_tree; + + assert(tA->index == tB->index); + + /* The inner tree's mark values should all be zero. */ + for (u = 0; u < num_nodes; u++) { + assert(tA->marked[u] == 0); + assert(tB->marked[u] == 0); + } +} + +void +tsk_ld_calc_print_state(tsk_ld_calc_t *self, FILE *out) +{ + fprintf(out, "tree_sequence = %p\n", (void *) self->tree_sequence); + fprintf(out, "outer tree index = %d\n", (int) self->outer_tree->index); + fprintf(out, "outer tree interval = (%f, %f)\n", + self->outer_tree->left, self->outer_tree->right); + fprintf(out, "inner tree index = %d\n", (int) self->inner_tree->index); + fprintf(out, "inner tree interval = (%f, %f)\n", + self->inner_tree->left, self->inner_tree->right); + tsk_ld_calc_check_state(self); +} + +int TSK_WARN_UNUSED +tsk_ld_calc_alloc(tsk_ld_calc_t *self, tsk_treeseq_t *tree_sequence) +{ + int ret = TSK_ERR_GENERIC; + + memset(self, 0, sizeof(tsk_ld_calc_t)); + self->tree_sequence = tree_sequence; + self->num_sites = tsk_treeseq_get_num_sites(tree_sequence); + self->outer_tree = malloc(sizeof(tsk_tree_t)); + self->inner_tree = malloc(sizeof(tsk_tree_t)); + if (self->outer_tree == NULL || self->inner_tree == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_tree_alloc(self->outer_tree, self->tree_sequence, + TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + if (ret != 0) { + goto out; + } + ret = tsk_tree_alloc(self->inner_tree, self->tree_sequence, + TSK_SAMPLE_COUNTS); + if (ret != 0) { + goto out; + } + ret = tsk_tree_first(self->outer_tree); + if (ret < 0) { + goto out; + } + ret = tsk_tree_first(self->inner_tree); + if (ret < 0) { + goto out; + } + ret = 0; +out: + return ret; +} + +int +tsk_ld_calc_free(tsk_ld_calc_t *self) +{ + if (self->inner_tree != NULL) { + tsk_tree_free(self->inner_tree); + free(self->inner_tree); + } + if (self->outer_tree != NULL) { + tsk_tree_free(self->outer_tree); + free(self->outer_tree); + } + return 0; +} + +/* Position the two trees so that the specified site is within their + * interval. + */ +static int TSK_WARN_UNUSED +tsk_ld_calc_position_trees(tsk_ld_calc_t *self, size_t site_index) +{ + int ret = TSK_ERR_GENERIC; + tsk_site_t mut; + double x; + tsk_tree_t *tA = self->outer_tree; + tsk_tree_t *tB = self->inner_tree; + + ret = tsk_treeseq_get_site(self->tree_sequence, site_index, &mut); + if (ret != 0) { + goto out; + } + x = mut.position; + assert(tA->index == tB->index); + while (x >= tA->right) { + ret = tsk_tree_next(tA); + if (ret < 0) { + goto out; + } + assert(ret == 1); + ret = tsk_tree_next(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + while (x < tA->left) { + ret = tsk_tree_prev(tA); + if (ret < 0) { + goto out; + } + assert(ret == 1); + ret = tsk_tree_prev(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + ret = 0; + assert(x >= tA->left && x < tB->right); + assert(tA->index == tB->index); +out: + return ret; +} + +static double +tsk_ld_calc_overlap_within_tree(tsk_ld_calc_t *self, tsk_site_t sA, tsk_site_t sB) +{ + const tsk_tree_t *t = self->inner_tree; + const tsk_node_tbl_t *nodes = self->tree_sequence->tables->nodes; + tsk_id_t u, v, nAB; + + assert(sA.mutations_length == 1); + assert(sB.mutations_length == 1); + u = sA.mutations[0].node; + v = sB.mutations[0].node; + if (nodes->time[u] > nodes->time[v]) { + v = sA.mutations[0].node; + u = sB.mutations[0].node; + } + while (u != v && u != TSK_NULL) { + u = t->parent[u]; + } + nAB = 0; + if (u == v) { + nAB = TSK_MIN(t->num_samples[sA.mutations[0].node], t->num_samples[sB.mutations[0].node]); + } + return (double) nAB; +} + +static inline int TSK_WARN_UNUSED +tsk_ld_calc_set_tracked_samples(tsk_ld_calc_t *self, tsk_site_t sA) +{ + int ret = 0; + + assert(sA.mutations_length == 1); + ret = tsk_tree_set_tracked_samples_from_sample_list(self->inner_tree, + self->outer_tree, sA.mutations[0].node); + return ret; +} + +static int TSK_WARN_UNUSED +tsk_ld_calc_get_r2_array_forward(tsk_ld_calc_t *self, size_t source_index, + size_t max_sites, double max_distance, double *r2, + size_t *num_r2_values) +{ + int ret = TSK_ERR_GENERIC; + tsk_site_t sA, sB; + double fA, fB, fAB, D; + int tracked_samples_set = 0; + tsk_tree_t *tA, *tB; + double n = (double) tsk_treeseq_get_num_samples(self->tree_sequence); + size_t j; + double nAB; + + tA = self->outer_tree; + tB = self->inner_tree; + ret = tsk_treeseq_get_site(self->tree_sequence, source_index, &sA); + if (ret != 0) { + goto out; + } + if (sA.mutations_length > 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + fA = ((double) tA->num_samples[sA.mutations[0].node]) / n; + assert(fA > 0); + tB->mark = 1; + for (j = 0; j < max_sites; j++) { + if (source_index + j + 1 >= self->num_sites) { + break; + } + ret = tsk_treeseq_get_site(self->tree_sequence, (source_index + j + 1), &sB); + if (ret != 0) { + goto out; + } + if (sB.mutations_length > 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + if (sB.position - sA.position > max_distance) { + break; + } + while (sB.position >= tB->right) { + ret = tsk_tree_next(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + fB = ((double) tB->num_samples[sB.mutations[0].node]) / n; + assert(fB > 0); + if (sB.position < tA->right) { + nAB = tsk_ld_calc_overlap_within_tree(self, sA, sB); + } else { + if (!tracked_samples_set && tB->marked[sA.mutations[0].node] == 1) { + tracked_samples_set = 1; + ret = tsk_ld_calc_set_tracked_samples(self, sA); + if (ret != 0) { + goto out; + } + } + if (tracked_samples_set) { + nAB = (double)tB->num_tracked_samples[sB.mutations[0].node]; + } else { + nAB = tsk_ld_calc_overlap_within_tree(self, sA, sB); + } + } + fAB = nAB / n; + D = fAB - fA * fB; + r2[j] = D * D / (fA * fB * (1 - fA) * (1 - fB)); + } + + /* Now rewind back the inner iterator and unmark all nodes that + * were set to 1 as we moved forward. */ + tB->mark = 0; + while (tB->index > tA->index) { + ret = tsk_tree_prev(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + *num_r2_values = j; + ret = 0; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_ld_calc_get_r2_array_reverse(tsk_ld_calc_t *self, size_t source_index, + size_t max_sites, double max_distance, double *r2, + size_t *num_r2_values) +{ + int ret = TSK_ERR_GENERIC; + tsk_site_t sA, sB; + double fA, fB, fAB, D; + int tracked_samples_set = 0; + tsk_tree_t *tA, *tB; + double n = (double) tsk_treeseq_get_num_samples(self->tree_sequence); + size_t j; + double nAB; + int64_t site_index; + + tA = self->outer_tree; + tB = self->inner_tree; + ret = tsk_treeseq_get_site(self->tree_sequence, source_index, &sA); + if (ret != 0) { + goto out; + } + if (sA.mutations_length > 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + fA = ((double) tA->num_samples[sA.mutations[0].node]) / n; + assert(fA > 0); + tB->mark = 1; + for (j = 0; j < max_sites; j++) { + site_index = ((int64_t) source_index) - ((int64_t) j) - 1; + if (site_index < 0) { + break; + } + ret = tsk_treeseq_get_site(self->tree_sequence, (size_t) site_index, &sB); + if (ret != 0) { + goto out; + } + if (sB.mutations_length > 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + if (sA.position - sB.position > max_distance) { + break; + } + while (sB.position < tB->left) { + ret = tsk_tree_prev(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + fB = ((double) tB->num_samples[sB.mutations[0].node]) / n; + assert(fB > 0); + if (sB.position >= tA->left) { + nAB = tsk_ld_calc_overlap_within_tree(self, sA, sB); + } else { + if (!tracked_samples_set && tB->marked[sA.mutations[0].node] == 1) { + tracked_samples_set = 1; + ret = tsk_ld_calc_set_tracked_samples(self, sA); + if (ret != 0) { + goto out; + } + } + if (tracked_samples_set) { + nAB = (double) tB->num_tracked_samples[sB.mutations[0].node]; + } else { + nAB = tsk_ld_calc_overlap_within_tree(self, sA, sB); + } + } + fAB = nAB / n; + D = fAB - fA * fB; + r2[j] = D * D / (fA * fB * (1 - fA) * (1 - fB)); + } + + /* Now fast forward the inner iterator and unmark all nodes that + * were set to 1 as we moved back. */ + tB->mark = 0; + while (tB->index < tA->index) { + ret = tsk_tree_next(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + *num_r2_values = j; + ret = 0; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ld_calc_get_r2_array(tsk_ld_calc_t *self, size_t a, int direction, + size_t max_sites, double max_distance, double *r2, + size_t *num_r2_values) +{ + int ret = TSK_ERR_GENERIC; + + if (a >= self->num_sites) { + ret = TSK_ERR_OUT_OF_BOUNDS; + goto out; + } + ret = tsk_ld_calc_position_trees(self, a); + if (ret != 0) { + goto out; + } + if (direction == TSK_DIR_FORWARD) { + ret = tsk_ld_calc_get_r2_array_forward(self, a, max_sites, max_distance, + r2, num_r2_values); + } else if (direction == TSK_DIR_REVERSE) { + ret = tsk_ld_calc_get_r2_array_reverse(self, a, max_sites, max_distance, + r2, num_r2_values); + } else { + ret = TSK_ERR_BAD_PARAM_VALUE; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ld_calc_get_r2(tsk_ld_calc_t *self, size_t a, size_t b, double *r2) +{ + int ret = TSK_ERR_GENERIC; + tsk_site_t sA, sB; + double fA, fB, fAB, D; + tsk_tree_t *tA, *tB; + double n = (double) tsk_treeseq_get_num_samples(self->tree_sequence); + double nAB; + size_t tmp; + + if (a >= self->num_sites || b >= self->num_sites) { + ret = TSK_ERR_OUT_OF_BOUNDS; + goto out; + } + if (a > b) { + tmp = a; + a = b; + b = tmp; + } + ret = tsk_ld_calc_position_trees(self, a); + if (ret != 0) { + goto out; + } + /* We can probably do a lot better than this implementation... */ + tA = self->outer_tree; + tB = self->inner_tree; + ret = tsk_treeseq_get_site(self->tree_sequence, a, &sA); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_get_site(self->tree_sequence, b, &sB); + if (ret != 0) { + goto out; + } + if (sA.mutations_length > 1 || sB.mutations_length > 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + assert(sA.mutations_length == 1); + /* assert(tA->parent[sA.mutations[0].node] != TSK_NULL); */ + fA = ((double) tA->num_samples[sA.mutations[0].node]) / n; + assert(fA > 0); + ret = tsk_ld_calc_set_tracked_samples(self, sA); + if (ret != 0) { + goto out; + } + + while (sB.position >= tB->right) { + ret = tsk_tree_next(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + /* assert(tB->parent[sB.mutations[0].node] != TSK_NULL); */ + fB = ((double) tB->num_samples[sB.mutations[0].node]) / n; + assert(fB > 0); + nAB = (double) tB->num_tracked_samples[sB.mutations[0].node]; + fAB = nAB / n; + D = fAB - fA * fB; + *r2 = D * D / (fA * fB * (1 - fA) * (1 - fB)); + + /* Now rewind the inner iterator back. */ + while (tB->index > tA->index) { + ret = tsk_tree_prev(tB); + if (ret < 0) { + goto out; + } + assert(ret == 1); + } + ret = 0; +out: + return ret; +} diff --git a/c/tsk_stats.h b/c/tsk_stats.h new file mode 100644 index 0000000000..f16ddc9785 --- /dev/null +++ b/c/tsk_stats.h @@ -0,0 +1,30 @@ +#ifndef TSK_STATS_H +#define TSK_STATS_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "tsk_trees.h" + +typedef struct { + tsk_tree_t *outer_tree; + tsk_tree_t *inner_tree; + size_t num_sites; + int tree_changed; + tsk_treeseq_t *tree_sequence; +} tsk_ld_calc_t; + +int tsk_ld_calc_alloc(tsk_ld_calc_t *self, tsk_treeseq_t *tree_sequence); +int tsk_ld_calc_free(tsk_ld_calc_t *self); +void tsk_ld_calc_print_state(tsk_ld_calc_t *self, FILE *out); +int tsk_ld_calc_get_r2(tsk_ld_calc_t *self, size_t a, size_t b, double *r2); +int tsk_ld_calc_get_r2_array(tsk_ld_calc_t *self, size_t a, int direction, + size_t max_mutations, double max_distance, + double *r2, size_t *num_r2_values); + + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/tsk_tables.c b/c/tsk_tables.c new file mode 100644 index 0000000000..21ad331aad --- /dev/null +++ b/c/tsk_tables.c @@ -0,0 +1,6073 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "tsk_tables.h" + +/* This is a flag for tsk_tbl_collection_alloc used by tsk_tbl_collection_load to + * avoid allocating the table columns. It's defined internally for now as it's + * not clear how this would be useful outside of tskit. */ +#define TSK_NO_ALLOC_TABLES (1 << 30) + +#define DEFAULT_SIZE_INCREMENT 1024 +#define TABLE_SEP "-----------------------------------------\n" + +typedef struct { + const char *name; + void **array_dest; + tsk_tbl_size_t *len_dest; + tsk_tbl_size_t len_offset; + int type; +} read_table_col_t; + +typedef struct { + const char *name; + void *array; + tsk_tbl_size_t len; + int type; +} write_table_col_t; + + +static int +read_table_cols(kastore_t *store, read_table_col_t *read_cols, size_t num_cols) +{ + int ret = 0; + size_t len; + int type; + size_t j; + tsk_tbl_size_t last_len; + + /* Set all the size destinations to -1 so we can detect the first time we + * read it. Therefore, destinations that are supposed to have the same + * length will take the value of the first instance, and we check each + * subsequent value against this. */ + for (j = 0; j < num_cols; j++) { + *read_cols[j].len_dest = (tsk_tbl_size_t) -1; + } + for (j = 0; j < num_cols; j++) { + ret = kastore_gets(store, read_cols[j].name, read_cols[j].array_dest, + &len, &type); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + last_len = *read_cols[j].len_dest; + if (last_len == (tsk_tbl_size_t) -1) { + *read_cols[j].len_dest = (tsk_tbl_size_t) (len - read_cols[j].len_offset); + } else if ((last_len + read_cols[j].len_offset) != (tsk_tbl_size_t) len) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + if (type != read_cols[j].type) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + } +out: + return ret; +} + + +static int +write_table_cols(kastore_t *store, write_table_col_t *write_cols, size_t num_cols) +{ + int ret = 0; + size_t j; + + for (j = 0; j < num_cols; j++) { + ret = kastore_puts(store, write_cols[j].name, write_cols[j].array, + write_cols[j].len, write_cols[j].type, 0); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + } +out: + return ret; +} + +/* Checks that the specified list of offsets is well-formed. */ +static int +check_offsets(size_t num_rows, tsk_tbl_size_t *offsets, + tsk_tbl_size_t length, bool check_length) +{ + int ret = TSK_ERR_BAD_OFFSET; + size_t j; + + if (offsets[0] != 0) { + goto out; + } + if (check_length && offsets[num_rows] != length) { + goto out; + } + for (j = 0; j < num_rows; j++) { + if (offsets[j] > offsets[j + 1]) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +static int +expand_column(void **column, size_t new_max_rows, size_t element_size) +{ + int ret = 0; + void *tmp; + + tmp = realloc((void **) *column, new_max_rows * element_size); + if (tmp == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + *column = tmp; +out: + return ret; +} + +/************************* + * individual table + *************************/ + +static int +tsk_individual_tbl_expand_main_columns(tsk_individual_tbl_t *self, + tsk_tbl_size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->flags, new_size, sizeof(uint32_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->location_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->metadata_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_individual_tbl_expand_location(tsk_individual_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_location_length_increment); + tsk_tbl_size_t new_size = self->max_location_length + increment; + + if ((self->location_length + additional_length) > self->max_location_length) { + ret = expand_column((void **) &self->location, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + self->max_location_length = new_size; + } +out: + return ret; +} + +static int +tsk_individual_tbl_expand_metadata(tsk_individual_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_metadata_length_increment); + tsk_tbl_size_t new_size = self->max_metadata_length + increment; + + if ((self->metadata_length + additional_length) > self->max_metadata_length) { + ret = expand_column((void **) &self->metadata, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_metadata_length = new_size; + } +out: + return ret; +} + +int +tsk_individual_tbl_set_max_rows_increment(tsk_individual_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_individual_tbl_set_max_metadata_length_increment(tsk_individual_tbl_t *self, + size_t max_metadata_length_increment) +{ + if (max_metadata_length_increment == 0) { + max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_metadata_length_increment = (tsk_tbl_size_t) max_metadata_length_increment; + return 0; +} + +int +tsk_individual_tbl_set_max_location_length_increment(tsk_individual_tbl_t *self, + size_t max_location_length_increment) +{ + if (max_location_length_increment == 0) { + max_location_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_location_length_increment = (tsk_tbl_size_t) max_location_length_increment; + return 0; +} + +int +tsk_individual_tbl_alloc(tsk_individual_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_individual_tbl_t)); + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_location_length_increment = 1; + self->max_metadata_length_increment = 1; + ret = tsk_individual_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_expand_location(self, 1); + if (ret != 0) { + goto out; + } + self->location_offset[0] = 0; + ret = tsk_individual_tbl_expand_metadata(self, 1); + if (ret != 0) { + goto out; + } + self->metadata_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_location_length_increment = DEFAULT_SIZE_INCREMENT; + self->max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_individual_tbl_copy(tsk_individual_tbl_t *self, tsk_individual_tbl_t *dest) +{ + return tsk_individual_tbl_set_columns(dest, self->num_rows, self->flags, + self->location, self->location_offset, self->metadata, self->metadata_offset); +} + +int TSK_WARN_UNUSED +tsk_individual_tbl_set_columns(tsk_individual_tbl_t *self, size_t num_rows, uint32_t *flags, + double *location, uint32_t *location_offset, + const char *metadata, uint32_t *metadata_offset) +{ + int ret; + + ret = tsk_individual_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_append_columns(self, num_rows, flags, location, location_offset, + metadata, metadata_offset); +out: + return ret; +} + +int +tsk_individual_tbl_append_columns(tsk_individual_tbl_t *self, size_t num_rows, uint32_t *flags, + double *location, uint32_t *location_offset, const char *metadata, uint32_t *metadata_offset) +{ + int ret; + tsk_tbl_size_t j, metadata_length, location_length; + + if (flags == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if ((location == NULL) != (location_offset == NULL)) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if ((metadata == NULL) != (metadata_offset == NULL)) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_individual_tbl_expand_main_columns(self, (tsk_tbl_size_t) num_rows); + if (ret != 0) { + goto out; + } + memcpy(self->flags + self->num_rows, flags, num_rows * sizeof(uint32_t)); + if (location == NULL) { + for (j = 0; j < num_rows; j++) { + self->location_offset[self->num_rows + j + 1] = (tsk_tbl_size_t) self->location_length; + } + } else { + ret = check_offsets(num_rows, location_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->location_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->location_length + location_offset[j]; + } + location_length = location_offset[num_rows]; + ret = tsk_individual_tbl_expand_location(self, location_length); + if (ret != 0) { + goto out; + } + memcpy(self->location + self->location_length, location, location_length * sizeof(double)); + self->location_length += location_length; + } + if (metadata == NULL) { + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j + 1] = (tsk_tbl_size_t) self->metadata_length; + } + } else { + ret = check_offsets(num_rows, metadata_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->metadata_length + metadata_offset[j]; + } + metadata_length = metadata_offset[num_rows]; + ret = tsk_individual_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + memcpy(self->metadata + self->metadata_length, metadata, metadata_length * sizeof(char)); + self->metadata_length += metadata_length; + } + self->num_rows += (tsk_tbl_size_t) num_rows; + self->location_offset[self->num_rows] = self->location_length; + self->metadata_offset[self->num_rows] = self->metadata_length; +out: + return ret; +} + +static tsk_id_t +tsk_individual_tbl_add_row_internal(tsk_individual_tbl_t *self, uint32_t flags, double *location, + tsk_tbl_size_t location_length, const char *metadata, tsk_tbl_size_t metadata_length) +{ + assert(self->num_rows < self->max_rows); + assert(self->metadata_length + metadata_length <= self->max_metadata_length); + assert(self->location_length + location_length <= self->max_location_length); + self->flags[self->num_rows] = flags; + memcpy(self->location + self->location_length, location, location_length * sizeof(double)); + self->location_offset[self->num_rows + 1] = self->location_length + location_length; + self->location_length += location_length; + memcpy(self->metadata + self->metadata_length, metadata, metadata_length * sizeof(char)); + self->metadata_offset[self->num_rows + 1] = self->metadata_length + metadata_length; + self->metadata_length += metadata_length; + self->num_rows++; + return (tsk_id_t) self->num_rows - 1; +} + +tsk_id_t +tsk_individual_tbl_add_row(tsk_individual_tbl_t *self, uint32_t flags, double *location, + size_t location_length, const char *metadata, size_t metadata_length) +{ + int ret = 0; + + ret = tsk_individual_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_expand_location(self, (tsk_tbl_size_t) location_length); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_expand_metadata(self, (tsk_tbl_size_t) metadata_length); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_add_row_internal(self, flags, location, + (tsk_tbl_size_t) location_length, metadata, (tsk_tbl_size_t) metadata_length); +out: + return ret; +} + +int +tsk_individual_tbl_clear(tsk_individual_tbl_t *self) +{ + return tsk_individual_tbl_truncate(self, 0); +} + +int +tsk_individual_tbl_truncate(tsk_individual_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; + self->location_length = self->location_offset[n]; + self->metadata_length = self->metadata_offset[n]; +out: + return ret; +} + +int +tsk_individual_tbl_free(tsk_individual_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->flags); + tsk_safe_free(self->location); + tsk_safe_free(self->location_offset); + tsk_safe_free(self->metadata); + tsk_safe_free(self->metadata_offset); + } + return 0; +} + +void +tsk_individual_tbl_print_state(tsk_individual_tbl_t *self, FILE *out) +{ + size_t j, k; + + fprintf(out, TABLE_SEP); + fprintf(out, "tsk_individual_tbl: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "metadata_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->metadata_length, + (int) self->max_metadata_length, + (int) self->max_metadata_length_increment); + fprintf(out, TABLE_SEP); + /* We duplicate the dump_text code here because we want to output + * the offset columns. */ + fprintf(out, "id\tflags\tlocation_offset\tlocation\t"); + fprintf(out, "metadata_offset\tmetadata\n"); + for (j = 0; j < self->num_rows; j++) { + fprintf(out, "%d\t%d\t", (int) j, self->flags[j]); + fprintf(out, "%d\t", self->location_offset[j]); + for (k = self->location_offset[j]; k < self->location_offset[j + 1]; k++) { + fprintf(out, "%f", self->location[k]); + if (k + 1 < self->location_offset[j + 1]) { + fprintf(out, ","); + } + } + fprintf(out, "\t"); + fprintf(out, "%d\t", self->metadata_offset[j]); + for (k = self->metadata_offset[j]; k < self->metadata_offset[j + 1]; k++) { + fprintf(out, "%c", self->metadata[k]); + } + fprintf(out, "\n"); + } +} + +int +tsk_individual_tbl_get_row(tsk_individual_tbl_t *self, size_t index, + tsk_individual_t *row) +{ + int ret = 0; + + if (index >= self->num_rows) { + ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->flags = self->flags[index]; + row->location_length = self->location_offset[index + 1] + - self->location_offset[index]; + row->location = self->location + self->location_offset[index]; + row->metadata_length = self->metadata_offset[index + 1] + - self->metadata_offset[index]; + row->metadata = self->metadata + self->metadata_offset[index]; + /* Also have referencing individuals here. Should this be a different struct? + * See also site. */ + row->nodes_length = 0; + row->nodes = NULL; +out: + return ret; +} + +int +tsk_individual_tbl_dump_text(tsk_individual_tbl_t *self, FILE *out) +{ + int ret = TSK_ERR_IO; + size_t j, k; + tsk_tbl_size_t metadata_len; + int err; + + err = fprintf(out, "id\tflags\tlocation\tmetadata\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + metadata_len = self->metadata_offset[j + 1] - self->metadata_offset[j]; + err = fprintf(out, "%d\t%d\t", (int) j, (int) self->flags[j]); + if (err < 0) { + goto out; + } + for (k = self->location_offset[j]; k < self->location_offset[j + 1]; k++) { + err = fprintf(out, "%.*g", TSK_DBL_DECIMAL_DIG, self->location[k]); + if (err < 0) { + goto out; + } + if (k + 1 < self->location_offset[j + 1]) { + err = fprintf(out, ","); + if (err < 0) { + goto out; + } + } + } + err = fprintf(out, "\t%.*s\n", + metadata_len, self->metadata + self->metadata_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_individual_tbl_equals(tsk_individual_tbl_t *self, tsk_individual_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->metadata_length == other->metadata_length) { + ret = memcmp(self->flags, other->flags, + self->num_rows * sizeof(uint32_t)) == 0 + && memcmp(self->location_offset, other->location_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->location, other->location, + self->location_length * sizeof(double)) == 0 + && memcmp(self->metadata_offset, other->metadata_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) == 0; + } + return ret; +} + +static int +tsk_individual_tbl_dump(tsk_individual_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"individuals/flags", (void *) self->flags, self->num_rows, KAS_UINT32}, + {"individuals/location", (void *) self->location, self->location_length, KAS_FLOAT64}, + {"individuals/location_offset", (void *) self->location_offset, self->num_rows + 1, + KAS_UINT32}, + {"individuals/metadata", (void *) self->metadata, self->metadata_length, KAS_UINT8}, + {"individuals/metadata_offset", (void *) self->metadata_offset, self->num_rows + 1, + KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_individual_tbl_load(tsk_individual_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"individuals/flags", (void **) &self->flags, &self->num_rows, 0, KAS_UINT32}, + {"individuals/location", (void **) &self->location, &self->location_length, 0, + KAS_FLOAT64}, + {"individuals/location_offset", (void **) &self->location_offset, &self->num_rows, + 1, KAS_UINT32}, + {"individuals/metadata", (void **) &self->metadata, &self->metadata_length, 0, + KAS_UINT8}, + {"individuals/metadata_offset", (void **) &self->metadata_offset, &self->num_rows, + 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * node table + *************************/ + +static int +tsk_node_tbl_expand_main_columns(tsk_node_tbl_t *self, tsk_tbl_size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->flags, new_size, sizeof(uint32_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->time, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->population, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->individual, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->metadata_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_node_tbl_expand_metadata(tsk_node_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_metadata_length_increment); + tsk_tbl_size_t new_size = self->max_metadata_length + increment; + + if ((self->metadata_length + additional_length) > self->max_metadata_length) { + ret = expand_column((void **) &self->metadata, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_metadata_length = new_size; + } +out: + return ret; +} + +int +tsk_node_tbl_set_max_rows_increment(tsk_node_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_node_tbl_set_max_metadata_length_increment(tsk_node_tbl_t *self, + size_t max_metadata_length_increment) +{ + if (max_metadata_length_increment == 0) { + max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_metadata_length_increment = (tsk_tbl_size_t) max_metadata_length_increment; + return 0; +} + +int +tsk_node_tbl_alloc(tsk_node_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_node_tbl_t)); + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_metadata_length_increment = 1; + ret = tsk_node_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_expand_metadata(self, 1); + if (ret != 0) { + goto out; + } + self->metadata_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_node_tbl_copy(tsk_node_tbl_t *self, tsk_node_tbl_t *dest) +{ + return tsk_node_tbl_set_columns(dest, self->num_rows, self->flags, + self->time, self->population, self->individual, + self->metadata, self->metadata_offset); +} + +int TSK_WARN_UNUSED +tsk_node_tbl_set_columns(tsk_node_tbl_t *self, size_t num_rows, uint32_t *flags, double *time, + tsk_id_t *population, tsk_id_t *individual, const char *metadata, + uint32_t *metadata_offset) +{ + int ret; + + ret = tsk_node_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_append_columns(self, num_rows, flags, time, population, individual, + metadata, metadata_offset); +out: + return ret; +} + +int +tsk_node_tbl_append_columns(tsk_node_tbl_t *self, size_t num_rows, uint32_t *flags, double *time, + tsk_id_t *population, tsk_id_t *individual, const char *metadata, + uint32_t *metadata_offset) +{ + int ret; + tsk_tbl_size_t j, metadata_length; + + if (flags == NULL || time == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if ((metadata == NULL) != (metadata_offset == NULL)) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_node_tbl_expand_main_columns(self, (tsk_tbl_size_t) num_rows); + if (ret != 0) { + goto out; + } + memcpy(self->time + self->num_rows, time, num_rows * sizeof(double)); + memcpy(self->flags + self->num_rows, flags, num_rows * sizeof(uint32_t)); + if (metadata == NULL) { + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j + 1] = (tsk_tbl_size_t) self->metadata_length; + } + } else { + ret = check_offsets(num_rows, metadata_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->metadata_length + metadata_offset[j]; + } + metadata_length = metadata_offset[num_rows]; + ret = tsk_node_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + memcpy(self->metadata + self->metadata_length, metadata, metadata_length * sizeof(char)); + self->metadata_length += metadata_length; + } + if (population == NULL) { + /* Set population to NULL_POPULATION (-1) if not specified */ + memset(self->population + self->num_rows, 0xff, + num_rows * sizeof(tsk_id_t)); + } else { + memcpy(self->population + self->num_rows, population, + num_rows * sizeof(tsk_id_t)); + } + if (individual == NULL) { + /* Set individual to NULL_INDIVIDUAL (-1) if not specified */ + memset(self->individual + self->num_rows, 0xff, + num_rows * sizeof(tsk_id_t)); + } else { + memcpy(self->individual + self->num_rows, individual, + num_rows * sizeof(tsk_id_t)); + } + self->num_rows += (tsk_tbl_size_t) num_rows; + self->metadata_offset[self->num_rows] = self->metadata_length; +out: + return ret; +} + +static tsk_id_t +tsk_node_tbl_add_row_internal(tsk_node_tbl_t *self, uint32_t flags, double time, + tsk_id_t population, tsk_id_t individual, + const char *metadata, tsk_tbl_size_t metadata_length) +{ + assert(self->num_rows < self->max_rows); + assert(self->metadata_length + metadata_length <= self->max_metadata_length); + memcpy(self->metadata + self->metadata_length, metadata, metadata_length); + self->flags[self->num_rows] = flags; + self->time[self->num_rows] = time; + self->population[self->num_rows] = population; + self->individual[self->num_rows] = individual; + self->metadata_offset[self->num_rows + 1] = self->metadata_length + metadata_length; + self->metadata_length += metadata_length; + self->num_rows++; + return (tsk_id_t) self->num_rows - 1; +} + +tsk_id_t +tsk_node_tbl_add_row(tsk_node_tbl_t *self, uint32_t flags, double time, + tsk_id_t population, tsk_id_t individual, + const char *metadata, size_t metadata_length) +{ + int ret = 0; + + ret = tsk_node_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_expand_metadata(self, (tsk_tbl_size_t) metadata_length); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_add_row_internal(self, flags, time, population, individual, + metadata, (tsk_tbl_size_t) metadata_length); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_node_tbl_clear(tsk_node_tbl_t *self) +{ + return tsk_node_tbl_truncate(self, 0); +} + +int +tsk_node_tbl_truncate(tsk_node_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; + self->metadata_length = self->metadata_offset[n]; +out: + return ret; +} + +int +tsk_node_tbl_free(tsk_node_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->flags); + tsk_safe_free(self->time); + tsk_safe_free(self->population); + tsk_safe_free(self->individual); + tsk_safe_free(self->metadata); + tsk_safe_free(self->metadata_offset); + } + return 0; +} + +void +tsk_node_tbl_print_state(tsk_node_tbl_t *self, FILE *out) +{ + size_t j, k; + + fprintf(out, TABLE_SEP); + fprintf(out, "tsk_node_tbl: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "metadata_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->metadata_length, + (int) self->max_metadata_length, + (int) self->max_metadata_length_increment); + fprintf(out, TABLE_SEP); + /* We duplicate the dump_text code here for simplicity because we want to output + * the flags column directly. */ + fprintf(out, "id\tflags\ttime\tpopulation\tindividual\tmetadata_offset\tmetadata\n"); + for (j = 0; j < self->num_rows; j++) { + fprintf(out, "%d\t%d\t%f\t%d\t%d\t%d\t", (int) j, self->flags[j], self->time[j], + (int) self->population[j], self->individual[j], self->metadata_offset[j]); + for (k = self->metadata_offset[j]; k < self->metadata_offset[j + 1]; k++) { + fprintf(out, "%c", self->metadata[k]); + } + fprintf(out, "\n"); + } + assert(self->metadata_offset[0] == 0); + assert(self->metadata_offset[self->num_rows] == self->metadata_length); +} + +int +tsk_node_tbl_dump_text(tsk_node_tbl_t *self, FILE *out) +{ + int ret = TSK_ERR_IO; + size_t j; + tsk_tbl_size_t metadata_len; + int err; + + err = fprintf(out, "id\tis_sample\ttime\tpopulation\tindividual\tmetadata\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + metadata_len = self->metadata_offset[j + 1] - self->metadata_offset[j]; + err = fprintf(out, "%d\t%d\t%f\t%d\t%d\t%.*s\n", (int) j, + (int) (self->flags[j] & TSK_NODE_IS_SAMPLE), + self->time[j], self->population[j], self->individual[j], + metadata_len, self->metadata + self->metadata_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_node_tbl_equals(tsk_node_tbl_t *self, tsk_node_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->metadata_length == other->metadata_length) { + ret = memcmp(self->time, other->time, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->flags, other->flags, + self->num_rows * sizeof(uint32_t)) == 0 + && memcmp(self->population, other->population, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->individual, other->individual, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->metadata_offset, other->metadata_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) == 0; + } + return ret; +} + +int +tsk_node_tbl_get_row(tsk_node_tbl_t *self, size_t index, tsk_node_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->flags = self->flags[index]; + row->time = self->time[index]; + row->population = self->population[index]; + row->individual = self->individual[index]; + row->metadata_length = self->metadata_offset[index + 1] + - self->metadata_offset[index]; + row->metadata = self->metadata + self->metadata_offset[index]; +out: + return ret; +} + +static int +tsk_node_tbl_dump(tsk_node_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"nodes/time", (void *) self->time, self->num_rows, KAS_FLOAT64}, + {"nodes/flags", (void *) self->flags, self->num_rows, KAS_UINT32}, + {"nodes/population", (void *) self->population, self->num_rows, KAS_INT32}, + {"nodes/individual", (void *) self->individual, self->num_rows, KAS_INT32}, + {"nodes/metadata", (void *) self->metadata, self->metadata_length, KAS_UINT8}, + {"nodes/metadata_offset", (void *) self->metadata_offset, self->num_rows + 1, + KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_node_tbl_load(tsk_node_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"nodes/time", (void **) &self->time, &self->num_rows, 0, KAS_FLOAT64}, + {"nodes/flags", (void **) &self->flags, &self->num_rows, 0, KAS_UINT32}, + {"nodes/population", (void **) &self->population, &self->num_rows, 0, + KAS_INT32}, + {"nodes/individual", (void **) &self->individual, &self->num_rows, 0, + KAS_INT32}, + {"nodes/metadata", (void **) &self->metadata, &self->metadata_length, 0, + KAS_UINT8}, + {"nodes/metadata_offset", (void **) &self->metadata_offset, &self->num_rows, + 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * edge table + *************************/ + +static int +tsk_edge_tbl_expand_columns(tsk_edge_tbl_t *self, size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX( + (tsk_tbl_size_t) additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->left, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->right, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->parent, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->child, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +int +tsk_edge_tbl_set_max_rows_increment(tsk_edge_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_edge_tbl_alloc(tsk_edge_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_edge_tbl_t)); + + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + ret = tsk_edge_tbl_expand_columns(self, 1); + if (ret != 0) { + goto out; + } + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +tsk_id_t +tsk_edge_tbl_add_row(tsk_edge_tbl_t *self, double left, double right, tsk_id_t parent, + tsk_id_t child) +{ + int ret = 0; + + ret = tsk_edge_tbl_expand_columns(self, 1); + if (ret != 0) { + goto out; + } + self->left[self->num_rows] = left; + self->right[self->num_rows] = right; + self->parent[self->num_rows] = parent; + self->child[self->num_rows] = child; + ret = (tsk_id_t) self->num_rows; + self->num_rows++; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_edge_tbl_copy(tsk_edge_tbl_t *self, tsk_edge_tbl_t *dest) +{ + return tsk_edge_tbl_set_columns(dest, self->num_rows, self->left, self->right, + self->parent, self->child); +} + +int +tsk_edge_tbl_set_columns(tsk_edge_tbl_t *self, + size_t num_rows, double *left, double *right, tsk_id_t *parent, tsk_id_t *child) +{ + int ret = 0; + + ret = tsk_edge_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_append_columns(self, num_rows, left, right, parent, child); +out: + return ret; +} + +int +tsk_edge_tbl_append_columns(tsk_edge_tbl_t *self, + size_t num_rows, double *left, double *right, tsk_id_t *parent, tsk_id_t *child) +{ + int ret; + + if (left == NULL || right == NULL || parent == NULL || child == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_edge_tbl_expand_columns(self, num_rows); + if (ret != 0) { + goto out; + } + memcpy(self->left + self->num_rows, left, num_rows * sizeof(double)); + memcpy(self->right + self->num_rows, right, num_rows * sizeof(double)); + memcpy(self->parent + self->num_rows, parent, num_rows * sizeof(tsk_id_t)); + memcpy(self->child + self->num_rows, child, num_rows * sizeof(tsk_id_t)); + self->num_rows += (tsk_tbl_size_t) num_rows; +out: + return ret; +} + +int +tsk_edge_tbl_clear(tsk_edge_tbl_t *self) +{ + return tsk_edge_tbl_truncate(self, 0); +} + +int +tsk_edge_tbl_truncate(tsk_edge_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; +out: + return ret; +} + +int +tsk_edge_tbl_free(tsk_edge_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->left); + tsk_safe_free(self->right); + tsk_safe_free(self->parent); + tsk_safe_free(self->child); + } + return 0; +} + +void +tsk_edge_tbl_print_state(tsk_edge_tbl_t *self, FILE *out) +{ + int ret; + + fprintf(out, TABLE_SEP); + fprintf(out, "edge_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, TABLE_SEP); + ret = tsk_edge_tbl_dump_text(self, out); + assert(ret == 0); +} + +int +tsk_edge_tbl_dump_text(tsk_edge_tbl_t *self, FILE *out) +{ + size_t j; + int ret = TSK_ERR_IO; + int err; + + err = fprintf(out, "left\tright\tparent\tchild\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + err = fprintf(out, "%.3f\t%.3f\t%d\t%d\n", self->left[j], self->right[j], + self->parent[j], self->child[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_edge_tbl_equals(tsk_edge_tbl_t *self, tsk_edge_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows) { + ret = memcmp(self->left, other->left, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->right, other->right, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->parent, other->parent, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->child, other->child, + self->num_rows * sizeof(tsk_id_t)) == 0; + } + return ret; +} + +int +tsk_edge_tbl_get_row(tsk_edge_tbl_t *self, size_t index, tsk_edge_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->left = self->left[index]; + row->right = self->right[index]; + row->parent = self->parent[index]; + row->child = self->child[index]; +out: + return ret; +} + +static int +tsk_edge_tbl_dump(tsk_edge_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"edges/left", (void *) self->left, self->num_rows, KAS_FLOAT64}, + {"edges/right", (void *) self->right, self->num_rows, KAS_FLOAT64}, + {"edges/parent", (void *) self->parent, self->num_rows, KAS_INT32}, + {"edges/child", (void *) self->child, self->num_rows, KAS_INT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_edge_tbl_load(tsk_edge_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"edges/left", (void **) &self->left, &self->num_rows, 0, KAS_FLOAT64}, + {"edges/right", (void **) &self->right, &self->num_rows, 0, KAS_FLOAT64}, + {"edges/parent", (void **) &self->parent, &self->num_rows, 0, KAS_INT32}, + {"edges/child", (void **) &self->child, &self->num_rows, 0, KAS_INT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * site table + *************************/ + +static int +tsk_site_tbl_expand_main_columns(tsk_site_tbl_t *self, size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->position, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->ancestral_state_offset, new_size + 1, + sizeof(uint32_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->metadata_offset, new_size + 1, + sizeof(uint32_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_site_tbl_expand_ancestral_state(tsk_site_tbl_t *self, size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_length, + self->max_ancestral_state_length_increment); + tsk_tbl_size_t new_size = self->max_ancestral_state_length + increment; + + if ((self->ancestral_state_length + additional_length) + > self->max_ancestral_state_length) { + ret = expand_column((void **) &self->ancestral_state, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_ancestral_state_length = new_size; + } +out: + return ret; +} + +static int +tsk_site_tbl_expand_metadata(tsk_site_tbl_t *self, size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_length, + self->max_metadata_length_increment); + tsk_tbl_size_t new_size = self->max_metadata_length + increment; + + if ((self->metadata_length + additional_length) + > self->max_metadata_length) { + ret = expand_column((void **) &self->metadata, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_metadata_length = new_size; + } +out: + return ret; +} + +int +tsk_site_tbl_set_max_rows_increment(tsk_site_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_site_tbl_set_max_metadata_length_increment(tsk_site_tbl_t *self, + size_t max_metadata_length_increment) +{ + if (max_metadata_length_increment == 0) { + max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_metadata_length_increment = (tsk_tbl_size_t) max_metadata_length_increment; + return 0; +} + +int +tsk_site_tbl_set_max_ancestral_state_length_increment(tsk_site_tbl_t *self, + size_t max_ancestral_state_length_increment) +{ + if (max_ancestral_state_length_increment == 0) { + max_ancestral_state_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_ancestral_state_length_increment = + (tsk_tbl_size_t) max_ancestral_state_length_increment; + return 0; +} + +int +tsk_site_tbl_alloc(tsk_site_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_site_tbl_t)); + + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_ancestral_state_length_increment = 1; + self->max_metadata_length_increment = 1; + ret = tsk_site_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_expand_ancestral_state(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_expand_metadata(self, 1); + if (ret != 0) { + goto out; + } + self->ancestral_state_offset[0] = 0; + self->metadata_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_ancestral_state_length_increment = DEFAULT_SIZE_INCREMENT; + self->max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +tsk_id_t +tsk_site_tbl_add_row(tsk_site_tbl_t *self, double position, + const char *ancestral_state, tsk_tbl_size_t ancestral_state_length, + const char *metadata, tsk_tbl_size_t metadata_length) +{ + int ret = 0; + tsk_tbl_size_t ancestral_state_offset, metadata_offset; + + ret = tsk_site_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + self->position[self->num_rows] = position; + + ancestral_state_offset = (tsk_tbl_size_t) self->ancestral_state_length; + assert(self->ancestral_state_offset[self->num_rows] == ancestral_state_offset); + ret = tsk_site_tbl_expand_ancestral_state(self, ancestral_state_length); + if (ret != 0) { + goto out; + } + self->ancestral_state_length += ancestral_state_length; + memcpy(self->ancestral_state + ancestral_state_offset, ancestral_state, + ancestral_state_length); + self->ancestral_state_offset[self->num_rows + 1] = self->ancestral_state_length; + + metadata_offset = (tsk_tbl_size_t) self->metadata_length; + assert(self->metadata_offset[self->num_rows] == metadata_offset); + ret = tsk_site_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + self->metadata_length += metadata_length; + memcpy(self->metadata + metadata_offset, metadata, metadata_length); + self->metadata_offset[self->num_rows + 1] = self->metadata_length; + + ret = (tsk_id_t) self->num_rows; + self->num_rows++; +out: + return ret; +} + +int +tsk_site_tbl_append_columns(tsk_site_tbl_t *self, size_t num_rows, double *position, + const char *ancestral_state, tsk_tbl_size_t *ancestral_state_offset, + const char *metadata, tsk_tbl_size_t *metadata_offset) +{ + int ret = 0; + tsk_tbl_size_t j, ancestral_state_length, metadata_length; + + if (position == NULL || ancestral_state == NULL || ancestral_state_offset == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if ((metadata == NULL) != (metadata_offset == NULL)) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + + ret = tsk_site_tbl_expand_main_columns(self, num_rows); + if (ret != 0) { + goto out; + } + memcpy(self->position + self->num_rows, position, num_rows * sizeof(double)); + + /* Metadata column */ + if (metadata == NULL) { + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j + 1] = (tsk_tbl_size_t) self->metadata_length; + } + } else { + ret = check_offsets(num_rows, metadata_offset, 0, false); + if (ret != 0) { + goto out; + } + metadata_length = metadata_offset[num_rows]; + ret = tsk_site_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + memcpy(self->metadata + self->metadata_length, metadata, + metadata_length * sizeof(char)); + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->metadata_length + metadata_offset[j]; + } + self->metadata_length += metadata_length; + } + self->metadata_offset[self->num_rows + num_rows] = self->metadata_length; + + /* Ancestral state column */ + ret = check_offsets(num_rows, ancestral_state_offset, 0, false); + if (ret != 0) { + goto out; + } + ancestral_state_length = ancestral_state_offset[num_rows]; + ret = tsk_site_tbl_expand_ancestral_state(self, ancestral_state_length); + if (ret != 0) { + goto out; + } + memcpy(self->ancestral_state + self->ancestral_state_length, ancestral_state, + ancestral_state_length * sizeof(char)); + for (j = 0; j < num_rows; j++) { + self->ancestral_state_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->ancestral_state_length + ancestral_state_offset[j]; + } + self->ancestral_state_length += ancestral_state_length; + self->ancestral_state_offset[self->num_rows + num_rows] = self->ancestral_state_length; + + self->num_rows += (tsk_tbl_size_t) num_rows; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_site_tbl_copy(tsk_site_tbl_t *self, tsk_site_tbl_t *dest) +{ + return tsk_site_tbl_set_columns(dest, self->num_rows, self->position, + self->ancestral_state, self->ancestral_state_offset, + self->metadata, self->metadata_offset); +} + +int +tsk_site_tbl_set_columns(tsk_site_tbl_t *self, size_t num_rows, double *position, + const char *ancestral_state, tsk_tbl_size_t *ancestral_state_offset, + const char *metadata, tsk_tbl_size_t *metadata_offset) +{ + int ret = 0; + + ret = tsk_site_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_append_columns(self, num_rows, position, ancestral_state, + ancestral_state_offset, metadata, metadata_offset); +out: + return ret; +} + +bool +tsk_site_tbl_equals(tsk_site_tbl_t *self, tsk_site_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->ancestral_state_length == other->ancestral_state_length + && self->metadata_length == other->metadata_length) { + ret = memcmp(self->position, other->position, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->ancestral_state_offset, other->ancestral_state_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->ancestral_state, other->ancestral_state, + self->ancestral_state_length * sizeof(char)) == 0 + && memcmp(self->metadata_offset, other->metadata_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) == 0; + } + return ret; +} + +int +tsk_site_tbl_clear(tsk_site_tbl_t *self) +{ + return tsk_site_tbl_truncate(self, 0); +} + +int +tsk_site_tbl_truncate(tsk_site_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; + self->ancestral_state_length = self->ancestral_state_offset[n]; + self->metadata_length = self->metadata_offset[n]; +out: + return ret; +} + +int +tsk_site_tbl_free(tsk_site_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->position); + tsk_safe_free(self->ancestral_state); + tsk_safe_free(self->ancestral_state_offset); + tsk_safe_free(self->metadata); + tsk_safe_free(self->metadata_offset); + } + return 0; +} + +void +tsk_site_tbl_print_state(tsk_site_tbl_t *self, FILE *out) +{ + int ret; + + fprintf(out, TABLE_SEP); + fprintf(out, "site_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\t(max= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "ancestral_state_length = %d\t(max= %d\tincrement = %d)\n", + (int) self->ancestral_state_length, + (int) self->max_ancestral_state_length, + (int) self->max_ancestral_state_length_increment); + fprintf(out, "metadata_length = %d(\tmax= %d\tincrement = %d)\n", + (int) self->metadata_length, + (int) self->max_metadata_length, + (int) self->max_metadata_length_increment); + fprintf(out, TABLE_SEP); + ret = tsk_site_tbl_dump_text(self, out); + assert(ret == 0); + + assert(self->ancestral_state_offset[0] == 0); + assert(self->ancestral_state_length + == self->ancestral_state_offset[self->num_rows]); + assert(self->metadata_offset[0] == 0); + assert(self->metadata_length == self->metadata_offset[self->num_rows]); +} + +int +tsk_site_tbl_get_row(tsk_site_tbl_t *self, size_t index, tsk_site_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->position = self->position[index]; + row->ancestral_state_length = self->ancestral_state_offset[index + 1] + - self->ancestral_state_offset[index]; + row->ancestral_state = self->ancestral_state + self->ancestral_state_offset[index]; + row->metadata_length = self->metadata_offset[index + 1] + - self->metadata_offset[index]; + row->metadata = self->metadata + self->metadata_offset[index]; + /* This struct has a placeholder for mutations. Probably should be separate + * structs for this (tsk_site_tbl_row_t?) */ + row->mutations_length = 0; + row->mutations = NULL; +out: + return ret; +} + +int +tsk_site_tbl_dump_text(tsk_site_tbl_t *self, FILE *out) +{ + size_t j; + int ret = TSK_ERR_IO; + int err; + tsk_tbl_size_t ancestral_state_len, metadata_len; + + err = fprintf(out, "id\tposition\tancestral_state\tmetadata\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + ancestral_state_len = self->ancestral_state_offset[j + 1] - + self->ancestral_state_offset[j]; + metadata_len = self->metadata_offset[j + 1] - self->metadata_offset[j]; + err = fprintf(out, "%d\t%f\t%.*s\t%.*s\n", (int) j, self->position[j], + ancestral_state_len, self->ancestral_state + self->ancestral_state_offset[j], + metadata_len, self->metadata + self->metadata_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +static int +tsk_site_tbl_dump(tsk_site_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"sites/position", (void *) self->position, self->num_rows, KAS_FLOAT64}, + {"sites/ancestral_state", (void *) self->ancestral_state, + self->ancestral_state_length, KAS_UINT8}, + {"sites/ancestral_state_offset", (void *) self->ancestral_state_offset, + self->num_rows + 1, KAS_UINT32}, + {"sites/metadata", (void *) self->metadata, self->metadata_length, KAS_UINT8}, + {"sites/metadata_offset", (void *) self->metadata_offset, + self->num_rows + 1, KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_site_tbl_load(tsk_site_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"sites/position", (void **) &self->position, &self->num_rows, 0, KAS_FLOAT64}, + {"sites/ancestral_state", (void **) &self->ancestral_state, + &self->ancestral_state_length, 0, KAS_UINT8}, + {"sites/ancestral_state_offset", (void **) &self->ancestral_state_offset, + &self->num_rows, 1, KAS_UINT32}, + {"sites/metadata", (void **) &self->metadata, + &self->metadata_length, 0, KAS_UINT8}, + {"sites/metadata_offset", (void **) &self->metadata_offset, + &self->num_rows, 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * mutation table + *************************/ + +static int +tsk_mutation_tbl_expand_main_columns(tsk_mutation_tbl_t *self, size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->site, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->node, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->parent, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->derived_state_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->metadata_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_mutation_tbl_expand_derived_state(tsk_mutation_tbl_t *self, size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_length, + self->max_derived_state_length_increment); + tsk_tbl_size_t new_size = self->max_derived_state_length + increment; + + if ((self->derived_state_length + additional_length) + > self->max_derived_state_length) { + ret = expand_column((void **) &self->derived_state, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_derived_state_length = (tsk_tbl_size_t) new_size; + } +out: + return ret; +} + +static int +tsk_mutation_tbl_expand_metadata(tsk_mutation_tbl_t *self, size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = (tsk_tbl_size_t) TSK_MAX(additional_length, + self->max_metadata_length_increment); + tsk_tbl_size_t new_size = self->max_metadata_length + increment; + + if ((self->metadata_length + additional_length) + > self->max_metadata_length) { + ret = expand_column((void **) &self->metadata, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_metadata_length = new_size; + } +out: + return ret; +} + +int +tsk_mutation_tbl_set_max_rows_increment(tsk_mutation_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_mutation_tbl_set_max_metadata_length_increment(tsk_mutation_tbl_t *self, + size_t max_metadata_length_increment) +{ + if (max_metadata_length_increment == 0) { + max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_metadata_length_increment = (tsk_tbl_size_t) max_metadata_length_increment; + return 0; +} + +int +tsk_mutation_tbl_set_max_derived_state_length_increment(tsk_mutation_tbl_t *self, + size_t max_derived_state_length_increment) +{ + if (max_derived_state_length_increment == 0) { + max_derived_state_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_derived_state_length_increment = + (tsk_tbl_size_t) max_derived_state_length_increment; + return 0; +} + +int +tsk_mutation_tbl_alloc(tsk_mutation_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_mutation_tbl_t)); + + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_derived_state_length_increment = 1; + self->max_metadata_length_increment = 1; + ret = tsk_mutation_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_expand_derived_state(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_expand_metadata(self, 1); + if (ret != 0) { + goto out; + } + self->derived_state_offset[0] = 0; + self->metadata_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_derived_state_length_increment = DEFAULT_SIZE_INCREMENT; + self->max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +tsk_id_t +tsk_mutation_tbl_add_row(tsk_mutation_tbl_t *self, tsk_id_t site, tsk_id_t node, + tsk_id_t parent, + const char *derived_state, tsk_tbl_size_t derived_state_length, + const char *metadata, tsk_tbl_size_t metadata_length) +{ + tsk_tbl_size_t derived_state_offset, metadata_offset; + int ret; + + ret = tsk_mutation_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + self->site[self->num_rows] = site; + self->node[self->num_rows] = node; + self->parent[self->num_rows] = parent; + + derived_state_offset = (tsk_tbl_size_t) self->derived_state_length; + assert(self->derived_state_offset[self->num_rows] == derived_state_offset); + ret = tsk_mutation_tbl_expand_derived_state(self, derived_state_length); + if (ret != 0) { + goto out; + } + self->derived_state_length += derived_state_length; + memcpy(self->derived_state + derived_state_offset, derived_state, + derived_state_length); + self->derived_state_offset[self->num_rows + 1] = self->derived_state_length; + + metadata_offset = (tsk_tbl_size_t) self->metadata_length; + assert(self->metadata_offset[self->num_rows] == metadata_offset); + ret = tsk_mutation_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + self->metadata_length += metadata_length; + memcpy(self->metadata + metadata_offset, metadata, metadata_length); + self->metadata_offset[self->num_rows + 1] = self->metadata_length; + + ret = (tsk_id_t) self->num_rows; + self->num_rows++; +out: + return ret; +} + +int +tsk_mutation_tbl_append_columns(tsk_mutation_tbl_t *self, size_t num_rows, tsk_id_t *site, + tsk_id_t *node, tsk_id_t *parent, + const char *derived_state, tsk_tbl_size_t *derived_state_offset, + const char *metadata, tsk_tbl_size_t *metadata_offset) +{ + int ret = 0; + tsk_tbl_size_t j, derived_state_length, metadata_length; + + if (site == NULL || node == NULL || derived_state == NULL + || derived_state_offset == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if ((metadata == NULL) != (metadata_offset == NULL)) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + + ret = tsk_mutation_tbl_expand_main_columns(self, num_rows); + if (ret != 0) { + goto out; + } + memcpy(self->site + self->num_rows, site, num_rows * sizeof(tsk_id_t)); + memcpy(self->node + self->num_rows, node, num_rows * sizeof(tsk_id_t)); + if (parent == NULL) { + /* If parent is NULL, set all parents to the null mutation */ + memset(self->parent + self->num_rows, 0xff, num_rows * sizeof(tsk_id_t)); + } else { + memcpy(self->parent + self->num_rows, parent, num_rows * sizeof(tsk_id_t)); + } + + /* Metadata column */ + if (metadata == NULL) { + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j + 1] = (tsk_tbl_size_t) self->metadata_length; + } + } else { + ret = check_offsets(num_rows, metadata_offset, 0, false); + if (ret != 0) { + goto out; + } + metadata_length = metadata_offset[num_rows]; + ret = tsk_mutation_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + memcpy(self->metadata + self->metadata_length, metadata, + metadata_length * sizeof(char)); + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->metadata_length + metadata_offset[j]; + } + self->metadata_length += metadata_length; + } + self->metadata_offset[self->num_rows + num_rows] = self->metadata_length; + + /* Derived state column */ + ret = check_offsets(num_rows, derived_state_offset, 0, false); + if (ret != 0) { + goto out; + } + derived_state_length = derived_state_offset[num_rows]; + ret = tsk_mutation_tbl_expand_derived_state(self, derived_state_length); + if (ret != 0) { + goto out; + } + memcpy(self->derived_state + self->derived_state_length, derived_state, + derived_state_length * sizeof(char)); + for (j = 0; j < num_rows; j++) { + self->derived_state_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->derived_state_length + derived_state_offset[j]; + } + self->derived_state_length += derived_state_length; + self->derived_state_offset[self->num_rows + num_rows] = self->derived_state_length; + + self->num_rows += (tsk_tbl_size_t) num_rows; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_mutation_tbl_copy(tsk_mutation_tbl_t *self, tsk_mutation_tbl_t *dest) +{ + return tsk_mutation_tbl_set_columns(dest, self->num_rows, + self->site, self->node, self->parent, + self->derived_state, self->derived_state_offset, + self->metadata, self->metadata_offset); +} + +int +tsk_mutation_tbl_set_columns(tsk_mutation_tbl_t *self, size_t num_rows, tsk_id_t *site, + tsk_id_t *node, tsk_id_t *parent, + const char *derived_state, tsk_tbl_size_t *derived_state_offset, + const char *metadata, tsk_tbl_size_t *metadata_offset) +{ + int ret = 0; + + ret = tsk_mutation_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_append_columns(self, num_rows, site, node, parent, + derived_state, derived_state_offset, metadata, metadata_offset); +out: + return ret; +} + +bool +tsk_mutation_tbl_equals(tsk_mutation_tbl_t *self, tsk_mutation_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->derived_state_length == other->derived_state_length + && self->metadata_length == other->metadata_length) { + ret = memcmp(self->site, other->site, self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->node, other->node, self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->parent, other->parent, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->derived_state_offset, other->derived_state_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->derived_state, other->derived_state, + self->derived_state_length * sizeof(char)) == 0 + && memcmp(self->metadata_offset, other->metadata_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) == 0; + } + return ret; +} + +int +tsk_mutation_tbl_clear(tsk_mutation_tbl_t *self) +{ + return tsk_mutation_tbl_truncate(self, 0); +} + +int +tsk_mutation_tbl_truncate(tsk_mutation_tbl_t *mutations, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > mutations->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + mutations->num_rows = n; + mutations->derived_state_length = mutations->derived_state_offset[n]; + mutations->metadata_length = mutations->metadata_offset[n]; +out: + return ret; +} + +int +tsk_mutation_tbl_free(tsk_mutation_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->node); + tsk_safe_free(self->site); + tsk_safe_free(self->parent); + tsk_safe_free(self->derived_state); + tsk_safe_free(self->derived_state_offset); + tsk_safe_free(self->metadata); + tsk_safe_free(self->metadata_offset); + } + return 0; +} + +void +tsk_mutation_tbl_print_state(tsk_mutation_tbl_t *self, FILE *out) +{ + int ret; + + fprintf(out, TABLE_SEP); + fprintf(out, "mutation_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "derived_state_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->derived_state_length, + (int) self->max_derived_state_length, + (int) self->max_derived_state_length_increment); + fprintf(out, "metadata_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->metadata_length, + (int) self->max_metadata_length, + (int) self->max_metadata_length_increment); + fprintf(out, TABLE_SEP); + ret = tsk_mutation_tbl_dump_text(self, out); + assert(ret == 0); + assert(self->derived_state_offset[0] == 0); + assert(self->derived_state_length + == self->derived_state_offset[self->num_rows]); + assert(self->metadata_offset[0] == 0); + assert(self->metadata_length + == self->metadata_offset[self->num_rows]); +} + +int +tsk_mutation_tbl_get_row(tsk_mutation_tbl_t *self, size_t index, tsk_mutation_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->site = self->site[index]; + row->node = self->node[index]; + row->parent = self->parent[index]; + row->derived_state_length = self->derived_state_offset[index + 1] + - self->derived_state_offset[index]; + row->derived_state = self->derived_state + self->derived_state_offset[index]; + row->metadata_length = self->metadata_offset[index + 1] + - self->metadata_offset[index]; + row->metadata = self->metadata + self->metadata_offset[index]; +out: + return ret; +} + +int +tsk_mutation_tbl_dump_text(tsk_mutation_tbl_t *self, FILE *out) +{ + size_t j; + int ret = TSK_ERR_IO; + int err; + tsk_tbl_size_t derived_state_len, metadata_len; + + err = fprintf(out, "id\tsite\tnode\tparent\tderived_state\tmetadata\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + derived_state_len = self->derived_state_offset[j + 1] - + self->derived_state_offset[j]; + metadata_len = self->metadata_offset[j + 1] - self->metadata_offset[j]; + err = fprintf(out, "%d\t%d\t%d\t%d\t%.*s\t%.*s\n", (int) j, + self->site[j], self->node[j], self->parent[j], + derived_state_len, self->derived_state + self->derived_state_offset[j], + metadata_len, self->metadata + self->metadata_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +static int +tsk_mutation_tbl_dump(tsk_mutation_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"mutations/site", (void *) self->site, self->num_rows, KAS_INT32}, + {"mutations/node", (void *) self->node, self->num_rows, KAS_INT32}, + {"mutations/parent", (void *) self->parent, self->num_rows, KAS_INT32}, + {"mutations/derived_state", (void *) self->derived_state, + self->derived_state_length, KAS_UINT8}, + {"mutations/derived_state_offset", (void *) self->derived_state_offset, + self->num_rows + 1, KAS_UINT32}, + {"mutations/metadata", (void *) self->metadata, + self->metadata_length, KAS_UINT8}, + {"mutations/metadata_offset", (void *) self->metadata_offset, + self->num_rows + 1, KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + + +static int +tsk_mutation_tbl_load(tsk_mutation_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"mutations/site", (void **) &self->site, &self->num_rows, 0, KAS_INT32}, + {"mutations/node", (void **) &self->node, &self->num_rows, 0, KAS_INT32}, + {"mutations/parent", (void **) &self->parent, &self->num_rows, 0, KAS_INT32}, + {"mutations/derived_state", (void **) &self->derived_state, + &self->derived_state_length, 0, KAS_UINT8}, + {"mutations/derived_state_offset", (void **) &self->derived_state_offset, + &self->num_rows, 1, KAS_UINT32}, + {"mutations/metadata", (void **) &self->metadata, + &self->metadata_length, 0, KAS_UINT8}, + {"mutations/metadata_offset", (void **) &self->metadata_offset, + &self->num_rows, 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * migration table + *************************/ + +static int +tsk_migration_tbl_expand(tsk_migration_tbl_t *self, size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX( + (tsk_tbl_size_t) additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->left, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->right, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->node, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->source, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->dest, new_size, sizeof(tsk_id_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->time, new_size, sizeof(double)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +int +tsk_migration_tbl_set_max_rows_increment(tsk_migration_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_migration_tbl_alloc(tsk_migration_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_migration_tbl_t)); + + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + ret = tsk_migration_tbl_expand(self, 1); + if (ret != 0) { + goto out; + } + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +int +tsk_migration_tbl_append_columns(tsk_migration_tbl_t *self, size_t num_rows, double *left, + double *right, tsk_id_t *node, tsk_id_t *source, tsk_id_t *dest, + double *time) +{ + int ret; + + ret = tsk_migration_tbl_expand(self, num_rows); + if (ret != 0) { + goto out; + } + if (left == NULL || right == NULL || node == NULL || source == NULL + || dest == NULL || time == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + memcpy(self->left + self->num_rows, left, num_rows * sizeof(double)); + memcpy(self->right + self->num_rows, right, num_rows * sizeof(double)); + memcpy(self->node + self->num_rows, node, num_rows * sizeof(tsk_id_t)); + memcpy(self->source + self->num_rows, source, num_rows * sizeof(tsk_id_t)); + memcpy(self->dest + self->num_rows, dest, num_rows * sizeof(tsk_id_t)); + memcpy(self->time + self->num_rows, time, num_rows * sizeof(double)); + self->num_rows += (tsk_tbl_size_t) num_rows; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_migration_tbl_copy(tsk_migration_tbl_t *self, tsk_migration_tbl_t *dest) +{ + return tsk_migration_tbl_set_columns(dest, self->num_rows, + self->left, self->right, self->node, + self->source, self->dest, self->time); +} + +int +tsk_migration_tbl_set_columns(tsk_migration_tbl_t *self, size_t num_rows, double *left, + double *right, tsk_id_t *node, tsk_id_t *source, tsk_id_t *dest, + double *time) +{ + int ret; + + ret = tsk_migration_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_append_columns(self, num_rows, left, right, node, source, + dest, time); +out: + return ret; +} + +tsk_id_t +tsk_migration_tbl_add_row(tsk_migration_tbl_t *self, double left, double right, + tsk_id_t node, tsk_id_t source, tsk_id_t dest, double time) +{ + int ret = 0; + + ret = tsk_migration_tbl_expand(self, 1); + if (ret != 0) { + goto out; + } + self->left[self->num_rows] = left; + self->right[self->num_rows] = right; + self->node[self->num_rows] = node; + self->source[self->num_rows] = source; + self->dest[self->num_rows] = dest; + self->time[self->num_rows] = time; + ret = (tsk_id_t) self->num_rows; + self->num_rows++; +out: + return ret; +} + +int +tsk_migration_tbl_clear(tsk_migration_tbl_t *self) +{ + return tsk_migration_tbl_truncate(self, 0); +} + +int +tsk_migration_tbl_truncate(tsk_migration_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; +out: + return ret; +} + +int +tsk_migration_tbl_free(tsk_migration_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->left); + tsk_safe_free(self->right); + tsk_safe_free(self->node); + tsk_safe_free(self->source); + tsk_safe_free(self->dest); + tsk_safe_free(self->time); + } + return 0; +} + +void +tsk_migration_tbl_print_state(tsk_migration_tbl_t *self, FILE *out) +{ + int ret; + + fprintf(out, TABLE_SEP); + fprintf(out, "migration_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, TABLE_SEP); + ret = tsk_migration_tbl_dump_text(self, out); + assert(ret == 0); +} + +int +tsk_migration_tbl_get_row(tsk_migration_tbl_t *self, size_t index, tsk_migration_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_MIGRATION_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->left = self->left[index]; + row->right = self->right[index]; + row->node = self->node[index]; + row->source = self->source[index]; + row->dest = self->dest[index]; + row->time = self->time[index]; +out: + return ret; +} + +int +tsk_migration_tbl_dump_text(tsk_migration_tbl_t *self, FILE *out) +{ + size_t j; + int ret = TSK_ERR_IO; + int err; + + err = fprintf(out, "left\tright\tnode\tsource\tdest\ttime\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + err = fprintf(out, "%.3f\t%.3f\t%d\t%d\t%d\t%f\n", self->left[j], + self->right[j], (int) self->node[j], (int) self->source[j], + (int) self->dest[j], self->time[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_migration_tbl_equals(tsk_migration_tbl_t *self, tsk_migration_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows) { + ret = memcmp(self->left, other->left, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->right, other->right, + self->num_rows * sizeof(double)) == 0 + && memcmp(self->node, other->node, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->source, other->source, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->dest, other->dest, + self->num_rows * sizeof(tsk_id_t)) == 0 + && memcmp(self->time, other->time, + self->num_rows * sizeof(double)) == 0; + } + return ret; +} + +static int +tsk_migration_tbl_dump(tsk_migration_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"migrations/left", (void *) self->left, self->num_rows, KAS_FLOAT64}, + {"migrations/right", (void *) self->right, self->num_rows, KAS_FLOAT64}, + {"migrations/node", (void *) self->node, self->num_rows, KAS_INT32}, + {"migrations/source", (void *) self->source, self->num_rows, KAS_INT32}, + {"migrations/dest", (void *) self->dest, self->num_rows, KAS_INT32}, + {"migrations/time", (void *) self->time, self->num_rows, KAS_FLOAT64}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_migration_tbl_load(tsk_migration_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"migrations/left", (void **) &self->left, &self->num_rows, 0, KAS_FLOAT64}, + {"migrations/right", (void **) &self->right, &self->num_rows, 0, KAS_FLOAT64}, + {"migrations/node", (void **) &self->node, &self->num_rows, 0, KAS_INT32}, + {"migrations/source", (void **) &self->source, &self->num_rows, 0, KAS_INT32}, + {"migrations/dest", (void **) &self->dest, &self->num_rows, 0, KAS_INT32}, + {"migrations/time", (void **) &self->time, &self->num_rows, 0, KAS_FLOAT64}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * population table + *************************/ + +static int +tsk_population_tbl_expand_main_columns(tsk_population_tbl_t *self, tsk_tbl_size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->metadata_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_population_tbl_expand_metadata(tsk_population_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_metadata_length_increment); + tsk_tbl_size_t new_size = self->max_metadata_length + increment; + + if ((self->metadata_length + additional_length) > self->max_metadata_length) { + ret = expand_column((void **) &self->metadata, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_metadata_length = new_size; + } +out: + return ret; +} + +int +tsk_population_tbl_set_max_rows_increment(tsk_population_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_population_tbl_set_max_metadata_length_increment(tsk_population_tbl_t *self, + size_t max_metadata_length_increment) +{ + if (max_metadata_length_increment == 0) { + max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_metadata_length_increment = (tsk_tbl_size_t) max_metadata_length_increment; + return 0; +} + +int +tsk_population_tbl_alloc(tsk_population_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_population_tbl_t)); + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_metadata_length_increment = 1; + ret = tsk_population_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_expand_metadata(self, 1); + if (ret != 0) { + goto out; + } + self->metadata_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_metadata_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_population_tbl_copy(tsk_population_tbl_t *self, tsk_population_tbl_t *dest) +{ + return tsk_population_tbl_set_columns(dest, self->num_rows, + self->metadata, self->metadata_offset); +} + +int +tsk_population_tbl_set_columns(tsk_population_tbl_t *self, size_t num_rows, + const char *metadata, uint32_t *metadata_offset) +{ + int ret; + + ret = tsk_population_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_append_columns(self, num_rows, metadata, metadata_offset); +out: + return ret; +} + +int +tsk_population_tbl_append_columns(tsk_population_tbl_t *self, size_t num_rows, + const char *metadata, uint32_t *metadata_offset) +{ + int ret; + tsk_tbl_size_t j, metadata_length; + + if (metadata == NULL || metadata_offset == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_population_tbl_expand_main_columns(self, (tsk_tbl_size_t) num_rows); + if (ret != 0) { + goto out; + } + + ret = check_offsets(num_rows, metadata_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->metadata_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->metadata_length + metadata_offset[j]; + } + metadata_length = metadata_offset[num_rows]; + ret = tsk_population_tbl_expand_metadata(self, metadata_length); + if (ret != 0) { + goto out; + } + memcpy(self->metadata + self->metadata_length, metadata, + metadata_length * sizeof(char)); + self->metadata_length += metadata_length; + + self->num_rows += (tsk_tbl_size_t) num_rows; + self->metadata_offset[self->num_rows] = self->metadata_length; +out: + return ret; +} + +static tsk_id_t +tsk_population_tbl_add_row_internal(tsk_population_tbl_t *self, + const char *metadata, tsk_tbl_size_t metadata_length) +{ + int ret = 0; + + assert(self->num_rows < self->max_rows); + assert(self->metadata_length + metadata_length <= self->max_metadata_length); + memcpy(self->metadata + self->metadata_length, metadata, metadata_length); + self->metadata_offset[self->num_rows + 1] = self->metadata_length + metadata_length; + self->metadata_length += metadata_length; + ret = (tsk_id_t) self->num_rows; + self->num_rows++; + return ret; +} + +tsk_id_t +tsk_population_tbl_add_row(tsk_population_tbl_t *self, + const char *metadata, size_t metadata_length) +{ + int ret = 0; + + ret = tsk_population_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_expand_metadata(self, (tsk_tbl_size_t) metadata_length); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_add_row_internal(self, + metadata, (tsk_tbl_size_t) metadata_length); +out: + return ret; +} + +int +tsk_population_tbl_clear(tsk_population_tbl_t *self) +{ + return tsk_population_tbl_truncate(self, 0); +} + +int +tsk_population_tbl_truncate(tsk_population_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; + self->metadata_length = self->metadata_offset[n]; +out: + return ret; +} + +int +tsk_population_tbl_free(tsk_population_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->metadata); + tsk_safe_free(self->metadata_offset); + } + return 0; +} + +void +tsk_population_tbl_print_state(tsk_population_tbl_t *self, FILE *out) +{ + size_t j, k; + + fprintf(out, TABLE_SEP); + fprintf(out, "population_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "metadata_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->metadata_length, + (int) self->max_metadata_length, + (int) self->max_metadata_length_increment); + fprintf(out, TABLE_SEP); + fprintf(out, "index\tmetadata_offset\tmetadata\n"); + for (j = 0; j < self->num_rows; j++) { + fprintf(out, "%d\t%d\t", (int) j, self->metadata_offset[j]); + for (k = self->metadata_offset[j]; k < self->metadata_offset[j + 1]; k++) { + fprintf(out, "%c", self->metadata[k]); + } + fprintf(out, "\n"); + } + assert(self->metadata_offset[0] == 0); + assert(self->metadata_offset[self->num_rows] == self->metadata_length); +} + +int +tsk_population_tbl_get_row(tsk_population_tbl_t *self, size_t index, tsk_population_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->metadata_length = self->metadata_offset[index + 1] + - self->metadata_offset[index]; + row->metadata = self->metadata + self->metadata_offset[index]; +out: + return ret; +} + +int +tsk_population_tbl_dump_text(tsk_population_tbl_t *self, FILE *out) +{ + int ret = TSK_ERR_IO; + int err; + size_t j; + tsk_tbl_size_t metadata_len; + + err = fprintf(out, "metadata\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + metadata_len = self->metadata_offset[j + 1] - self->metadata_offset[j]; + err = fprintf(out, "%.*s\n", metadata_len, + self->metadata + self->metadata_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_population_tbl_equals(tsk_population_tbl_t *self, tsk_population_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->metadata_length == other->metadata_length) { + ret = memcmp(self->metadata_offset, other->metadata_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) == 0; + } + return ret; +} + +static int +tsk_population_tbl_dump(tsk_population_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"populations/metadata", (void *) self->metadata, + self->metadata_length, KAS_UINT8}, + {"populations/metadata_offset", (void *) self->metadata_offset, + self->num_rows+ 1, KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_population_tbl_load(tsk_population_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"populations/metadata", (void **) &self->metadata, + &self->metadata_length, 0, KAS_UINT8}, + {"populations/metadata_offset", (void **) &self->metadata_offset, + &self->num_rows, 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * provenance table + *************************/ + +static int +tsk_provenance_tbl_expand_main_columns(tsk_provenance_tbl_t *self, tsk_tbl_size_t additional_rows) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_rows, self->max_rows_increment); + tsk_tbl_size_t new_size = self->max_rows + increment; + + if ((self->num_rows + additional_rows) > self->max_rows) { + ret = expand_column((void **) &self->timestamp_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + ret = expand_column((void **) &self->record_offset, new_size + 1, + sizeof(tsk_tbl_size_t)); + if (ret != 0) { + goto out; + } + self->max_rows = new_size; + } +out: + return ret; +} + +static int +tsk_provenance_tbl_expand_timestamp(tsk_provenance_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_timestamp_length_increment); + tsk_tbl_size_t new_size = self->max_timestamp_length + increment; + + if ((self->timestamp_length + additional_length) > self->max_timestamp_length) { + ret = expand_column((void **) &self->timestamp, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_timestamp_length = new_size; + } +out: + return ret; +} + +static int +tsk_provenance_tbl_expand_provenance(tsk_provenance_tbl_t *self, tsk_tbl_size_t additional_length) +{ + int ret = 0; + tsk_tbl_size_t increment = TSK_MAX(additional_length, + self->max_record_length_increment); + tsk_tbl_size_t new_size = self->max_record_length + increment; + + if ((self->record_length + additional_length) > self->max_record_length) { + ret = expand_column((void **) &self->record, new_size, sizeof(char)); + if (ret != 0) { + goto out; + } + self->max_record_length = new_size; + } +out: + return ret; +} + + +int +tsk_provenance_tbl_set_max_rows_increment(tsk_provenance_tbl_t *self, size_t max_rows_increment) +{ + if (max_rows_increment == 0) { + max_rows_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_rows_increment = (tsk_tbl_size_t) max_rows_increment; + return 0; +} + +int +tsk_provenance_tbl_set_max_timestamp_length_increment(tsk_provenance_tbl_t *self, + size_t max_timestamp_length_increment) +{ + if (max_timestamp_length_increment == 0) { + max_timestamp_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_timestamp_length_increment = (tsk_tbl_size_t) max_timestamp_length_increment; + return 0; +} + +int +tsk_provenance_tbl_set_max_record_length_increment(tsk_provenance_tbl_t *self, + size_t max_record_length_increment) +{ + if (max_record_length_increment == 0) { + max_record_length_increment = DEFAULT_SIZE_INCREMENT; + } + self->max_record_length_increment = (tsk_tbl_size_t) max_record_length_increment; + return 0; +} + +int +tsk_provenance_tbl_alloc(tsk_provenance_tbl_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(tsk_provenance_tbl_t)); + /* Allocate space for one row initially, ensuring we always have valid pointers + * even if the table is empty */ + self->max_rows_increment = 1; + self->max_timestamp_length_increment = 1; + self->max_record_length_increment = 1; + ret = tsk_provenance_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_expand_timestamp(self, 1); + if (ret != 0) { + goto out; + } + self->timestamp_offset[0] = 0; + ret = tsk_provenance_tbl_expand_provenance(self, 1); + if (ret != 0) { + goto out; + } + self->record_offset[0] = 0; + self->max_rows_increment = DEFAULT_SIZE_INCREMENT; + self->max_timestamp_length_increment = DEFAULT_SIZE_INCREMENT; + self->max_record_length_increment = DEFAULT_SIZE_INCREMENT; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_provenance_tbl_copy(tsk_provenance_tbl_t *self, tsk_provenance_tbl_t *dest) +{ + return tsk_provenance_tbl_set_columns(dest, self->num_rows, + self->timestamp, self->timestamp_offset, + self->record, self->record_offset); +} + +int +tsk_provenance_tbl_set_columns(tsk_provenance_tbl_t *self, size_t num_rows, + char *timestamp, uint32_t *timestamp_offset, + char *record, uint32_t *record_offset) +{ + int ret; + + ret = tsk_provenance_tbl_clear(self); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_append_columns(self, num_rows, + timestamp, timestamp_offset, record, record_offset); +out: + return ret; +} + +int +tsk_provenance_tbl_append_columns(tsk_provenance_tbl_t *self, size_t num_rows, + char *timestamp, uint32_t *timestamp_offset, + char *record, uint32_t *record_offset) +{ + int ret; + tsk_tbl_size_t j, timestamp_length, record_length; + + if (timestamp == NULL || timestamp_offset == NULL || + record == NULL || record_offset == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_provenance_tbl_expand_main_columns(self, (tsk_tbl_size_t) num_rows); + if (ret != 0) { + goto out; + } + + ret = check_offsets(num_rows, timestamp_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->timestamp_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->timestamp_length + timestamp_offset[j]; + } + timestamp_length = timestamp_offset[num_rows]; + ret = tsk_provenance_tbl_expand_timestamp(self, timestamp_length); + if (ret != 0) { + goto out; + } + memcpy(self->timestamp + self->timestamp_length, timestamp, + timestamp_length * sizeof(char)); + self->timestamp_length += timestamp_length; + + ret = check_offsets(num_rows, record_offset, 0, false); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_rows; j++) { + self->record_offset[self->num_rows + j] = + (tsk_tbl_size_t) self->record_length + record_offset[j]; + } + record_length = record_offset[num_rows]; + ret = tsk_provenance_tbl_expand_provenance(self, record_length); + if (ret != 0) { + goto out; + } + memcpy(self->record + self->record_length, record, record_length * sizeof(char)); + self->record_length += record_length; + + self->num_rows += (tsk_tbl_size_t) num_rows; + self->timestamp_offset[self->num_rows] = self->timestamp_length; + self->record_offset[self->num_rows] = self->record_length; +out: + return ret; +} + +static tsk_id_t +tsk_provenance_tbl_add_row_internal(tsk_provenance_tbl_t *self, + const char *timestamp, tsk_tbl_size_t timestamp_length, + const char *record, tsk_tbl_size_t record_length) +{ + int ret = 0; + + assert(self->num_rows < self->max_rows); + assert(self->timestamp_length + timestamp_length <= self->max_timestamp_length); + memcpy(self->timestamp + self->timestamp_length, timestamp, timestamp_length); + self->timestamp_offset[self->num_rows + 1] = self->timestamp_length + timestamp_length; + self->timestamp_length += timestamp_length; + assert(self->record_length + record_length <= self->max_record_length); + memcpy(self->record + self->record_length, record, record_length); + self->record_offset[self->num_rows + 1] = self->record_length + record_length; + self->record_length += record_length; + ret = (tsk_id_t) self->num_rows; + self->num_rows++; + return ret; +} + +tsk_id_t +tsk_provenance_tbl_add_row(tsk_provenance_tbl_t *self, + const char *timestamp, size_t timestamp_length, + const char *record, size_t record_length) +{ + int ret = 0; + + ret = tsk_provenance_tbl_expand_main_columns(self, 1); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_expand_timestamp(self, (tsk_tbl_size_t) timestamp_length); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_expand_provenance(self, (tsk_tbl_size_t) record_length); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_add_row_internal(self, + timestamp, (tsk_tbl_size_t) timestamp_length, + record, (tsk_tbl_size_t) record_length); +out: + return ret; +} + +int +tsk_provenance_tbl_clear(tsk_provenance_tbl_t *self) +{ + return tsk_provenance_tbl_truncate(self, 0); +} + +int +tsk_provenance_tbl_truncate(tsk_provenance_tbl_t *self, size_t num_rows) +{ + int ret = 0; + tsk_tbl_size_t n = (tsk_tbl_size_t) num_rows; + + if (n > self->num_rows) { + ret = TSK_ERR_BAD_TABLE_POSITION; + goto out; + } + self->num_rows = n; + self->timestamp_length = self->timestamp_offset[n]; + self->record_length = self->record_offset[n]; +out: + return ret; +} + +int +tsk_provenance_tbl_free(tsk_provenance_tbl_t *self) +{ + if (self->max_rows > 0) { + tsk_safe_free(self->timestamp); + tsk_safe_free(self->timestamp_offset); + tsk_safe_free(self->record); + tsk_safe_free(self->record_offset); + } + return 0; +} + +void +tsk_provenance_tbl_print_state(tsk_provenance_tbl_t *self, FILE *out) +{ + size_t j, k; + + fprintf(out, TABLE_SEP); + fprintf(out, "provenance_table: %p:\n", (void *) self); + fprintf(out, "num_rows = %d\tmax= %d\tincrement = %d)\n", + (int) self->num_rows, (int) self->max_rows, (int) self->max_rows_increment); + fprintf(out, "timestamp_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->timestamp_length, + (int) self->max_timestamp_length, + (int) self->max_timestamp_length_increment); + fprintf(out, "record_length = %d\tmax= %d\tincrement = %d)\n", + (int) self->record_length, + (int) self->max_record_length, + (int) self->max_record_length_increment); + fprintf(out, TABLE_SEP); + fprintf(out, "index\ttimestamp_offset\ttimestamp\trecord_offset\tprovenance\n"); + for (j = 0; j < self->num_rows; j++) { + fprintf(out, "%d\t%d\t", (int) j, self->timestamp_offset[j]); + for (k = self->timestamp_offset[j]; k < self->timestamp_offset[j + 1]; k++) { + fprintf(out, "%c", self->timestamp[k]); + } + fprintf(out, "\t%d\t", self->record_offset[j]); + for (k = self->record_offset[j]; k < self->record_offset[j + 1]; k++) { + fprintf(out, "%c", self->record[k]); + } + fprintf(out, "\n"); + } + assert(self->timestamp_offset[0] == 0); + assert(self->timestamp_offset[self->num_rows] == self->timestamp_length); + assert(self->record_offset[0] == 0); + assert(self->record_offset[self->num_rows] == self->record_length); +} + +int +tsk_provenance_tbl_get_row(tsk_provenance_tbl_t *self, size_t index, tsk_provenance_t *row) +{ + int ret = 0; + if (index >= self->num_rows) { + ret = TSK_ERR_PROVENANCE_OUT_OF_BOUNDS; + goto out; + } + row->id = (tsk_id_t) index; + row->timestamp_length = self->timestamp_offset[index + 1] + - self->timestamp_offset[index]; + row->timestamp = self->timestamp + self->timestamp_offset[index]; + row->record_length = self->record_offset[index + 1] + - self->record_offset[index]; + row->record = self->record + self->record_offset[index]; +out: + return ret; +} + +int +tsk_provenance_tbl_dump_text(tsk_provenance_tbl_t *self, FILE *out) +{ + int ret = TSK_ERR_IO; + int err; + size_t j; + tsk_tbl_size_t timestamp_len, record_len; + + err = fprintf(out, "record\ttimestamp\n"); + if (err < 0) { + goto out; + } + for (j = 0; j < self->num_rows; j++) { + record_len = self->record_offset[j + 1] - + self->record_offset[j]; + timestamp_len = self->timestamp_offset[j + 1] - self->timestamp_offset[j]; + err = fprintf(out, "%.*s\t%.*s\n", record_len, self->record + self->record_offset[j], + timestamp_len, self->timestamp + self->timestamp_offset[j]); + if (err < 0) { + goto out; + } + } + ret = 0; +out: + return ret; +} + +bool +tsk_provenance_tbl_equals(tsk_provenance_tbl_t *self, tsk_provenance_tbl_t *other) +{ + bool ret = false; + if (self->num_rows == other->num_rows + && self->timestamp_length == other->timestamp_length) { + ret = memcmp(self->timestamp_offset, other->timestamp_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->timestamp, other->timestamp, + self->timestamp_length * sizeof(char)) == 0 + && memcmp(self->record_offset, other->record_offset, + (self->num_rows + 1) * sizeof(tsk_tbl_size_t)) == 0 + && memcmp(self->record, other->record, + self->record_length * sizeof(char)) == 0; + } + return ret; +} + +static int +tsk_provenance_tbl_dump(tsk_provenance_tbl_t *self, kastore_t *store) +{ + write_table_col_t write_cols[] = { + {"provenances/timestamp", (void *) self->timestamp, + self->timestamp_length, KAS_UINT8}, + {"provenances/timestamp_offset", (void *) self->timestamp_offset, + self->num_rows+ 1, KAS_UINT32}, + {"provenances/record", (void *) self->record, + self->record_length, KAS_UINT8}, + {"provenances/record_offset", (void *) self->record_offset, + self->num_rows + 1, KAS_UINT32}, + }; + return write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +} + +static int +tsk_provenance_tbl_load(tsk_provenance_tbl_t *self, kastore_t *store) +{ + read_table_col_t read_cols[] = { + {"provenances/timestamp", (void **) &self->timestamp, + &self->timestamp_length, 0, KAS_UINT8}, + {"provenances/timestamp_offset", (void **) &self->timestamp_offset, + &self->num_rows, 1, KAS_UINT32}, + {"provenances/record", (void **) &self->record, + &self->record_length, 0, KAS_UINT8}, + {"provenances/record_offset", (void **) &self->record_offset, + &self->num_rows, 1, KAS_UINT32}, + }; + return read_table_cols(store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +/************************* + * sort_tables + *************************/ + +typedef struct { + double left; + double right; + tsk_id_t parent; + tsk_id_t child; + double time; +} edge_sort_t; + +typedef struct { + /* Input tables. */ + tsk_node_tbl_t *nodes; + tsk_edge_tbl_t *edges; + tsk_site_tbl_t *sites; + tsk_mutation_tbl_t *mutations; + tsk_migration_tbl_t *migrations; + /* Mapping from input site IDs to output site IDs */ + tsk_id_t *site_id_map; +} table_sorter_t; + +static int +cmp_site(const void *a, const void *b) { + const tsk_site_t *ia = (const tsk_site_t *) a; + const tsk_site_t *ib = (const tsk_site_t *) b; + /* Compare sites by position */ + int ret = (ia->position > ib->position) - (ia->position < ib->position); + if (ret == 0) { + /* Within a particular position sort by ID. This ensures that relative ordering + * of multiple sites at the same position is maintained; the redundant sites + * will get compacted down by clean_tables(), but in the meantime if the order + * of the redundant sites changes it will cause the sort order of mutations to + * be corrupted, as the mutations will follow their sites. */ + ret = (ia->id > ib->id) - (ia->id < ib->id); + } + return ret; +} + +static int +cmp_mutation(const void *a, const void *b) { + const tsk_mutation_t *ia = (const tsk_mutation_t *) a; + const tsk_mutation_t *ib = (const tsk_mutation_t *) b; + /* Compare mutations by site */ + int ret = (ia->site > ib->site) - (ia->site < ib->site); + if (ret == 0) { + /* Within a particular site sort by ID. This ensures that relative ordering + * within a site is maintained */ + ret = (ia->id > ib->id) - (ia->id < ib->id); + } + return ret; +} + +static int +cmp_edge(const void *a, const void *b) { + const edge_sort_t *ca = (const edge_sort_t *) a; + const edge_sort_t *cb = (const edge_sort_t *) b; + + int ret = (ca->time > cb->time) - (ca->time < cb->time); + /* If time values are equal, sort by the parent node */ + if (ret == 0) { + ret = (ca->parent > cb->parent) - (ca->parent < cb->parent); + /* If the parent nodes are equal, sort by the child ID. */ + if (ret == 0) { + ret = (ca->child > cb->child) - (ca->child < cb->child); + /* If the child nodes are equal, sort by the left coordinate. */ + if (ret == 0) { + ret = (ca->left > cb->left) - (ca->left < cb->left); + } + } + } + return ret; +} + +static int +table_sorter_alloc(table_sorter_t *self, tsk_tbl_collection_t *tables, + int TSK_UNUSED(flags)) +{ + int ret = 0; + + memset(self, 0, sizeof(table_sorter_t)); + if (tables == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_tbl_collection_check_integrity(tables, TSK_CHECK_OFFSETS); + if (ret != 0) { + goto out; + } + self->nodes = tables->nodes; + self->edges = tables->edges; + self->mutations = tables->mutations; + self->sites = tables->sites; + self->migrations = tables->migrations; + + self->site_id_map = malloc(self->sites->num_rows * sizeof(tsk_id_t)); + if (self->site_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +table_sorter_sort_edges(table_sorter_t *self, size_t start) +{ + int ret = 0; + edge_sort_t *e; + size_t j, k; + size_t n = self->edges->num_rows - start; + edge_sort_t *sorted_edges = malloc(n * sizeof(*sorted_edges)); + + if (sorted_edges == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (j = 0; j < n; j++) { + e = sorted_edges + j; + k = start + j; + e->left = self->edges->left[k]; + e->right = self->edges->right[k]; + e->parent = self->edges->parent[k]; + e->child = self->edges->child[k]; + e->time = self->nodes->time[e->parent]; + } + qsort(sorted_edges, n, sizeof(edge_sort_t), cmp_edge); + /* Copy the edges back into the table. */ + for (j = 0; j < n; j++) { + e = sorted_edges + j; + k = start + j; + self->edges->left[k] = e->left; + self->edges->right[k] = e->right; + self->edges->parent[k] = e->parent; + self->edges->child[k] = e->child; + } +out: + tsk_safe_free(sorted_edges); + return ret; +} + +static int +table_sorter_sort_sites(table_sorter_t *self) +{ + int ret = 0; + tsk_site_tbl_t copy; + tsk_tbl_size_t j; + tsk_tbl_size_t num_sites = self->sites->num_rows; + tsk_site_t *sorted_sites = malloc(num_sites * sizeof(*sorted_sites)); + + ret = tsk_site_tbl_alloc(©, 0); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_copy(self->sites, ©); + if (ret != 0) { + goto out; + } + if (sorted_sites == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (j = 0; j < num_sites; j++) { + ret = tsk_site_tbl_get_row(©, j, sorted_sites + j); + if (ret != 0) { + goto out; + } + } + + /* Sort the sites by position */ + qsort(sorted_sites, self->sites->num_rows, sizeof(*sorted_sites), cmp_site); + + /* Build the mapping from old site IDs to new site IDs and copy back into the table */ + tsk_site_tbl_clear(self->sites); + for (j = 0; j < num_sites; j++) { + self->site_id_map[sorted_sites[j].id] = (tsk_id_t) j; + ret = tsk_site_tbl_add_row(self->sites, sorted_sites[j].position, + sorted_sites[j].ancestral_state, sorted_sites[j].ancestral_state_length, + sorted_sites[j].metadata, sorted_sites[j].metadata_length); + if (ret < 0) { + goto out; + } + } + ret = 0; +out: + tsk_safe_free(sorted_sites); + tsk_site_tbl_free(©); + return ret; +} + +static int +table_sorter_sort_mutations(table_sorter_t *self) +{ + int ret = 0; + size_t j; + tsk_id_t parent, mapped_parent; + size_t num_mutations = self->mutations->num_rows; + tsk_mutation_tbl_t copy; + tsk_mutation_t *sorted_mutations = malloc(num_mutations * sizeof(*sorted_mutations)); + tsk_id_t *mutation_id_map = malloc(num_mutations * sizeof(*mutation_id_map)); + + ret = tsk_mutation_tbl_alloc(©, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_copy(self->mutations, ©); + if (ret != 0) { + goto out; + } + if (mutation_id_map == NULL || sorted_mutations == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (j = 0; j < num_mutations; j++) { + ret = tsk_mutation_tbl_get_row(©, j, sorted_mutations + j); + if (ret != 0) { + goto out; + } + sorted_mutations[j].site = self->site_id_map[sorted_mutations[j].site]; + } + ret = tsk_mutation_tbl_clear(self->mutations); + if (ret != 0) { + goto out; + } + + qsort(sorted_mutations, num_mutations, sizeof(*sorted_mutations), cmp_mutation); + + /* Make a first pass through the sorted mutations to build the ID map. */ + for (j = 0; j < num_mutations; j++) { + mutation_id_map[sorted_mutations[j].id] = (tsk_id_t) j; + } + + for (j = 0; j < num_mutations; j++) { + mapped_parent = TSK_NULL; + parent = sorted_mutations[j].parent; + if (parent != TSK_NULL) { + mapped_parent = mutation_id_map[parent]; + } + ret = tsk_mutation_tbl_add_row(self->mutations, + sorted_mutations[j].site, + sorted_mutations[j].node, + mapped_parent, + sorted_mutations[j].derived_state, + sorted_mutations[j].derived_state_length, + sorted_mutations[j].metadata, + sorted_mutations[j].metadata_length); + if (ret < 0) { + goto out; + } + } + ret = 0; + +out: + tsk_safe_free(mutation_id_map); + tsk_safe_free(sorted_mutations); + tsk_mutation_tbl_free(©); + return ret; +} + +static int +table_sorter_run(table_sorter_t *self, size_t edge_start) +{ + int ret = 0; + + ret = table_sorter_sort_edges(self, edge_start); + if (ret != 0) { + goto out; + } + ret = table_sorter_sort_sites(self); + if (ret != 0) { + goto out; + } + ret = table_sorter_sort_mutations(self); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +static void +table_sorter_free(table_sorter_t *self) +{ + tsk_safe_free(self->site_id_map); +} + + +/************************* + * segment overlapper + *************************/ + +/* TODO: This should be renamed to tsk_segment_t when we move to tskit. + * msprime can then #include this and also use it. msprime needs a different + * segment definition, which it can continue to call 'segment_t' as it + * doesn't export a C API. */ +typedef struct _simplify_segment_t { + double left; + double right; + struct _simplify_segment_t *next; + tsk_id_t node; +} simplify_segment_t; + +typedef struct _interval_list_t { + double left; + double right; + struct _interval_list_t *next; +} interval_list_t; + +typedef struct _mutation_id_list_t { + tsk_id_t mutation; + struct _mutation_id_list_t *next; +} mutation_id_list_t; + +/* segment overlap finding algorithm */ +typedef struct { + /* The input segments. This buffer is sorted by the algorithm and we also + * assume that there is space for an extra element at the end */ + simplify_segment_t *segments; + size_t num_segments; + size_t index; + size_t num_overlapping; + double left; + double right; + /* Output buffer */ + size_t max_overlapping; + simplify_segment_t **overlapping; +} segment_overlapper_t; + +typedef struct { + tsk_id_t *samples; + size_t num_samples; + int flags; + tsk_tbl_collection_t *tables; + /* Keep a copy of the input tables */ + tsk_tbl_collection_t input_tables; + /* State for topology */ + simplify_segment_t **ancestor_map_head; + simplify_segment_t **ancestor_map_tail; + tsk_id_t *node_id_map; + bool *is_sample; + /* Segments for a particular parent that are processed together */ + simplify_segment_t *segment_queue; + size_t segment_queue_size; + size_t max_segment_queue_size; + segment_overlapper_t segment_overlapper; + tsk_blkalloc_t segment_heap; + /* Buffer for output edges. For each child we keep a linked list of + * intervals, and also store the actual children that have been buffered. */ + tsk_blkalloc_t interval_list_heap; + interval_list_t **child_edge_map_head; + interval_list_t **child_edge_map_tail; + tsk_id_t *buffered_children; + size_t num_buffered_children; + /* For each mutation, map its output node. */ + tsk_id_t *mutation_node_map; + /* Map of input mutation IDs to output mutation IDs. */ + tsk_id_t *mutation_id_map; + /* Map of input nodes to the list of input mutation IDs */ + mutation_id_list_t **node_mutation_list_map_head; + mutation_id_list_t **node_mutation_list_map_tail; + mutation_id_list_t *node_mutation_list_mem; + /* When reducing topology, we need a map positions to their corresponding + * sites.*/ + double *position_lookup; +} simplifier_t; + +static int +cmp_segment(const void *a, const void *b) { + const simplify_segment_t *ia = (const simplify_segment_t *) a; + const simplify_segment_t *ib = (const simplify_segment_t *) b; + int ret = (ia->left > ib->left) - (ia->left < ib->left); + /* Break ties using the node */ + if (ret == 0) { + ret = (ia->node > ib->node) - (ia->node < ib->node); + } + return ret; +} + +static int TSK_WARN_UNUSED +segment_overlapper_alloc(segment_overlapper_t *self) +{ + int ret = 0; + + memset(self, 0, sizeof(*self)); + self->max_overlapping = 8; /* Making sure we call realloc in tests */ + self->overlapping = malloc(self->max_overlapping * sizeof(*self->overlapping)); + if (self->overlapping == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +segment_overlapper_free(segment_overlapper_t *self) +{ + tsk_safe_free(self->overlapping); + return 0; +} + +/* Initialise the segment overlapper for use. Note that the segments + * array must have space for num_segments + 1 elements! + */ +static int TSK_WARN_UNUSED +segment_overlapper_init(segment_overlapper_t *self, simplify_segment_t *segments, + size_t num_segments) +{ + int ret = 0; + simplify_segment_t *sentinel; + void *p; + + if (self->max_overlapping < num_segments) { + self->max_overlapping = num_segments; + p = realloc(self->overlapping, + self->max_overlapping * sizeof(*self->overlapping)); + if (p == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->overlapping = p; + + } + self->segments = segments; + self->num_segments = num_segments; + self->index = 0; + self->num_overlapping = 0; + self->left = 0; + self->right = DBL_MAX; + + /* Sort the segments in the buffer by left coordinate */ + qsort(self->segments, self->num_segments, sizeof(simplify_segment_t), cmp_segment); + /* NOTE! We are assuming that there's space for another element on the end + * here. This is to insert a sentinel which simplifies the logic. */ + sentinel = self->segments + self->num_segments; + sentinel->left = DBL_MAX; +out: + return ret; +} + +static int TSK_WARN_UNUSED +segment_overlapper_next(segment_overlapper_t *self, + double *left, double *right, simplify_segment_t ***overlapping, + size_t *num_overlapping) +{ + int ret = 0; + size_t j, k; + size_t n = self->num_segments; + simplify_segment_t *S = self->segments; + + if (self->index < n) { + self->left = self->right; + /* Remove any elements of X with right <= left */ + k = 0; + for (j = 0; j < self->num_overlapping; j++) { + if (self->overlapping[j]->right > self->left) { + self->overlapping[k] = self->overlapping[j]; + k++; + } + } + self->num_overlapping = k; + if (k == 0) { + self->left = S[self->index].left; + } + while (self->index < n && S[self->index].left == self->left) { + assert(self->num_overlapping < self->max_overlapping); + self->overlapping[self->num_overlapping] = &S[self->index]; + self->num_overlapping++; + self->index++; + } + self->index--; + self->right = S[self->index + 1].left; + for (j = 0; j < self->num_overlapping; j++) { + self->right = TSK_MIN(self->right, self->overlapping[j]->right); + } + assert(self->left < self->right); + self->index++; + ret = 1; + } else { + self->left = self->right; + self->right = DBL_MAX; + k = 0; + for (j = 0; j < self->num_overlapping; j++) { + if (self->overlapping[j]->right > self->left) { + self->right = TSK_MIN(self->right, self->overlapping[j]->right); + self->overlapping[k] = self->overlapping[j]; + k++; + } + } + self->num_overlapping = k; + if (k > 0) { + ret = 1; + } + } + + *left = self->left; + *right = self->right; + *overlapping = self->overlapping; + *num_overlapping = self->num_overlapping; + return ret; +} + +/************************* + * simplifier + *************************/ + +static int +cmp_node_id(const void *a, const void *b) { + const tsk_id_t *ia = (const tsk_id_t *) a; + const tsk_id_t *ib = (const tsk_id_t *) b; + return (*ia > *ib) - (*ia < *ib); +} + +static void +simplifier_check_state(simplifier_t *self) +{ + size_t j, k; + simplify_segment_t *u; + mutation_id_list_t *list_node; + tsk_id_t site; + interval_list_t *int_list; + tsk_id_t child; + double position, last_position; + bool found; + size_t num_intervals; + + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + assert((self->ancestor_map_head[j] == NULL) == + (self->ancestor_map_tail[j] == NULL)); + for (u = self->ancestor_map_head[j]; u != NULL; u = u->next) { + assert(u->left < u->right); + if (u->next != NULL) { + assert(u->right <= u->next->left); + if (u->right == u->next->left) { + assert(u->node != u->next->node); + } + } else { + assert(u == self->ancestor_map_tail[j]); + } + } + } + + for (j = 0; j < self->segment_queue_size; j++) { + assert(self->segment_queue[j].left < self->segment_queue[j].right); + } + + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + last_position = -1; + for (list_node = self->node_mutation_list_map_head[j]; list_node != NULL; + list_node = list_node->next) { + assert(self->input_tables.mutations->node[list_node->mutation] == (tsk_id_t) j); + site = self->input_tables.mutations->site[list_node->mutation]; + position = self->input_tables.sites->position[site]; + assert(last_position <= position); + last_position = position; + } + } + + /* check the buffered edges */ + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + assert((self->child_edge_map_head[j] == NULL) == + (self->child_edge_map_tail[j] == NULL)); + if (self->child_edge_map_head[j] != NULL) { + /* Make sure that the child is in our list */ + found = false; + for (k = 0; k < self->num_buffered_children; k++) { + if (self->buffered_children[k] == (tsk_id_t) j) { + found = true; + break; + } + } + assert(found); + } + } + num_intervals = 0; + for (j = 0; j < self->num_buffered_children; j++) { + child = self->buffered_children[j]; + assert(self->child_edge_map_head[child] != NULL); + for (int_list = self->child_edge_map_head[child]; int_list != NULL; + int_list = int_list->next) { + assert(int_list->left < int_list->right); + if (int_list->next != NULL) { + assert(int_list->right < int_list->next->left); + } + num_intervals++; + } + } + assert(num_intervals == + self->interval_list_heap.total_allocated / (sizeof(interval_list_t))); +} + +static void +print_segment_chain(simplify_segment_t *head, FILE *out) +{ + simplify_segment_t *u; + + for (u = head; u != NULL; u = u->next) { + fprintf(out, "(%f,%f->%d)", u->left, u->right, u->node); + } +} + +static void +simplifier_print_state(simplifier_t *self, FILE *out) +{ + size_t j; + simplify_segment_t *u; + mutation_id_list_t *list_node; + interval_list_t *int_list; + tsk_id_t child; + + fprintf(out, "--simplifier state--\n"); + fprintf(out, "flags:\n"); + fprintf(out, "\tfilter_unreferenced_sites: %d\n", + !!(self->flags & TSK_FILTER_SITES)); + fprintf(out, "\treduce_to_site_topology : %d\n", + !!(self->flags & TSK_REDUCE_TO_SITE_TOPOLOGY)); + + fprintf(out, "===\nInput tables\n==\n"); + tsk_tbl_collection_print_state(&self->input_tables, out); + fprintf(out, "===\nOutput tables\n==\n"); + tsk_tbl_collection_print_state(self->tables, out); + fprintf(out, "===\nmemory heaps\n==\n"); + fprintf(out, "segment_heap:\n"); + tsk_blkalloc_print_state(&self->segment_heap, out); + fprintf(out, "interval_list_heap:\n"); + tsk_blkalloc_print_state(&self->interval_list_heap, out); + fprintf(out, "===\nancestors\n==\n"); + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + fprintf(out, "%d:\t", (int) j); + print_segment_chain(self->ancestor_map_head[j], out); + fprintf(out, "\n"); + } + fprintf(out, "===\nnode_id map (input->output)\n==\n"); + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + if (self->node_id_map[j] != TSK_NULL) { + fprintf(out, "%d->%d\n", (int) j, self->node_id_map[j]); + } + } + fprintf(out, "===\nsegment queue\n==\n"); + for (j = 0; j < self->segment_queue_size; j++) { + u = &self->segment_queue[j]; + fprintf(out, "(%f,%f->%d)", u->left, u->right, u->node); + fprintf(out, "\n"); + } + fprintf(out, "===\nbuffered children\n==\n"); + for (j = 0; j < self->num_buffered_children; j++) { + child = self->buffered_children[j]; + fprintf(out, "%d -> ", (int) j); + for (int_list = self->child_edge_map_head[child]; int_list != NULL; + int_list = int_list->next) { + fprintf(out, "(%f, %f), ", int_list->left, int_list->right); + } + fprintf(out, "\n"); + } + fprintf(out, "===\nmutation node map\n==\n"); + for (j = 0; j < self->input_tables.mutations->num_rows; j++) { + fprintf(out, "%d\t-> %d\n", (int) j, self->mutation_node_map[j]); + } + fprintf(out, "===\nnode mutation id list map\n==\n"); + for (j = 0; j < self->input_tables.nodes->num_rows; j++) { + if (self->node_mutation_list_map_head[j] != NULL) { + fprintf(out, "%d\t-> [", (int) j); + for (list_node = self->node_mutation_list_map_head[j]; list_node != NULL; + list_node = list_node->next) { + fprintf(out, "%d,", list_node->mutation); + } + fprintf(out, "]\n"); + } + } + if (!!(self->flags & TSK_REDUCE_TO_SITE_TOPOLOGY)) { + fprintf(out, "===\nposition_lookup\n==\n"); + for (j = 0; j < self->input_tables.sites->num_rows + 2; j++) { + fprintf(out, "%d\t-> %f\n", (int) j, self->position_lookup[j]); + } + } + simplifier_check_state(self); +} + +static simplify_segment_t * TSK_WARN_UNUSED +simplifier_alloc_segment(simplifier_t *self, double left, double right, tsk_id_t node) +{ + simplify_segment_t *seg = NULL; + + seg = tsk_blkalloc_get(&self->segment_heap, sizeof(*seg)); + if (seg == NULL) { + goto out; + } + seg->next = NULL; + seg->left = left; + seg->right = right; + seg->node = node; +out: + return seg; +} + +static interval_list_t * TSK_WARN_UNUSED +simplifier_alloc_interval_list(simplifier_t *self, double left, double right) +{ + interval_list_t *x = NULL; + + x = tsk_blkalloc_get(&self->interval_list_heap, sizeof(*x)); + if (x == NULL) { + goto out; + } + x->next = NULL; + x->left = left; + x->right = right; +out: + return x; +} + +/* Add a new node to the output node table corresponding to the specified input id. + * Returns the new ID. */ +static int TSK_WARN_UNUSED +simplifier_record_node(simplifier_t *self, tsk_id_t input_id, bool is_sample) +{ + int ret = 0; + tsk_node_t node; + uint32_t flags; + + ret = tsk_node_tbl_get_row(self->input_tables.nodes, (size_t) input_id, &node); + if (ret != 0) { + goto out; + } + /* Zero out the sample bit */ + flags = node.flags & (uint32_t) ~TSK_NODE_IS_SAMPLE; + if (is_sample) { + flags |= TSK_NODE_IS_SAMPLE; + } + self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes->num_rows; + ret = tsk_node_tbl_add_row(self->tables->nodes, flags, + node.time, node.population, node.individual, + node.metadata, node.metadata_length); +out: + return ret; +} + +/* Remove the mapping for the last recorded node. */ +static int +simplifier_rewind_node(simplifier_t *self, tsk_id_t input_id, tsk_id_t output_id) +{ + self->node_id_map[input_id] = TSK_NULL; + return tsk_node_tbl_truncate(self->tables->nodes, (size_t) output_id); +} + +static int +simplifier_flush_edges(simplifier_t *self, tsk_id_t parent, size_t *ret_num_edges) +{ + int ret = 0; + size_t j; + tsk_id_t child; + interval_list_t *x; + size_t num_edges = 0; + + qsort(self->buffered_children, self->num_buffered_children, + sizeof(tsk_id_t), cmp_node_id); + for (j = 0; j < self->num_buffered_children; j++) { + child = self->buffered_children[j]; + for (x = self->child_edge_map_head[child]; x != NULL; x = x->next) { + ret = tsk_edge_tbl_add_row(self->tables->edges, x->left, x->right, parent, child); + if (ret < 0) { + goto out; + } + num_edges++; + } + self->child_edge_map_head[child] = NULL; + self->child_edge_map_tail[child] = NULL; + } + self->num_buffered_children = 0; + *ret_num_edges = num_edges; + ret = tsk_blkalloc_reset(&self->interval_list_heap); +out: + return ret; +} + +/* When we are reducing topology down to what is visible at the sites we need a + * lookup table to find the closest site position for each edge. We do this with + * a sorted array and binary search */ +static int +simplifier_init_position_lookup(simplifier_t *self) +{ + int ret = 0; + size_t num_sites = self->input_tables.sites->num_rows; + + self->position_lookup = malloc((num_sites + 2) * sizeof(*self->position_lookup)); + if (self->position_lookup == NULL) { + goto out; + } + self->position_lookup[0] = 0; + self->position_lookup[num_sites + 1] = self->tables->sequence_length; + memcpy(self->position_lookup + 1, self->input_tables.sites->position, + num_sites * sizeof(double)); +out: + return ret; +} +/* + * Find the smallest site position index greater than or equal to left + * and right, i.e., slide each endpoint of an interval to the right + * until they hit a site position. If both left and right map to the + * the same position then we discard this edge. We also discard an + * edge if left = 0 and right is less than the first site position. + */ +static bool +simplifier_map_reduced_coordinates(simplifier_t *self, double *left, double *right) +{ + double *X = self->position_lookup; + size_t N = self->input_tables.sites->num_rows + 2; + size_t left_index, right_index; + bool skip = false; + + left_index = tsk_search_sorted(X, N, *left); + right_index = tsk_search_sorted(X, N, *right); + if (left_index == right_index || (left_index == 0 && right_index == 1)) { + skip = true; + } else { + /* Remap back to zero if the left end maps to the first site. */ + if (left_index == 1) { + left_index = 0; + } + *left = X[left_index]; + *right = X[right_index]; + } + return skip; +} + +/* Records the specified edge for the current parent by buffering it */ +static int +simplifier_record_edge(simplifier_t *self, double left, double right, tsk_id_t child) +{ + int ret = 0; + interval_list_t *tail, *x; + bool skip; + + if (!!(self->flags & TSK_REDUCE_TO_SITE_TOPOLOGY)) { + skip = simplifier_map_reduced_coordinates(self, &left, &right); + /* NOTE: we exit early here when reduce_coordindates has told us to + * skip this edge, as it is not visible in the reduced tree sequence */ + if (skip) { + goto out; + } + } + + tail = self->child_edge_map_tail[child]; + if (tail == NULL) { + assert(self->num_buffered_children < self->input_tables.nodes->num_rows); + self->buffered_children[self->num_buffered_children] = child; + self->num_buffered_children++; + x = simplifier_alloc_interval_list(self, left, right); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->child_edge_map_head[child] = x; + self->child_edge_map_tail[child] = x; + } else { + if (tail->right == left) { + tail->right = right; + } else { + x = simplifier_alloc_interval_list(self, left, right); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tail->next = x; + self->child_edge_map_tail[child] = x; + } + } +out: + return ret; +} + +static int +simplifier_init_sites(simplifier_t *self) +{ + int ret = 0; + tsk_id_t node; + mutation_id_list_t *list_node; + size_t j; + + self->mutation_id_map = calloc(self->input_tables.mutations->num_rows, + sizeof(tsk_id_t)); + self->mutation_node_map = calloc(self->input_tables.mutations->num_rows, + sizeof(tsk_id_t)); + self->node_mutation_list_mem = malloc(self->input_tables.mutations->num_rows * + sizeof(mutation_id_list_t)); + self->node_mutation_list_map_head = calloc(self->input_tables.nodes->num_rows, + sizeof(mutation_id_list_t *)); + self->node_mutation_list_map_tail = calloc(self->input_tables.nodes->num_rows, + sizeof(mutation_id_list_t *)); + if (self->mutation_id_map == NULL || self->mutation_node_map == NULL + || self->node_mutation_list_mem == NULL + || self->node_mutation_list_map_head == NULL + || self->node_mutation_list_map_tail == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(self->mutation_id_map, 0xff, + self->input_tables.mutations->num_rows * sizeof(tsk_id_t)); + memset(self->mutation_node_map, 0xff, + self->input_tables.mutations->num_rows * sizeof(tsk_id_t)); + + for (j = 0; j < self->input_tables.mutations->num_rows; j++) { + node = self->input_tables.mutations->node[j]; + list_node = self->node_mutation_list_mem + j; + list_node->mutation = (tsk_id_t) j; + list_node->next = NULL; + if (self->node_mutation_list_map_head[node] == NULL) { + self->node_mutation_list_map_head[node] = list_node; + } else { + self->node_mutation_list_map_tail[node]->next = list_node; + } + self->node_mutation_list_map_tail[node] = list_node; + } +out: + return ret; + +} + +static int TSK_WARN_UNUSED +simplifier_add_ancestry(simplifier_t *self, tsk_id_t input_id, double left, double right, + tsk_id_t output_id) +{ + int ret = 0; + simplify_segment_t *tail = self->ancestor_map_tail[input_id]; + simplify_segment_t *x; + + assert(left < right); + if (tail == NULL) { + x = simplifier_alloc_segment(self, left, right, output_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->ancestor_map_head[input_id] = x; + self->ancestor_map_tail[input_id] = x; + } else { + if (tail->right == left && tail->node == output_id) { + tail->right = right; + } else { + x = simplifier_alloc_segment(self, left, right, output_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tail->next = x; + self->ancestor_map_tail[input_id] = x; + } + } +out: + return ret; +} + +static int +simplifier_init_samples(simplifier_t *self, tsk_id_t *samples) +{ + int ret = 0; + size_t j; + + /* Go through the samples to check for errors. */ + for (j = 0; j < self->num_samples; j++) { + if (samples[j] < 0 || samples[j] > (tsk_id_t) self->input_tables.nodes->num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (!(self->input_tables.nodes->flags[self->samples[j]] & TSK_NODE_IS_SAMPLE)) { + ret = TSK_ERR_BAD_SAMPLES; + goto out; + } + if (self->is_sample[samples[j]]) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + self->is_sample[samples[j]] = true; + ret = simplifier_record_node(self, samples[j], true); + if (ret < 0) { + goto out; + } + ret = simplifier_add_ancestry(self, samples[j], 0, self->tables->sequence_length, + (tsk_id_t) ret); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +static int +simplifier_alloc(simplifier_t *self, tsk_id_t *samples, size_t num_samples, + tsk_tbl_collection_t *tables, int flags) +{ + int ret = 0; + size_t num_nodes_alloc; + + memset(self, 0, sizeof(simplifier_t)); + if (samples == NULL || tables == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + self->num_samples = num_samples; + self->flags = flags; + self->tables = tables; + + /* TODO we can add a flag to skip these checks for when we know they are + * unnecessary */ + /* TODO Current unit tests require TSK_CHECK_SITE_DUPLICATES but it's + * debateable whether we need it. If we remove, we definitely need explicit + * tests to ensure we're doing sensible things with duplicate sites. + * (Particularly, re TSK_REDUCE_TO_SITE_TOPOLOGY.) */ + ret = tsk_tbl_collection_check_integrity(tables, + TSK_CHECK_OFFSETS|TSK_CHECK_EDGE_ORDERING|TSK_CHECK_SITE_ORDERING| + TSK_CHECK_SITE_DUPLICATES); + if (ret != 0) { + goto out; + } + + ret = tsk_tbl_collection_alloc(&self->input_tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_copy(self->tables, &self->input_tables); + if (ret != 0) { + goto out; + } + + /* Take a copy of the input samples */ + self->samples = malloc(num_samples * sizeof(tsk_id_t)); + if (self->samples == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->samples, samples, num_samples * sizeof(tsk_id_t)); + + /* Allocate the heaps used for small objects-> Assuming 8K is a good chunk size */ + ret = tsk_blkalloc_alloc(&self->segment_heap, 8192); + if (ret != 0) { + goto out; + } + ret = tsk_blkalloc_alloc(&self->interval_list_heap, 8192); + if (ret != 0) { + goto out; + } + ret = segment_overlapper_alloc(&self->segment_overlapper); + if (ret != 0) { + goto out; + } + /* Need to avoid malloc(0) so make sure we have at least 1. */ + num_nodes_alloc = 1 + tables->nodes->num_rows; + /* Make the maps and set the intial state */ + self->ancestor_map_head = calloc(num_nodes_alloc, sizeof(simplify_segment_t *)); + self->ancestor_map_tail = calloc(num_nodes_alloc, sizeof(simplify_segment_t *)); + self->child_edge_map_head = calloc(num_nodes_alloc, sizeof(interval_list_t *)); + self->child_edge_map_tail = calloc(num_nodes_alloc, sizeof(interval_list_t *)); + self->node_id_map = malloc(num_nodes_alloc * sizeof(tsk_id_t)); + self->buffered_children = malloc(num_nodes_alloc * sizeof(tsk_id_t)); + self->is_sample = calloc(num_nodes_alloc, sizeof(bool)); + self->max_segment_queue_size = 64; + self->segment_queue = malloc(self->max_segment_queue_size + * sizeof(simplify_segment_t)); + if (self->ancestor_map_head == NULL || self->ancestor_map_tail == NULL + || self->child_edge_map_head == NULL || self->child_edge_map_tail == NULL + || self->node_id_map == NULL || self->is_sample == NULL + || self->segment_queue == NULL || self->buffered_children == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_tbl_collection_clear(self->tables); + if (ret != 0) { + goto out; + } + memset(self->node_id_map, 0xff, self->input_tables.nodes->num_rows * sizeof(tsk_id_t)); + ret = simplifier_init_samples(self, samples); + if (ret != 0) { + goto out; + } + ret = simplifier_init_sites(self); + if (ret != 0) { + goto out; + } + if (!!(self->flags & TSK_REDUCE_TO_SITE_TOPOLOGY)) { + ret = simplifier_init_position_lookup(self); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +static int +simplifier_free(simplifier_t *self) +{ + tsk_tbl_collection_free(&self->input_tables); + tsk_blkalloc_free(&self->segment_heap); + tsk_blkalloc_free(&self->interval_list_heap); + segment_overlapper_free(&self->segment_overlapper); + tsk_safe_free(self->samples); + tsk_safe_free(self->ancestor_map_head); + tsk_safe_free(self->ancestor_map_tail); + tsk_safe_free(self->child_edge_map_head); + tsk_safe_free(self->child_edge_map_tail); + tsk_safe_free(self->node_id_map); + tsk_safe_free(self->segment_queue); + tsk_safe_free(self->is_sample); + tsk_safe_free(self->mutation_id_map); + tsk_safe_free(self->mutation_node_map); + tsk_safe_free(self->node_mutation_list_mem); + tsk_safe_free(self->node_mutation_list_map_head); + tsk_safe_free(self->node_mutation_list_map_tail); + tsk_safe_free(self->buffered_children); + tsk_safe_free(self->position_lookup); + return 0; +} + +static int TSK_WARN_UNUSED +simplifier_enqueue_segment(simplifier_t *self, double left, double right, tsk_id_t node) +{ + int ret = 0; + simplify_segment_t *seg; + void *p; + + assert(left < right); + /* Make sure we always have room for one more segment in the queue so we + * can put a tail sentinel on it */ + if (self->segment_queue_size == self->max_segment_queue_size - 1) { + self->max_segment_queue_size *= 2; + p = realloc(self->segment_queue, + self->max_segment_queue_size * sizeof(*self->segment_queue)); + if (p == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->segment_queue = p; + } + seg = self->segment_queue + self->segment_queue_size; + seg->left = left; + seg->right = right; + seg->node = node; + self->segment_queue_size++; +out: + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id) +{ + int ret = 0; + simplify_segment_t **X, *x; + size_t j, num_overlapping, num_flushed_edges; + double left, right, prev_right; + tsk_id_t ancestry_node; + tsk_id_t output_id = self->node_id_map[input_id]; + bool is_sample = output_id != TSK_NULL; + + if (is_sample) { + /* Free up the existing ancestry mapping. */ + x = self->ancestor_map_tail[input_id]; + assert(x->left == 0 && x->right == self->tables->sequence_length); + self->ancestor_map_head[input_id] = NULL; + self->ancestor_map_tail[input_id] = NULL; + } + + ret = segment_overlapper_init(&self->segment_overlapper, + self->segment_queue, self->segment_queue_size); + if (ret != 0) { + goto out; + } + prev_right = 0; + while ((ret = segment_overlapper_next(&self->segment_overlapper, + &left, &right, &X, &num_overlapping)) == 1) { + assert(left < right); + assert(num_overlapping > 0); + if (num_overlapping == 1) { + ancestry_node = X[0]->node; + if (is_sample) { + ret = simplifier_record_edge(self, left, right, ancestry_node); + if (ret != 0) { + goto out; + } + ancestry_node = output_id; + } + } else { + if (output_id == TSK_NULL) { + ret = simplifier_record_node(self, input_id, false); + if (ret < 0) { + goto out; + } + output_id = (tsk_id_t) ret; + } + ancestry_node = output_id; + for (j = 0; j < num_overlapping; j++) { + ret = simplifier_record_edge(self, left, right, X[j]->node); + if (ret != 0) { + goto out; + } + } + + } + if (is_sample && left != prev_right) { + /* Fill in any gaps in ancestry for the sample */ + ret = simplifier_add_ancestry(self, input_id, prev_right, left, output_id); + if (ret != 0) { + goto out; + } + } + ret = simplifier_add_ancestry(self, input_id, left, right, ancestry_node); + if (ret != 0) { + goto out; + } + prev_right = right; + } + /* Check for errors occuring in the loop condition */ + if (ret != 0) { + goto out; + } + if (is_sample && prev_right != self->tables->sequence_length) { + /* If a trailing gap exists in the sample ancestry, fill it in. */ + ret = simplifier_add_ancestry(self, input_id, prev_right, + self->tables->sequence_length, output_id); + if (ret != 0) { + goto out; + } + } + if (output_id != TSK_NULL) { + ret = simplifier_flush_edges(self, output_id, &num_flushed_edges); + if (ret != 0) { + goto out; + } + if (num_flushed_edges == 0 && !is_sample) { + ret = simplifier_rewind_node(self, input_id, output_id); + } + } +out: + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_process_parent_edges(simplifier_t *self, tsk_id_t parent, size_t start, + size_t end) +{ + int ret = 0; + size_t j; + simplify_segment_t *x; + const tsk_edge_tbl_t *input_edges = self->input_tables.edges; + tsk_id_t child; + double left, right; + + /* Go through the edges and queue up ancestry segments for processing. */ + self->segment_queue_size = 0; + for (j = start; j < end; j++) { + assert(parent == input_edges->parent[j]); + child = input_edges->child[j]; + left = input_edges->left[j]; + right = input_edges->right[j]; + for (x = self->ancestor_map_head[child]; x != NULL; x = x->next) { + if (x->right > left && right > x->left) { + ret = simplifier_enqueue_segment(self, + TSK_MAX(x->left, left), TSK_MIN(x->right, right), x->node); + if (ret != 0) { + goto out; + } + } + } + } + /* We can now merge the ancestral segments for the parent */ + ret = simplifier_merge_ancestors(self, parent); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_map_mutation_nodes(simplifier_t *self) +{ + int ret = 0; + simplify_segment_t *seg; + mutation_id_list_t *m_node; + size_t input_node; + tsk_id_t site; + double position; + + for (input_node = 0; input_node < self->input_tables.nodes->num_rows; input_node++) { + seg = self->ancestor_map_head[input_node]; + m_node = self->node_mutation_list_map_head[input_node]; + /* Co-iterate over the segments and mutations; mutations must be listed + * in increasing order of site position */ + while (seg != NULL && m_node != NULL) { + site = self->input_tables.mutations->site[m_node->mutation]; + position = self->input_tables.sites->position[site]; + if (seg->left <= position && position < seg->right) { + self->mutation_node_map[m_node->mutation] = seg->node; + m_node = m_node->next; + } else if (position >= seg->right) { + seg = seg->next; + } else { + assert(position < seg->left); + m_node = m_node->next; + } + } + } + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_output_sites(simplifier_t *self) +{ + int ret = 0; + tsk_id_t input_site; + tsk_id_t input_mutation, mapped_parent ,site_start, site_end; + tsk_id_t num_input_sites = (tsk_id_t) self->input_tables.sites->num_rows; + tsk_id_t num_input_mutations = (tsk_id_t) self->input_tables.mutations->num_rows; + tsk_id_t input_parent, num_output_mutations, num_output_site_mutations; + tsk_id_t mapped_node; + bool keep_site; + bool filter_sites = !!(self->flags & TSK_FILTER_SITES); + tsk_site_t site; + tsk_mutation_t mutation; + + input_mutation = 0; + num_output_mutations = 0; + for (input_site = 0; input_site < num_input_sites; input_site++) { + ret = tsk_site_tbl_get_row(self->input_tables.sites, (size_t) input_site, &site); + if (ret != 0) { + goto out; + } + site_start = input_mutation; + num_output_site_mutations = 0; + while (input_mutation < num_input_mutations + && self->input_tables.mutations->site[input_mutation] == site.id) { + mapped_node = self->mutation_node_map[input_mutation]; + if (mapped_node != TSK_NULL) { + input_parent = self->input_tables.mutations->parent[input_mutation]; + mapped_parent = TSK_NULL; + if (input_parent != TSK_NULL) { + mapped_parent = self->mutation_id_map[input_parent]; + } + self->mutation_id_map[input_mutation] = num_output_mutations; + num_output_mutations++; + num_output_site_mutations++; + } + input_mutation++; + } + site_end = input_mutation; + + keep_site = true; + if (filter_sites && num_output_site_mutations == 0) { + keep_site = false; + } + if (keep_site) { + for (input_mutation = site_start; input_mutation < site_end; input_mutation++) { + if (self->mutation_id_map[input_mutation] != TSK_NULL) { + assert(self->tables->mutations->num_rows + == (size_t) self->mutation_id_map[input_mutation]); + mapped_node = self->mutation_node_map[input_mutation]; + assert(mapped_node != TSK_NULL); + mapped_parent = self->input_tables.mutations->parent[input_mutation]; + if (mapped_parent != TSK_NULL) { + mapped_parent = self->mutation_id_map[mapped_parent]; + } + ret = tsk_mutation_tbl_get_row(self->input_tables.mutations, + (size_t) input_mutation, &mutation); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_add_row(self->tables->mutations, + (tsk_id_t) self->tables->sites->num_rows, + mapped_node, mapped_parent, + mutation.derived_state, mutation.derived_state_length, + mutation.metadata, mutation.metadata_length); + if (ret < 0) { + goto out; + } + } + } + ret = tsk_site_tbl_add_row(self->tables->sites, site.position, + site.ancestral_state, site.ancestral_state_length, + site.metadata, site.metadata_length); + if (ret < 0) { + goto out; + } + } + assert(num_output_mutations == (tsk_id_t) self->tables->mutations->num_rows); + input_mutation = site_end; + } + assert(input_mutation == num_input_mutations); + ret = 0; +out: + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_finalise_references(simplifier_t *self) +{ + int ret = 0; + tsk_tbl_size_t j; + bool keep; + tsk_tbl_size_t num_nodes = self->tables->nodes->num_rows; + + tsk_population_t pop; + tsk_id_t pop_id; + tsk_tbl_size_t num_populations = self->input_tables.populations->num_rows; + tsk_id_t *node_population = self->tables->nodes->population; + bool *population_referenced = calloc(num_populations, sizeof(*population_referenced)); + tsk_id_t *population_id_map = malloc( + num_populations * sizeof(*population_id_map)); + bool filter_populations = !!(self->flags & TSK_FILTER_POPULATIONS); + + tsk_individual_t ind; + tsk_id_t ind_id; + tsk_tbl_size_t num_individuals = self->input_tables.individuals->num_rows; + tsk_id_t *node_individual = self->tables->nodes->individual; + bool *individual_referenced = calloc(num_individuals, sizeof(*individual_referenced)); + tsk_id_t *individual_id_map = malloc( + num_individuals * sizeof(*individual_id_map)); + bool filter_individuals = !!(self->flags & TSK_FILTER_INDIVIDUALS); + + if (population_referenced == NULL || population_id_map == NULL + || individual_referenced == NULL || individual_id_map == NULL) { + goto out; + } + + /* TODO Migrations fit reasonably neatly into the pattern that we have here. We can + * consider references to populations from migration objects in the same way + * as from nodes, so that we only remove a population if its referenced by + * neither. Mapping the population IDs in migrations is then easy. In principle + * nodes are similar, but the semantics are slightly different because we've + * already allocated all the nodes by their references from edges. We then + * need to decide whether we remove migrations that reference unmapped nodes + * or whether to add these nodes back in (probably the former is the correct + * approach).*/ + if (self->input_tables.migrations->num_rows != 0) { + ret = TSK_ERR_SIMPLIFY_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + for (j = 0; j < num_nodes; j++) { + pop_id = node_population[j]; + if (pop_id != TSK_NULL) { + population_referenced[pop_id] = true; + } + ind_id = node_individual[j]; + if (ind_id != TSK_NULL) { + individual_referenced[ind_id] = true; + } + } + for (j = 0; j < num_populations; j++) { + ret = tsk_population_tbl_get_row(self->input_tables.populations, j, &pop); + if (ret != 0) { + goto out; + } + keep = true; + if (filter_populations && !population_referenced[j]) { + keep = false; + } + population_id_map[j] = TSK_NULL; + if (keep) { + ret = tsk_population_tbl_add_row(self->tables->populations, + pop.metadata, pop.metadata_length); + if (ret < 0) { + goto out; + } + population_id_map[j] = (tsk_id_t) ret; + } + } + + for (j = 0; j < num_individuals; j++) { + ret = tsk_individual_tbl_get_row(self->input_tables.individuals, j, &ind); + if (ret != 0) { + goto out; + } + keep = true; + if (filter_individuals && !individual_referenced[j]) { + keep = false; + } + individual_id_map[j] = TSK_NULL; + if (keep) { + ret = tsk_individual_tbl_add_row(self->tables->individuals, + ind.flags, ind.location, ind.location_length, + ind.metadata, ind.metadata_length); + if (ret < 0) { + goto out; + } + individual_id_map[j] = (tsk_id_t) ret; + } + } + + /* Remap node IDs referencing the above */ + for (j = 0; j < num_nodes; j++) { + pop_id = node_population[j]; + if (pop_id != TSK_NULL) { + node_population[j] = population_id_map[pop_id]; + } + ind_id = node_individual[j]; + if (ind_id != TSK_NULL) { + node_individual[j] = individual_id_map[ind_id]; + } + } + + ret = tsk_provenance_tbl_copy(self->input_tables.provenances, self->tables->provenances); + if (ret != 0) { + goto out; + } +out: + tsk_safe_free(population_referenced); + tsk_safe_free(individual_referenced); + tsk_safe_free(population_id_map); + tsk_safe_free(individual_id_map); + return ret; +} + +static int TSK_WARN_UNUSED +simplifier_run(simplifier_t *self, tsk_id_t *node_map) +{ + int ret = 0; + size_t j, start; + tsk_id_t parent, current_parent; + const tsk_edge_tbl_t *input_edges = self->input_tables.edges; + size_t num_edges = input_edges->num_rows; + + if (num_edges > 0) { + start = 0; + current_parent = input_edges->parent[0]; + for (j = 0; j < num_edges; j++) { + parent = input_edges->parent[j]; + if (parent != current_parent) { + ret = simplifier_process_parent_edges(self, current_parent, start, j); + if (ret != 0) { + goto out; + } + current_parent = parent; + start = j; + } + } + ret = simplifier_process_parent_edges(self, current_parent, start, num_edges); + if (ret != 0) { + goto out; + } + } + ret = simplifier_map_mutation_nodes(self); + if (ret != 0) { + goto out; + } + ret = simplifier_output_sites(self); + if (ret != 0) { + goto out; + } + ret = simplifier_finalise_references(self); + if (ret != 0) { + goto out; + } + if (node_map != NULL) { + /* Finally, output the new IDs for the nodes, if required. */ + memcpy(node_map, self->node_id_map, + self->input_tables.nodes->num_rows * sizeof(tsk_id_t)); + } +out: + return ret; +} + +/************************* + * table_collection + *************************/ + +typedef struct { + tsk_id_t index; + /* These are the sort keys in order */ + double first; + double second; + tsk_id_t third; + tsk_id_t fourth; +} index_sort_t; + +static int +cmp_index_sort(const void *a, const void *b) { + const index_sort_t *ca = (const index_sort_t *) a; + const index_sort_t *cb = (const index_sort_t *) b; + int ret = (ca->first > cb->first) - (ca->first < cb->first); + if (ret == 0) { + ret = (ca->second > cb->second) - (ca->second < cb->second); + if (ret == 0) { + ret = (ca->third > cb->third) - (ca->third < cb->third); + if (ret == 0) { + ret = (ca->fourth > cb->fourth) - (ca->fourth < cb->fourth); + } + } + } + return ret; +} + +static int +tsk_tbl_collection_check_offsets(tsk_tbl_collection_t *self) +{ + int ret = 0; + + ret = check_offsets(self->nodes->num_rows, self->nodes->metadata_offset, + self->nodes->metadata_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->sites->num_rows, self->sites->ancestral_state_offset, + self->sites->ancestral_state_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->sites->num_rows, self->sites->metadata_offset, + self->sites->metadata_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->mutations->num_rows, self->mutations->derived_state_offset, + self->mutations->derived_state_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->mutations->num_rows, self->mutations->metadata_offset, + self->mutations->metadata_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->individuals->num_rows, self->individuals->metadata_offset, + self->individuals->metadata_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->provenances->num_rows, self->provenances->timestamp_offset, + self->provenances->timestamp_length, true); + if (ret != 0) { + goto out; + } + ret = check_offsets(self->provenances->num_rows, self->provenances->record_offset, + self->provenances->record_length, true); + if (ret != 0) { + goto out; + } + ret = 0; +out: + return ret; +} + +static int +tsk_tbl_collection_check_edge_ordering(tsk_tbl_collection_t *self) +{ + int ret = 0; + tsk_tbl_size_t j; + tsk_id_t parent, last_parent, child, last_child; + double left, last_left; + const double *time = self->nodes->time; + bool *parent_seen = calloc(self->nodes->num_rows, sizeof(bool)); + + if (parent_seen == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* Just keeping compiler happy; these values don't matter. */ + last_left = 0; + last_parent = 0; + last_child = 0; + for (j = 0; j < self->edges->num_rows; j++) { + left = self->edges->left[j]; + parent = self->edges->parent[j]; + child = self->edges->child[j]; + if (parent_seen[parent]) { + ret = TSK_ERR_EDGES_NONCONTIGUOUS_PARENTS; + goto out; + } + if (j > 0) { + /* Input data must sorted by (time[parent], parent, child, left). */ + if (time[parent] < time[last_parent]) { + ret = TSK_ERR_EDGES_NOT_SORTED_PARENT_TIME; + goto out; + } + if (time[parent] == time[last_parent]) { + if (parent == last_parent) { + if (child < last_child) { + ret = TSK_ERR_EDGES_NOT_SORTED_CHILD; + goto out; + } + if (child == last_child) { + if (left == last_left) { + ret = TSK_ERR_DUPLICATE_EDGES; + goto out; + } else if (left < last_left) { + ret = TSK_ERR_EDGES_NOT_SORTED_LEFT; + goto out; + } + } + } else { + parent_seen[last_parent] = true; + } + } + } + last_parent = parent; + last_child = child; + last_left = left; + } +out: + tsk_safe_free(parent_seen); + return ret; +} + + +/* Checks the integrity of the table collection. What gets checked depends + * on the flags values: + * 0 Check the integrity of ID & spatial references. + * TSK_CHECK_OFFSETS Check offsets for ragged columns. + * TSK_CHECK_EDGE_ORDERING Check edge ordering contraints for a tree sequence. + * TSK_CHECK_SITE_ORDERING Check that sites are in nondecreasing position order. + * TSK_CHECK_SITE_DUPLICATES Check for any duplicate site positions. + * TSK_CHECK_MUTATION_ORDERING Check mutation ordering contraints for a tree sequence. + * TSK_CHECK_INDEXES Check indexes exist & reference integrity. + * TSK_CHECK_ALL All above checks. + */ +int TSK_WARN_UNUSED +tsk_tbl_collection_check_integrity(tsk_tbl_collection_t *self, int flags) +{ + int ret = TSK_ERR_GENERIC; + tsk_tbl_size_t j; + double left, right, position; + double L = self->sequence_length; + double *time = self->nodes->time; + tsk_id_t parent, child; + tsk_id_t parent_mut; + tsk_id_t population; + tsk_id_t individual; + tsk_id_t num_nodes = (tsk_id_t) self->nodes->num_rows; + tsk_id_t num_edges = (tsk_id_t) self->edges->num_rows; + tsk_id_t num_sites = (tsk_id_t) self->sites->num_rows; + tsk_id_t num_mutations = (tsk_id_t) self->mutations->num_rows; + tsk_id_t num_populations = (tsk_id_t) self->populations->num_rows; + tsk_id_t num_individuals = (tsk_id_t) self->individuals->num_rows; + bool check_site_ordering = !!(flags & TSK_CHECK_SITE_ORDERING); + bool check_site_duplicates = !!(flags & TSK_CHECK_SITE_DUPLICATES); + bool check_mutation_ordering = !!(flags & TSK_CHECK_MUTATION_ORDERING); + + if (self->sequence_length <= 0) { + ret = TSK_ERR_BAD_SEQUENCE_LENGTH; + goto out; + } + + /* Nodes */ + for (j = 0; j < self->nodes->num_rows; j++) { + population = self->nodes->population[j]; + if (population < TSK_NULL || population >= num_populations) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + individual = self->nodes->individual[j]; + if (individual < TSK_NULL || individual >= num_individuals) { + ret = TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS; + goto out; + } + } + + /* Edges */ + for (j = 0; j < self->edges->num_rows; j++) { + parent = self->edges->parent[j]; + child = self->edges->child[j]; + left = self->edges->left[j]; + right = self->edges->right[j]; + /* Node ID integrity */ + if (parent == TSK_NULL) { + ret = TSK_ERR_NULL_PARENT; + goto out; + } + if (parent < 0 || parent >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (child == TSK_NULL) { + ret = TSK_ERR_NULL_CHILD; + goto out; + } + if (child < 0 || child >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + /* Spatial requirements for edges */ + if (left < 0) { + ret = TSK_ERR_LEFT_LESS_ZERO; + goto out; + } + if (right > L) { + ret = TSK_ERR_RIGHT_GREATER_SEQ_LENGTH; + goto out; + } + if (left >= right) { + ret = TSK_ERR_BAD_EDGE_INTERVAL; + goto out; + } + /* time[child] must be < time[parent] */ + if (time[child] >= time[parent]) { + ret = TSK_ERR_BAD_NODE_TIME_ORDERING; + goto out; + } + } + for (j = 0; j < self->sites->num_rows; j++) { + position = self->sites->position[j]; + /* Spatial requirements */ + if (position < 0 || position >= L) { + ret = TSK_ERR_BAD_SITE_POSITION; + goto out; + } + if (j > 0) { + if (check_site_duplicates && self->sites->position[j - 1] == position) { + ret = TSK_ERR_DUPLICATE_SITE_POSITION; + goto out; + } + if (check_site_ordering && self->sites->position[j - 1] > position) { + ret = TSK_ERR_UNSORTED_SITES; + goto out; + } + } + } + + /* Mutations */ + for (j = 0; j < self->mutations->num_rows; j++) { + if (self->mutations->site[j] < 0 || self->mutations->site[j] >= num_sites) { + ret = TSK_ERR_SITE_OUT_OF_BOUNDS; + goto out; + } + if (self->mutations->node[j] < 0 || self->mutations->node[j] >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + parent_mut = self->mutations->parent[j]; + if (parent_mut < TSK_NULL || parent_mut >= num_mutations) { + ret = TSK_ERR_MUTATION_OUT_OF_BOUNDS; + goto out; + } + if (parent_mut == (tsk_id_t) j) { + ret = TSK_ERR_MUTATION_PARENT_EQUAL; + goto out; + } + if (check_mutation_ordering) { + if (parent_mut != TSK_NULL) { + /* Parents must be listed before their children */ + if (parent_mut > (tsk_id_t) j) { + ret = TSK_ERR_MUTATION_PARENT_AFTER_CHILD; + goto out; + } + if (self->mutations->site[parent_mut] != self->mutations->site[j]) { + ret = TSK_ERR_MUTATION_PARENT_DIFFERENT_SITE; + goto out; + } + } + if (j > 0) { + if (self->mutations->site[j - 1] > self->mutations->site[j]) { + ret = TSK_ERR_UNSORTED_MUTATIONS; + goto out; + } + } + } + } + + /* Migrations */ + for (j = 0; j < self->migrations->num_rows; j++) { + if (self->migrations->node[j] < 0 || self->migrations->node[j] >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (self->migrations->source[j] < 0 + || self->migrations->source[j] >= num_populations) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + if (self->migrations->dest[j] < 0 + || self->migrations->dest[j] >= num_populations) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + left = self->migrations->left[j]; + right = self->migrations->right[j]; + /* Spatial requirements */ + /* TODO it's a bit misleading to use the edge-specific errors here. */ + if (left < 0) { + ret = TSK_ERR_LEFT_LESS_ZERO; + goto out; + } + if (right > L) { + ret = TSK_ERR_RIGHT_GREATER_SEQ_LENGTH; + goto out; + } + if (left >= right) { + ret = TSK_ERR_BAD_EDGE_INTERVAL; + goto out; + } + } + + if (!!(flags & TSK_CHECK_INDEXES)) { + if (!tsk_tbl_collection_is_indexed(self)) { + ret = TSK_ERR_TABLES_NOT_INDEXED; + goto out; + } + for (j = 0; j < self->edges->num_rows; j++) { + if (self->indexes.edge_insertion_order[j] < 0 || + self->indexes.edge_insertion_order[j] >= num_edges) { + ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + goto out; + } + if (self->indexes.edge_removal_order[j] < 0 || + self->indexes.edge_removal_order[j] >= num_edges) { + ret = TSK_ERR_EDGE_OUT_OF_BOUNDS; + goto out; + } + } + } + + ret = 0; + if (!!(flags & TSK_CHECK_OFFSETS)) { + ret = tsk_tbl_collection_check_offsets(self); + if (ret != 0) { + goto out; + } + } + if (!!(flags & TSK_CHECK_EDGE_ORDERING)) { + ret = tsk_tbl_collection_check_edge_ordering(self); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +int +tsk_tbl_collection_print_state(tsk_tbl_collection_t *self, FILE *out) +{ + fprintf(out, "Table collection state\n"); + fprintf(out, "sequence_length = %f\n", self->sequence_length); + tsk_individual_tbl_print_state(self->individuals, out); + tsk_node_tbl_print_state(self->nodes, out); + tsk_edge_tbl_print_state(self->edges, out); + tsk_migration_tbl_print_state(self->migrations, out); + tsk_site_tbl_print_state(self->sites, out); + tsk_mutation_tbl_print_state(self->mutations, out); + tsk_population_tbl_print_state(self->populations, out); + tsk_provenance_tbl_print_state(self->provenances, out); + return 0; +} + +int +tsk_tbl_collection_alloc(tsk_tbl_collection_t *self, int flags) +{ + int ret = 0; + memset(self, 0, sizeof(*self)); + self->individuals = calloc(1, sizeof(*self->individuals)); + self->nodes = calloc(1, sizeof(*self->nodes)); + self->edges = calloc(1, sizeof(*self->edges)); + self->migrations = calloc(1, sizeof(*self->migrations)); + self->sites = calloc(1, sizeof(*self->sites)); + self->mutations = calloc(1, sizeof(*self->mutations)); + self->mutations = calloc(1, sizeof(*self->mutations)); + self->populations = calloc(1, sizeof(*self->populations)); + self->provenances = calloc(1, sizeof(*self->provenances)); + if (self->individuals == NULL || self->nodes == NULL + || self->edges == NULL || self->migrations == NULL + || self->sites == NULL || self->mutations == NULL + || self->populations == NULL || self->provenances == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + if (! (flags & TSK_NO_ALLOC_TABLES)) { + /* Allocate all the tables with their default increments */ + ret = tsk_node_tbl_alloc(self->nodes, 0); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_alloc(self->edges, 0); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_alloc(self->migrations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_alloc(self->sites, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_alloc(self->mutations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_alloc(self->individuals, 0); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_alloc(self->populations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_alloc(self->provenances, 0); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +static int +tsk_tbl_collection_free_tables(tsk_tbl_collection_t *self) +{ + if (self->individuals != NULL) { + tsk_individual_tbl_free(self->individuals); + free(self->individuals); + self->individuals = NULL; + } + if (self->nodes != NULL) { + tsk_node_tbl_free(self->nodes); + free(self->nodes); + self->nodes = NULL; + } + if (self->edges != NULL) { + tsk_edge_tbl_free(self->edges); + free(self->edges); + self->edges = NULL; + } + if (self->migrations != NULL) { + tsk_migration_tbl_free(self->migrations); + free(self->migrations); + self->migrations = NULL; + } + if (self->sites != NULL) { + tsk_site_tbl_free(self->sites); + free(self->sites); + self->sites = NULL; + } + if (self->mutations != NULL) { + tsk_mutation_tbl_free(self->mutations); + free(self->mutations); + self->mutations = NULL; + } + if (self->populations != NULL) { + tsk_population_tbl_free(self->populations); + free(self->populations); + self->populations = NULL; + } + if (self->provenances != NULL) { + tsk_provenance_tbl_free(self->provenances); + free(self->provenances); + self->provenances = NULL; + } + return 0; +} + +int +tsk_tbl_collection_free(tsk_tbl_collection_t *self) +{ + int ret = 0; + tsk_tbl_collection_free_tables(self); + if (self->indexes.malloced_locally) { + tsk_safe_free(self->indexes.edge_insertion_order); + tsk_safe_free(self->indexes.edge_removal_order); + } + if (self->store != NULL) { + kastore_close(self->store); + free(self->store); + } + tsk_safe_free(self->file_uuid); + return ret; +} + +/* Returns true if all the tables and collection metadata are equal. Note + * this does *not* consider the indexes, since these are derived from the + * tables. We do not consider the file_uuids either, since this is a property of + * the file that set of tables is stored in. */ +bool +tsk_tbl_collection_equals(tsk_tbl_collection_t *self, tsk_tbl_collection_t *other) +{ + bool ret = self->sequence_length == other->sequence_length + && tsk_individual_tbl_equals(self->individuals, other->individuals) + && tsk_node_tbl_equals(self->nodes, other->nodes) + && tsk_edge_tbl_equals(self->edges, other->edges) + && tsk_migration_tbl_equals(self->migrations, other->migrations) + && tsk_site_tbl_equals(self->sites, other->sites) + && tsk_mutation_tbl_equals(self->mutations, other->mutations) + && tsk_population_tbl_equals(self->populations, other->populations) + && tsk_provenance_tbl_equals(self->provenances, other->provenances); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_copy(tsk_tbl_collection_t *self, tsk_tbl_collection_t *dest) +{ + int ret = 0; + size_t index_size; + + if (dest == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + ret = tsk_node_tbl_copy(self->nodes, dest->nodes); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_copy(self->edges, dest->edges); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_copy(self->migrations, dest->migrations); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_copy(self->sites, dest->sites); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_copy(self->mutations, dest->mutations); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_copy(self->individuals, dest->individuals); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_copy(self->populations, dest->populations); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_copy(self->provenances, dest->provenances); + if (ret != 0) { + goto out; + } + dest->sequence_length = self->sequence_length; + if (tsk_tbl_collection_is_indexed(self)) { + tsk_tbl_collection_drop_indexes(dest); + index_size = self->edges->num_rows * sizeof(tsk_id_t); + dest->indexes.edge_insertion_order = malloc(index_size); + dest->indexes.edge_removal_order = malloc(index_size); + dest->indexes.malloced_locally = true; + if (dest->indexes.edge_insertion_order == NULL + || dest->indexes.edge_removal_order == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(dest->indexes.edge_insertion_order, self->indexes.edge_insertion_order, + index_size); + memcpy(dest->indexes.edge_removal_order, self->indexes.edge_removal_order, + index_size); + } +out: + return ret; +} + +bool +tsk_tbl_collection_is_indexed(tsk_tbl_collection_t *self) +{ + return self->indexes.edge_insertion_order != NULL + && self->indexes.edge_removal_order != NULL; +} + +int +tsk_tbl_collection_drop_indexes(tsk_tbl_collection_t *self) +{ + if (self->indexes.malloced_locally) { + tsk_safe_free(self->indexes.edge_insertion_order); + tsk_safe_free(self->indexes.edge_removal_order); + } + self->indexes.edge_insertion_order = NULL; + self->indexes.edge_removal_order = NULL; + return 0; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_build_indexes(tsk_tbl_collection_t *self, int TSK_UNUSED(flags)) +{ + int ret = TSK_ERR_GENERIC; + size_t j; + double *time = self->nodes->time; + index_sort_t *sort_buff = NULL; + tsk_id_t parent; + + tsk_tbl_collection_drop_indexes(self); + self->indexes.malloced_locally = true; + self->indexes.edge_insertion_order = malloc(self->edges->num_rows * sizeof(tsk_id_t)); + self->indexes.edge_removal_order = malloc(self->edges->num_rows * sizeof(tsk_id_t)); + if (self->indexes.edge_insertion_order == NULL + || self->indexes.edge_removal_order == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + /* Alloc the sort buffer */ + sort_buff = malloc(self->edges->num_rows * sizeof(index_sort_t)); + if (sort_buff == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* TODO we should probably drop these checks and call check_integrity instead. + * Do this when we're providing the Python API for build_indexes, so that + * we can test it properly. */ + + /* sort by left and increasing time to give us the order in which + * records should be inserted */ + for (j = 0; j < self->edges->num_rows; j++) { + sort_buff[j].index = (tsk_id_t ) j; + sort_buff[j].first = self->edges->left[j]; + parent = self->edges->parent[j]; + if (parent == TSK_NULL) { + ret = TSK_ERR_NULL_PARENT; + goto out; + } + if (parent < 0 || parent >= (tsk_id_t) self->nodes->num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + sort_buff[j].second = time[parent]; + sort_buff[j].third = parent; + sort_buff[j].fourth = self->edges->child[j]; + } + qsort(sort_buff, self->edges->num_rows, sizeof(index_sort_t), cmp_index_sort); + for (j = 0; j < self->edges->num_rows; j++) { + self->indexes.edge_insertion_order[j] = sort_buff[j].index; + } + /* sort by right and decreasing parent time to give us the order in which + * records should be removed. */ + for (j = 0; j < self->edges->num_rows; j++) { + sort_buff[j].index = (tsk_id_t ) j; + sort_buff[j].first = self->edges->right[j]; + parent = self->edges->parent[j]; + if (parent == TSK_NULL) { + ret = TSK_ERR_NULL_PARENT; + goto out; + } + if (parent < 0 || parent >= (tsk_id_t) self->nodes->num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + sort_buff[j].second = -time[parent]; + sort_buff[j].third = -parent; + sort_buff[j].fourth = -self->edges->child[j]; + } + qsort(sort_buff, self->edges->num_rows, sizeof(index_sort_t), cmp_index_sort); + for (j = 0; j < self->edges->num_rows; j++) { + self->indexes.edge_removal_order[j] = sort_buff[j].index; + } + ret = 0; +out: + if (sort_buff != NULL) { + free(sort_buff); + } + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tbl_collection_read_format_data(tsk_tbl_collection_t *self) +{ + int ret = 0; + size_t len; + uint32_t *version; + int8_t *format_name, *uuid; + double *L; + + ret = kastore_gets_int8(self->store, "format/name", &format_name, &len); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + if (len != TSK_FILE_FORMAT_NAME_LENGTH) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + if (memcmp(TSK_FILE_FORMAT_NAME, format_name, TSK_FILE_FORMAT_NAME_LENGTH) != 0) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + + ret = kastore_gets_uint32(self->store, "format/version", &version, &len); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + if (len != 2) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + if (version[0] < TSK_FILE_FORMAT_VERSION_MAJOR) { + ret = TSK_ERR_FILE_VERSION_TOO_OLD; + goto out; + } + if (version[0] > TSK_FILE_FORMAT_VERSION_MAJOR) { + ret = TSK_ERR_FILE_VERSION_TOO_NEW; + goto out; + } + + ret = kastore_gets_float64(self->store, "sequence_length", &L, &len); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + if (len != 1) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + if (L[0] <= 0.0) { + ret = TSK_ERR_BAD_SEQUENCE_LENGTH; + goto out; + } + self->sequence_length = L[0]; + + ret = kastore_gets_int8(self->store, "uuid", &uuid, &len); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + if (len != TSK_UUID_SIZE) { + ret = TSK_ERR_FILE_FORMAT; + goto out; + } + + /* Allow space for \0 so we can print it as a string */ + self->file_uuid = malloc(TSK_UUID_SIZE + 1); + if (self->file_uuid == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->file_uuid, uuid, TSK_UUID_SIZE); + self->file_uuid[TSK_UUID_SIZE] = '\0'; +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tbl_collection_dump_indexes(tsk_tbl_collection_t *self, kastore_t *store) +{ + int ret = 0; + write_table_col_t write_cols[] = { + {"indexes/edge_insertion_order", NULL, self->edges->num_rows, KAS_INT32}, + {"indexes/edge_removal_order", NULL, self->edges->num_rows, KAS_INT32}, + }; + + if (! tsk_tbl_collection_is_indexed(self)) { + ret = tsk_tbl_collection_build_indexes(self, 0); + if (ret != 0) { + goto out; + } + } + write_cols[0].array = self->indexes.edge_insertion_order; + write_cols[1].array = self->indexes.edge_removal_order; + ret = write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tbl_collection_load_indexes(tsk_tbl_collection_t *self) +{ + read_table_col_t read_cols[] = { + {"indexes/edge_insertion_order", (void **) &self->indexes.edge_insertion_order, + &self->edges->num_rows, 0, KAS_INT32}, + {"indexes/edge_removal_order", (void **) &self->indexes.edge_removal_order, + &self->edges->num_rows, 0, KAS_INT32}, + }; + self->indexes.malloced_locally = false; + return read_table_cols(self->store, read_cols, sizeof(read_cols) / sizeof(*read_cols)); +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_load(tsk_tbl_collection_t *self, const char *filename, int TSK_UNUSED(flags)) +{ + int ret = 0; + + ret = tsk_tbl_collection_alloc(self, TSK_NO_ALLOC_TABLES); + if (ret != 0) { + goto out; + } + self->store = calloc(1, sizeof(*self->store)); + if (self->store == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = kastore_open(self->store, filename, "r", KAS_READ_ALL); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + ret = tsk_tbl_collection_read_format_data(self); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_load(self->nodes, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_load(self->edges, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_load(self->sites, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_load(self->mutations, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_load(self->migrations, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_load(self->individuals, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_load(self->populations, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_load(self->provenances, self->store); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_load_indexes(self); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_check_offsets(self); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tbl_collection_write_format_data(tsk_tbl_collection_t *self, kastore_t *store) +{ + int ret = 0; + char format_name[TSK_FILE_FORMAT_NAME_LENGTH]; + char uuid[TSK_UUID_SIZE + 1]; // Must include space for trailing null. + uint32_t version[2] = { + TSK_FILE_FORMAT_VERSION_MAJOR, TSK_FILE_FORMAT_VERSION_MINOR}; + write_table_col_t write_cols[] = { + {"format/name", (void *) format_name, sizeof(format_name), KAS_INT8}, + {"format/version", (void *) version, 2, KAS_UINT32}, + {"sequence_length", (void *) &self->sequence_length, 1, KAS_FLOAT64}, + {"uuid", (void *) uuid, TSK_UUID_SIZE, KAS_INT8}, + }; + + ret = tsk_generate_uuid(uuid, 0); + if (ret != 0) { + goto out; + } + /* This stupid dance is to workaround the fact that compilers won't allow + * casts to discard the 'const' qualifier. */ + memcpy(format_name, TSK_FILE_FORMAT_NAME, sizeof(format_name)); + ret = write_table_cols(store, write_cols, sizeof(write_cols) / sizeof(*write_cols)); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_dump(tsk_tbl_collection_t *self, const char *filename, int TSK_UNUSED(flags)) +{ + int ret = 0; + kastore_t store; + + ret = kastore_open(&store, filename, "w", 0); + if (ret != 0) { + ret = tsk_set_kas_error(ret); + goto out; + } + ret = tsk_tbl_collection_write_format_data(self, &store); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_dump(self->nodes, &store); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_dump(self->edges, &store); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_dump(self->sites, &store); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_dump(self->migrations, &store); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_dump(self->mutations, &store); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_dump(self->individuals, &store); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_dump(self->populations, &store); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_dump(self->provenances, &store); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_dump_indexes(self, &store); + if (ret != 0) { + goto out; + } + ret = kastore_close(&store); +out: + if (ret != 0) { + kastore_close(&store); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_simplify(tsk_tbl_collection_t *self, + tsk_id_t *samples, size_t num_samples, int flags, tsk_id_t *node_map) +{ + int ret = 0; + simplifier_t simplifier; + + ret = simplifier_alloc(&simplifier, samples, num_samples, self, flags); + if (ret != 0) { + goto out; + } + ret = simplifier_run(&simplifier, node_map); + if (ret != 0) { + goto out; + } + if (!! (flags & TSK_DEBUG)) { + simplifier_print_state(&simplifier, stdout); + } + /* The indexes are invalidated now so drop them */ + ret = tsk_tbl_collection_drop_indexes(self); +out: + simplifier_free(&simplifier); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_sort(tsk_tbl_collection_t *self, size_t edge_start, int flags) +{ + int ret = 0; + table_sorter_t sorter; + + ret = table_sorter_alloc(&sorter, self, flags); + if (ret != 0) { + goto out; + } + ret = table_sorter_run(&sorter, edge_start); + if (ret != 0) { + goto out; + } + /* The indexes are invalidated now so drop them */ + ret = tsk_tbl_collection_drop_indexes(self); +out: + table_sorter_free(&sorter); + return ret; +} + +/* + * Remove any sites with duplicate positions, retaining only the *first* + * one. Assumes the tables have been sorted, throwing an error if not. + */ +int TSK_WARN_UNUSED +tsk_tbl_collection_deduplicate_sites(tsk_tbl_collection_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + tsk_tbl_size_t j; + /* Map of old site IDs to new site IDs. */ + tsk_id_t *site_id_map = NULL; + tsk_site_tbl_t copy; + tsk_site_t row, last_row; + + /* Must allocate the site table first for tsk_site_tbl_free to be safe */ + ret = tsk_site_tbl_alloc(©, 0); + if (ret != 0) { + goto out; + } + /* Check everything except site duplicates (which we expect) and + * edge indexes (which we don't use) */ + ret = tsk_tbl_collection_check_integrity(self, + TSK_CHECK_ALL & ~TSK_CHECK_SITE_DUPLICATES & ~TSK_CHECK_INDEXES); + if (ret != 0) { + goto out; + } + + ret = tsk_site_tbl_copy(self->sites, ©); + if (ret != 0) { + goto out; + } + site_id_map = malloc(copy.num_rows * sizeof(*site_id_map)); + if (site_id_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_site_tbl_clear(self->sites); + if (ret != 0) { + goto out; + } + + last_row.position = -1; + site_id_map[0] = 0; + for (j = 0; j < copy.num_rows; j++) { + ret = tsk_site_tbl_get_row(©, j, &row); + if (ret != 0) { + goto out; + } + if (row.position != last_row.position) { + ret = tsk_site_tbl_add_row(self->sites, row.position, row.ancestral_state, + row.ancestral_state_length, row.metadata, row.metadata_length); + if (ret < 0) { + goto out; + } + } + site_id_map[j] = (tsk_id_t) self->sites->num_rows - 1; + last_row = row; + } + + if (self->sites->num_rows < copy.num_rows) { + // Remap sites in the mutation table + // (but only if there's been any changed sites) + for (j = 0; j < self->mutations->num_rows; j++) { + self->mutations->site[j] = site_id_map[self->mutations->site[j]]; + } + } + ret = 0; +out: + tsk_site_tbl_free(©); + tsk_safe_free(site_id_map); + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_compute_mutation_parents(tsk_tbl_collection_t *self, int TSK_UNUSED(flags)) +{ + int ret = 0; + const tsk_id_t *I, *O; + const tsk_edge_tbl_t edges = *self->edges; + const tsk_node_tbl_t nodes = *self->nodes; + const tsk_site_tbl_t sites = *self->sites; + const tsk_mutation_tbl_t mutations = *self->mutations; + const tsk_id_t M = (tsk_id_t) edges.num_rows; + tsk_id_t tj, tk; + tsk_id_t *parent = NULL; + tsk_id_t *bottom_mutation = NULL; + tsk_id_t u; + double left, right; + tsk_id_t site; + /* Using unsigned values here avoids potentially undefined behaviour */ + uint32_t j, mutation, first_mutation; + + /* Note that because we check everything here, any non-null mutation parents + * will also be checked, even though they are about to be overwritten. To + * ensure that his function always succeeds we must ensure that the + * parent field is set to -1 first. */ + ret = tsk_tbl_collection_check_integrity(self, TSK_CHECK_ALL); + if (ret != 0) { + goto out; + } + parent = malloc(nodes.num_rows * sizeof(*parent)); + bottom_mutation = malloc(nodes.num_rows * sizeof(*bottom_mutation)); + if (parent == NULL || bottom_mutation == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(parent, 0xff, nodes.num_rows * sizeof(*parent)); + memset(bottom_mutation, 0xff, nodes.num_rows * sizeof(*bottom_mutation)); + memset(mutations.parent, 0xff, self->mutations->num_rows * sizeof(tsk_id_t)); + + I = self->indexes.edge_insertion_order; + O = self->indexes.edge_removal_order; + tj = 0; + tk = 0; + site = 0; + mutation = 0; + left = 0; + while (tj < M || left < self->sequence_length) { + while (tk < M && edges.right[O[tk]] == left) { + parent[edges.child[O[tk]]] = TSK_NULL; + tk++; + } + while (tj < M && edges.left[I[tj]] == left) { + parent[edges.child[I[tj]]] = edges.parent[I[tj]]; + tj++; + } + right = self->sequence_length; + if (tj < M) { + right = TSK_MIN(right, edges.left[I[tj]]); + } + if (tk < M) { + right = TSK_MIN(right, edges.right[O[tk]]); + } + + /* Tree is now ready. We look at each site on this tree in turn */ + while (site < (tsk_id_t) sites.num_rows && sites.position[site] < right) { + /* Create a mapping from mutations to nodes. If we see more than one + * mutation at a node, the previously seen one must be the parent + * of the current since we assume they are in order. */ + first_mutation = mutation; + while (mutation < mutations.num_rows && mutations.site[mutation] == site) { + u = mutations.node[mutation]; + if (bottom_mutation[u] != TSK_NULL) { + mutations.parent[mutation] = bottom_mutation[u]; + } + bottom_mutation[u] = (tsk_id_t) mutation; + mutation++; + } + /* Make the common case of 1 mutation fast */ + if (mutation > first_mutation + 1) { + /* If we have more than one mutation, compute the parent for each + * one by traversing up the tree until we find a node that has a + * mutation. */ + for (j = first_mutation; j < mutation; j++) { + if (mutations.parent[j] == TSK_NULL) { + u = parent[mutations.node[j]]; + while (u != TSK_NULL && bottom_mutation[u] == TSK_NULL) { + u = parent[u]; + } + if (u != TSK_NULL) { + mutations.parent[j] = bottom_mutation[u]; + } + } + } + } + /* Reset the mapping for the next site */ + for (j = first_mutation; j < mutation; j++) { + u = mutations.node[j]; + bottom_mutation[u] = TSK_NULL; + /* Check that we haven't violated the sortedness property */ + if (mutations.parent[j] > (tsk_id_t) j) { + ret = TSK_ERR_MUTATION_PARENT_AFTER_CHILD; + goto out; + } + } + site++; + } + /* Move on to the next tree */ + left = right; + } + +out: + tsk_safe_free(parent); + tsk_safe_free(bottom_mutation); + return ret; +} + +/* Record the current "end" position of a table collection, + * which is the current number of rows in each table. + */ +int +tsk_tbl_collection_record_position(tsk_tbl_collection_t *self, + tsk_tbl_collection_position_t *position) +{ + position->individuals = self->individuals->num_rows; + position->nodes = self->nodes->num_rows; + position->edges = self->edges->num_rows; + position->migrations = self->migrations->num_rows; + position->sites = self->sites->num_rows; + position->mutations = self->mutations->num_rows; + position->populations = self->populations->num_rows; + position->provenances = self->provenances->num_rows; + return 0; +} + +/* Reset to the previously recorded position. */ +int TSK_WARN_UNUSED +tsk_tbl_collection_reset_position(tsk_tbl_collection_t *tables, + tsk_tbl_collection_position_t *position) +{ + int ret = 0; + + ret = tsk_tbl_collection_drop_indexes(tables); + if (ret != 0) { + goto out; + } + ret = tsk_individual_tbl_truncate(tables->individuals, position->individuals); + if (ret != 0) { + goto out; + } + ret = tsk_node_tbl_truncate(tables->nodes, position->nodes); + if (ret != 0) { + goto out; + } + ret = tsk_edge_tbl_truncate(tables->edges, position->edges); + if (ret != 0) { + goto out; + } + ret = tsk_migration_tbl_truncate(tables->migrations, position->migrations); + if (ret != 0) { + goto out; + } + ret = tsk_site_tbl_truncate(tables->sites, position->sites); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_tbl_truncate(tables->mutations, position->mutations); + if (ret != 0) { + goto out; + } + ret = tsk_population_tbl_truncate(tables->populations, position->populations); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_tbl_truncate(tables->provenances, position->provenances); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tbl_collection_clear(tsk_tbl_collection_t *self) +{ + tsk_tbl_collection_position_t start; + + memset(&start, 0, sizeof(start)); + return tsk_tbl_collection_reset_position(self, &start); +} + + +static int +cmp_edge_cl(const void *a, const void *b) { + const tsk_edge_t *ia = (const tsk_edge_t *) a; + const tsk_edge_t *ib = (const tsk_edge_t *) b; + int ret = (ia->child > ib->child) - (ia->child < ib->child); + if (ret == 0) { + ret = (ia->left > ib->left) - (ia->left < ib->left); + } + return ret; +} + +/* Squash the edges in the specified array in place. The output edges will + * be sorted by (child_id, left). + */ +int TSK_WARN_UNUSED +tsk_squash_edges(tsk_edge_t *edges, size_t num_edges, size_t *num_output_edges) +{ + int ret = 0; + size_t j, k, l; + tsk_edge_t e; + + qsort(edges, num_edges, sizeof(tsk_edge_t), cmp_edge_cl); + j = 0; + l = 0; + for (k = 1; k < num_edges; k++) { + assert(edges[k - 1].parent == edges[k].parent); + if (edges[k - 1].right != edges[k].left || edges[j].child != edges[k].child) { + e = edges[j]; + e.right = edges[k - 1].right; + edges[l] = e; + j = k; + l++; + } + } + e = edges[j]; + e.right = edges[k - 1].right; + edges[l] = e; + *num_output_edges = l + 1; + return ret; +} diff --git a/c/tsk_tables.h b/c/tsk_tables.h new file mode 100644 index 0000000000..26ed3c961a --- /dev/null +++ b/c/tsk_tables.h @@ -0,0 +1,469 @@ +#ifndef TSK_TABLES_H +#define TSK_TABLES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include + +#include + +#include "tsk_core.h" + +typedef int32_t tsk_id_t; +typedef uint32_t tsk_tbl_size_t; + +/****************************************************************************/ +/* Definitions for the basic objects */ +/****************************************************************************/ + +typedef struct { + tsk_id_t id; + uint32_t flags; + double *location; + tsk_tbl_size_t location_length; + const char *metadata; + tsk_tbl_size_t metadata_length; + tsk_id_t *nodes; + tsk_tbl_size_t nodes_length; +} tsk_individual_t; + +typedef struct { + tsk_id_t id; + uint32_t flags; + double time; + tsk_id_t population; + tsk_id_t individual; + const char *metadata; + tsk_tbl_size_t metadata_length; +} tsk_node_t; + +typedef struct { + tsk_id_t id; + tsk_id_t parent; + tsk_id_t child; + double left; + double right; +} tsk_edge_t; + +typedef struct { + tsk_id_t id; + tsk_id_t site; + tsk_id_t node; + tsk_id_t parent; + const char *derived_state; + tsk_tbl_size_t derived_state_length; + const char *metadata; + tsk_tbl_size_t metadata_length; +} tsk_mutation_t; + +typedef struct { + tsk_id_t id; + double position; + const char *ancestral_state; + tsk_tbl_size_t ancestral_state_length; + const char *metadata; + tsk_tbl_size_t metadata_length; + tsk_mutation_t *mutations; + tsk_tbl_size_t mutations_length; +} tsk_site_t; + +typedef struct { + tsk_id_t id; + tsk_id_t source; + tsk_id_t dest; + tsk_id_t node; + double left; + double right; + double time; +} tsk_migration_t; + +typedef struct { + tsk_id_t id; + const char *metadata; + tsk_tbl_size_t metadata_length; +} tsk_population_t; + +typedef struct { + tsk_id_t id; + const char *timestamp; + tsk_tbl_size_t timestamp_length; + const char *record; + tsk_tbl_size_t record_length; +} tsk_provenance_t; + +/****************************************************************************/ +/* Table definitions */ +/****************************************************************************/ + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t location_length; + tsk_tbl_size_t max_location_length; + tsk_tbl_size_t max_location_length_increment; + tsk_tbl_size_t metadata_length; + tsk_tbl_size_t max_metadata_length; + tsk_tbl_size_t max_metadata_length_increment; + uint32_t *flags; + double *location; + tsk_tbl_size_t *location_offset; + char *metadata; + tsk_tbl_size_t *metadata_offset; +} tsk_individual_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t metadata_length; + tsk_tbl_size_t max_metadata_length; + tsk_tbl_size_t max_metadata_length_increment; + uint32_t *flags; + double *time; + tsk_id_t *population; + tsk_id_t *individual; + char *metadata; + tsk_tbl_size_t *metadata_offset; +} tsk_node_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t ancestral_state_length; + tsk_tbl_size_t max_ancestral_state_length; + tsk_tbl_size_t max_ancestral_state_length_increment; + tsk_tbl_size_t metadata_length; + tsk_tbl_size_t max_metadata_length; + tsk_tbl_size_t max_metadata_length_increment; + double *position; + char *ancestral_state; + tsk_tbl_size_t *ancestral_state_offset; + char *metadata; + tsk_tbl_size_t *metadata_offset; +} tsk_site_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t derived_state_length; + tsk_tbl_size_t max_derived_state_length; + tsk_tbl_size_t max_derived_state_length_increment; + tsk_tbl_size_t metadata_length; + tsk_tbl_size_t max_metadata_length; + tsk_tbl_size_t max_metadata_length_increment; + tsk_id_t *node; + tsk_id_t *site; + tsk_id_t *parent; + char *derived_state; + tsk_tbl_size_t *derived_state_offset; + char *metadata; + tsk_tbl_size_t *metadata_offset; +} tsk_mutation_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + double *left; + double *right; + tsk_id_t *parent; + tsk_id_t *child; +} tsk_edge_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_id_t *source; + tsk_id_t *dest; + tsk_id_t *node; + double *left; + double *right; + double *time; +} tsk_migration_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t metadata_length; + tsk_tbl_size_t max_metadata_length; + tsk_tbl_size_t max_metadata_length_increment; + char *metadata; + tsk_tbl_size_t *metadata_offset; +} tsk_population_tbl_t; + +typedef struct { + tsk_tbl_size_t num_rows; + tsk_tbl_size_t max_rows; + tsk_tbl_size_t max_rows_increment; + tsk_tbl_size_t timestamp_length; + tsk_tbl_size_t max_timestamp_length; + tsk_tbl_size_t max_timestamp_length_increment; + tsk_tbl_size_t record_length; + tsk_tbl_size_t max_record_length; + tsk_tbl_size_t max_record_length_increment; + char *timestamp; + tsk_tbl_size_t *timestamp_offset; + char *record; + tsk_tbl_size_t *record_offset; +} tsk_provenance_tbl_t; + +typedef struct { + double sequence_length; + char *file_uuid; + tsk_individual_tbl_t *individuals; + tsk_node_tbl_t *nodes; + tsk_edge_tbl_t *edges; + tsk_migration_tbl_t *migrations; + tsk_site_tbl_t *sites; + tsk_mutation_tbl_t *mutations; + tsk_population_tbl_t *populations; + tsk_provenance_tbl_t *provenances; + struct { + tsk_id_t *edge_insertion_order; + tsk_id_t *edge_removal_order; + bool malloced_locally; + } indexes; + kastore_t *store; + /* TODO Add in reserved space for future tables. */ +} tsk_tbl_collection_t; + +typedef struct { + tsk_tbl_size_t individuals; + tsk_tbl_size_t nodes; + tsk_tbl_size_t edges; + tsk_tbl_size_t migrations; + tsk_tbl_size_t sites; + tsk_tbl_size_t mutations; + tsk_tbl_size_t populations; + tsk_tbl_size_t provenances; + /* TODO add reserved space for future tables. */ +} tsk_tbl_collection_position_t; + + +/****************************************************************************/ +/* Function signatures */ +/****************************************************************************/ + +int tsk_individual_tbl_alloc(tsk_individual_tbl_t *self, int flags); +int tsk_individual_tbl_set_max_rows_increment(tsk_individual_tbl_t *self, size_t max_rows_increment); +int tsk_individual_tbl_set_max_metadata_length_increment(tsk_individual_tbl_t *self, + size_t max_metadata_length_increment); +int tsk_individual_tbl_set_max_location_length_increment(tsk_individual_tbl_t *self, + size_t max_location_length_increment); +tsk_id_t tsk_individual_tbl_add_row(tsk_individual_tbl_t *self, uint32_t flags, + double *location, size_t location_length, + const char *metadata, size_t metadata_length); +int tsk_individual_tbl_set_columns(tsk_individual_tbl_t *self, size_t num_rows, uint32_t *flags, + double *location, tsk_tbl_size_t *location_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_individual_tbl_append_columns(tsk_individual_tbl_t *self, size_t num_rows, uint32_t *flags, + double *location, tsk_tbl_size_t *location_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_individual_tbl_clear(tsk_individual_tbl_t *self); +int tsk_individual_tbl_truncate(tsk_individual_tbl_t *self, size_t num_rows); +int tsk_individual_tbl_free(tsk_individual_tbl_t *self); +int tsk_individual_tbl_dump_text(tsk_individual_tbl_t *self, FILE *out); +int tsk_individual_tbl_copy(tsk_individual_tbl_t *self, tsk_individual_tbl_t *dest); +void tsk_individual_tbl_print_state(tsk_individual_tbl_t *self, FILE *out); +bool tsk_individual_tbl_equals(tsk_individual_tbl_t *self, tsk_individual_tbl_t *other); +int tsk_individual_tbl_get_row(tsk_individual_tbl_t *self, size_t index, + tsk_individual_t *row); + +int tsk_node_tbl_alloc(tsk_node_tbl_t *self, int flags); +int tsk_node_tbl_set_max_rows_increment(tsk_node_tbl_t *self, size_t max_rows_increment); +int tsk_node_tbl_set_max_metadata_length_increment(tsk_node_tbl_t *self, + size_t max_metadata_length_increment); +tsk_id_t tsk_node_tbl_add_row(tsk_node_tbl_t *self, uint32_t flags, double time, + tsk_id_t population, tsk_id_t individual, + const char *metadata, size_t metadata_length); +int tsk_node_tbl_set_columns(tsk_node_tbl_t *self, size_t num_rows, + uint32_t *flags, double *time, + tsk_id_t *population, tsk_id_t *individual, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_node_tbl_append_columns(tsk_node_tbl_t *self, size_t num_rows, + uint32_t *flags, double *time, + tsk_id_t *population, tsk_id_t *individual, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_node_tbl_clear(tsk_node_tbl_t *self); +int tsk_node_tbl_truncate(tsk_node_tbl_t *self, size_t num_rows); +int tsk_node_tbl_free(tsk_node_tbl_t *self); +int tsk_node_tbl_dump_text(tsk_node_tbl_t *self, FILE *out); +int tsk_node_tbl_copy(tsk_node_tbl_t *self, tsk_node_tbl_t *dest); +void tsk_node_tbl_print_state(tsk_node_tbl_t *self, FILE *out); +bool tsk_node_tbl_equals(tsk_node_tbl_t *self, tsk_node_tbl_t *other); +int tsk_node_tbl_get_row(tsk_node_tbl_t *self, size_t index, tsk_node_t *row); + +int tsk_edge_tbl_alloc(tsk_edge_tbl_t *self, int flags); +int tsk_edge_tbl_set_max_rows_increment(tsk_edge_tbl_t *self, size_t max_rows_increment); +tsk_id_t tsk_edge_tbl_add_row(tsk_edge_tbl_t *self, double left, double right, tsk_id_t parent, + tsk_id_t child); +int tsk_edge_tbl_set_columns(tsk_edge_tbl_t *self, size_t num_rows, double *left, + double *right, tsk_id_t *parent, tsk_id_t *child); +int tsk_edge_tbl_append_columns(tsk_edge_tbl_t *self, size_t num_rows, double *left, + double *right, tsk_id_t *parent, tsk_id_t *child); +int tsk_edge_tbl_clear(tsk_edge_tbl_t *self); +int tsk_edge_tbl_truncate(tsk_edge_tbl_t *self, size_t num_rows); +int tsk_edge_tbl_free(tsk_edge_tbl_t *self); +int tsk_edge_tbl_dump_text(tsk_edge_tbl_t *self, FILE *out); +int tsk_edge_tbl_copy(tsk_edge_tbl_t *self, tsk_edge_tbl_t *dest); +void tsk_edge_tbl_print_state(tsk_edge_tbl_t *self, FILE *out); +bool tsk_edge_tbl_equals(tsk_edge_tbl_t *self, tsk_edge_tbl_t *other); +int tsk_edge_tbl_get_row(tsk_edge_tbl_t *self, size_t index, tsk_edge_t *row); + +int tsk_site_tbl_alloc(tsk_site_tbl_t *self, int flags); +int tsk_site_tbl_set_max_rows_increment(tsk_site_tbl_t *self, size_t max_rows_increment); +int tsk_site_tbl_set_max_metadata_length_increment(tsk_site_tbl_t *self, + size_t max_metadata_length_increment); +int tsk_site_tbl_set_max_ancestral_state_length_increment(tsk_site_tbl_t *self, + size_t max_ancestral_state_length_increment); +tsk_id_t tsk_site_tbl_add_row(tsk_site_tbl_t *self, double position, + const char *ancestral_state, tsk_tbl_size_t ancestral_state_length, + const char *metadata, tsk_tbl_size_t metadata_length); +int tsk_site_tbl_set_columns(tsk_site_tbl_t *self, size_t num_rows, double *position, + const char *ancestral_state, tsk_tbl_size_t *ancestral_state_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_site_tbl_append_columns(tsk_site_tbl_t *self, size_t num_rows, double *position, + const char *ancestral_state, tsk_tbl_size_t *ancestral_state_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +bool tsk_site_tbl_equals(tsk_site_tbl_t *self, tsk_site_tbl_t *other); +int tsk_site_tbl_clear(tsk_site_tbl_t *self); +int tsk_site_tbl_truncate(tsk_site_tbl_t *self, size_t num_rows); +int tsk_site_tbl_copy(tsk_site_tbl_t *self, tsk_site_tbl_t *dest); +int tsk_site_tbl_free(tsk_site_tbl_t *self); +int tsk_site_tbl_dump_text(tsk_site_tbl_t *self, FILE *out); +void tsk_site_tbl_print_state(tsk_site_tbl_t *self, FILE *out); +int tsk_site_tbl_get_row(tsk_site_tbl_t *self, size_t index, tsk_site_t *row); + +void tsk_mutation_tbl_print_state(tsk_mutation_tbl_t *self, FILE *out); +int tsk_mutation_tbl_alloc(tsk_mutation_tbl_t *self, int flags); +int tsk_mutation_tbl_set_max_rows_increment(tsk_mutation_tbl_t *self, size_t max_rows_increment); +int tsk_mutation_tbl_set_max_metadata_length_increment(tsk_mutation_tbl_t *self, + size_t max_metadata_length_increment); +int tsk_mutation_tbl_set_max_derived_state_length_increment(tsk_mutation_tbl_t *self, + size_t max_derived_state_length_increment); +tsk_id_t tsk_mutation_tbl_add_row(tsk_mutation_tbl_t *self, tsk_id_t site, + tsk_id_t node, tsk_id_t parent, + const char *derived_state, tsk_tbl_size_t derived_state_length, + const char *metadata, tsk_tbl_size_t metadata_length); +int tsk_mutation_tbl_set_columns(tsk_mutation_tbl_t *self, size_t num_rows, + tsk_id_t *site, tsk_id_t *node, tsk_id_t *parent, + const char *derived_state, tsk_tbl_size_t *derived_state_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +int tsk_mutation_tbl_append_columns(tsk_mutation_tbl_t *self, size_t num_rows, + tsk_id_t *site, tsk_id_t *node, tsk_id_t *parent, + const char *derived_state, tsk_tbl_size_t *derived_state_length, + const char *metadata, tsk_tbl_size_t *metadata_length); +bool tsk_mutation_tbl_equals(tsk_mutation_tbl_t *self, tsk_mutation_tbl_t *other); +int tsk_mutation_tbl_clear(tsk_mutation_tbl_t *self); +int tsk_mutation_tbl_truncate(tsk_mutation_tbl_t *self, size_t num_rows); +int tsk_mutation_tbl_copy(tsk_mutation_tbl_t *self, tsk_mutation_tbl_t *dest); +int tsk_mutation_tbl_free(tsk_mutation_tbl_t *self); +int tsk_mutation_tbl_dump_text(tsk_mutation_tbl_t *self, FILE *out); +void tsk_mutation_tbl_print_state(tsk_mutation_tbl_t *self, FILE *out); +int tsk_mutation_tbl_get_row(tsk_mutation_tbl_t *self, size_t index, tsk_mutation_t *row); + +int tsk_migration_tbl_alloc(tsk_migration_tbl_t *self, int flags); +int tsk_migration_tbl_set_max_rows_increment(tsk_migration_tbl_t *self, size_t max_rows_increment); +tsk_id_t tsk_migration_tbl_add_row(tsk_migration_tbl_t *self, double left, + double right, tsk_id_t node, tsk_id_t source, + tsk_id_t dest, double time); +int tsk_migration_tbl_set_columns(tsk_migration_tbl_t *self, size_t num_rows, + double *left, double *right, tsk_id_t *node, tsk_id_t *source, + tsk_id_t *dest, double *time); +int tsk_migration_tbl_append_columns(tsk_migration_tbl_t *self, size_t num_rows, + double *left, double *right, tsk_id_t *node, tsk_id_t *source, + tsk_id_t *dest, double *time); +int tsk_migration_tbl_clear(tsk_migration_tbl_t *self); +int tsk_migration_tbl_truncate(tsk_migration_tbl_t *self, size_t num_rows); +int tsk_migration_tbl_free(tsk_migration_tbl_t *self); +int tsk_migration_tbl_copy(tsk_migration_tbl_t *self, tsk_migration_tbl_t *dest); +int tsk_migration_tbl_dump_text(tsk_migration_tbl_t *self, FILE *out); +void tsk_migration_tbl_print_state(tsk_migration_tbl_t *self, FILE *out); +bool tsk_migration_tbl_equals(tsk_migration_tbl_t *self, tsk_migration_tbl_t *other); +int tsk_migration_tbl_get_row(tsk_migration_tbl_t *self, size_t index, tsk_migration_t *row); + +int tsk_population_tbl_alloc(tsk_population_tbl_t *self, int flags); +int tsk_population_tbl_set_max_rows_increment(tsk_population_tbl_t *self, size_t max_rows_increment); +int tsk_population_tbl_set_max_metadata_length_increment(tsk_population_tbl_t *self, + size_t max_metadata_length_increment); +tsk_id_t tsk_population_tbl_add_row(tsk_population_tbl_t *self, + const char *metadata, size_t metadata_length); +int tsk_population_tbl_set_columns(tsk_population_tbl_t *self, size_t num_rows, + const char *metadata, tsk_tbl_size_t *metadata_offset); +int tsk_population_tbl_append_columns(tsk_population_tbl_t *self, size_t num_rows, + const char *metadata, tsk_tbl_size_t *metadata_offset); +int tsk_population_tbl_clear(tsk_population_tbl_t *self); +int tsk_population_tbl_truncate(tsk_population_tbl_t *self, size_t num_rows); +int tsk_population_tbl_copy(tsk_population_tbl_t *self, tsk_population_tbl_t *dest); +int tsk_population_tbl_free(tsk_population_tbl_t *self); +void tsk_population_tbl_print_state(tsk_population_tbl_t *self, FILE *out); +int tsk_population_tbl_dump_text(tsk_population_tbl_t *self, FILE *out); +bool tsk_population_tbl_equals(tsk_population_tbl_t *self, tsk_population_tbl_t *other); +int tsk_population_tbl_get_row(tsk_population_tbl_t *self, size_t index, tsk_population_t *row); + +int tsk_provenance_tbl_alloc(tsk_provenance_tbl_t *self, int flags); +int tsk_provenance_tbl_set_max_rows_increment(tsk_provenance_tbl_t *self, size_t max_rows_increment); +int tsk_provenance_tbl_set_max_timestamp_length_increment(tsk_provenance_tbl_t *self, + size_t max_timestamp_length_increment); +int tsk_provenance_tbl_set_max_record_length_increment(tsk_provenance_tbl_t *self, + size_t max_record_length_increment); +tsk_id_t tsk_provenance_tbl_add_row(tsk_provenance_tbl_t *self, + const char *timestamp, size_t timestamp_length, + const char *record, size_t record_length); +int tsk_provenance_tbl_set_columns(tsk_provenance_tbl_t *self, size_t num_rows, + char *timestamp, tsk_tbl_size_t *timestamp_offset, + char *record, tsk_tbl_size_t *record_offset); +int tsk_provenance_tbl_append_columns(tsk_provenance_tbl_t *self, size_t num_rows, + char *timestamp, tsk_tbl_size_t *timestamp_offset, + char *record, tsk_tbl_size_t *record_offset); +int tsk_provenance_tbl_clear(tsk_provenance_tbl_t *self); +int tsk_provenance_tbl_truncate(tsk_provenance_tbl_t *self, size_t num_rows); +int tsk_provenance_tbl_copy(tsk_provenance_tbl_t *self, tsk_provenance_tbl_t *dest); +int tsk_provenance_tbl_free(tsk_provenance_tbl_t *self); +int tsk_provenance_tbl_dump_text(tsk_provenance_tbl_t *self, FILE *out); +void tsk_provenance_tbl_print_state(tsk_provenance_tbl_t *self, FILE *out); +bool tsk_provenance_tbl_equals(tsk_provenance_tbl_t *self, tsk_provenance_tbl_t *other); +int tsk_provenance_tbl_get_row(tsk_provenance_tbl_t *self, size_t index, tsk_provenance_t *row); + +/****************************************************************************/ +/* Table collection .*/ +/****************************************************************************/ + +int tsk_tbl_collection_alloc(tsk_tbl_collection_t *self, int flags); +int tsk_tbl_collection_load(tsk_tbl_collection_t *self, const char *filename, int flags); +int tsk_tbl_collection_dump(tsk_tbl_collection_t *tables, const char *filename, int flags); +int tsk_tbl_collection_copy(tsk_tbl_collection_t *self, tsk_tbl_collection_t *dest); +int tsk_tbl_collection_print_state(tsk_tbl_collection_t *self, FILE *out); +int tsk_tbl_collection_free(tsk_tbl_collection_t *self); + +bool tsk_tbl_collection_is_indexed(tsk_tbl_collection_t *self); +int tsk_tbl_collection_drop_indexes(tsk_tbl_collection_t *self); +int tsk_tbl_collection_build_indexes(tsk_tbl_collection_t *self, int flags); +int tsk_tbl_collection_simplify(tsk_tbl_collection_t *self, + tsk_id_t *samples, size_t num_samples, int flags, tsk_id_t *node_map); +int tsk_tbl_collection_sort(tsk_tbl_collection_t *self, size_t edge_start, int flags); +int tsk_tbl_collection_deduplicate_sites(tsk_tbl_collection_t *tables, int flags); +int tsk_tbl_collection_compute_mutation_parents(tsk_tbl_collection_t *self, int flags); +bool tsk_tbl_collection_equals(tsk_tbl_collection_t *self, tsk_tbl_collection_t *other); +int tsk_tbl_collection_record_position(tsk_tbl_collection_t *self, + tsk_tbl_collection_position_t *position); +int tsk_tbl_collection_reset_position(tsk_tbl_collection_t *self, + tsk_tbl_collection_position_t *position); +int tsk_tbl_collection_clear(tsk_tbl_collection_t *self); +int tsk_tbl_collection_check_integrity(tsk_tbl_collection_t *self, int flags); + +int tsk_squash_edges(tsk_edge_t *edges, size_t num_edges, size_t *num_output_edges); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/tsk_trees.c b/c/tsk_trees.c new file mode 100644 index 0000000000..0b9c37778a --- /dev/null +++ b/c/tsk_trees.c @@ -0,0 +1,2190 @@ +#include +#include +#include +#include +#include + +#include "tsk_trees.h" + + +/* ======================================================== * + * tree sequence + * ======================================================== */ + +static void +tsk_treeseq_check_state(tsk_treeseq_t *self) +{ + size_t j; + tsk_tbl_size_t k, l; + tsk_site_t site; + tsk_id_t site_id = 0; + + for (j = 0; j < self->num_trees; j++) { + for (k = 0; k < self->tree_sites_length[j]; k++) { + site = self->tree_sites[j][k]; + assert(site.id == site_id); + site_id++; + for (l = 0; l < site.mutations_length; l++) { + assert(site.mutations[l].site == site.id); + } + } + } +} + +void +tsk_treeseq_print_state(tsk_treeseq_t *self, FILE *out) +{ + size_t j; + tsk_tbl_size_t k, l, m; + tsk_site_t site; + + fprintf(out, "tree_sequence state\n"); + fprintf(out, "num_trees = %d\n", (int) self->num_trees); + fprintf(out, "samples = (%d)\n", (int) self->num_samples); + for (j = 0; j < self->num_samples; j++) { + fprintf(out, "\t%d\n", (int) self->samples[j]); + } + tsk_tbl_collection_print_state(self->tables, out); + fprintf(out, "tree_sites = \n"); + for (j = 0; j < self->num_trees; j++) { + fprintf(out, "tree %d\t%d sites\n", (int) j, self->tree_sites_length[j]); + for (k = 0; k < self->tree_sites_length[j]; k++) { + site = self->tree_sites[j][k]; + fprintf(out, "\tsite %d pos = %f ancestral state = ", site.id, site.position); + for (l = 0; l < site.ancestral_state_length; l++) { + fprintf(out, "%c", site.ancestral_state[l]); + } + fprintf(out, " %d mutations\n", site.mutations_length); + for (l = 0; l < site.mutations_length; l++) { + fprintf(out, "\t\tmutation %d node = %d derived_state = ", + site.mutations[l].id, site.mutations[l].node); + for (m = 0; m < site.mutations[l].derived_state_length; m++) { + fprintf(out, "%c", site.mutations[l].derived_state[m]); + } + fprintf(out, "\n"); + } + } + } + tsk_treeseq_check_state(self); +} + +int +tsk_treeseq_free(tsk_treeseq_t *self) +{ + if (self->tables != NULL) { + tsk_tbl_collection_free(self->tables); + } + tsk_safe_free(self->tables); + tsk_safe_free(self->samples); + tsk_safe_free(self->sample_index_map); + tsk_safe_free(self->tree_sites); + tsk_safe_free(self->tree_sites_length); + tsk_safe_free(self->tree_sites_mem); + tsk_safe_free(self->site_mutations_mem); + tsk_safe_free(self->site_mutations_length); + tsk_safe_free(self->site_mutations); + tsk_safe_free(self->individual_nodes_mem); + tsk_safe_free(self->individual_nodes_length); + tsk_safe_free(self->individual_nodes); + return 0; +} + +static int +tsk_treeseq_init_sites(tsk_treeseq_t *self) +{ + size_t j; + tsk_tbl_size_t k; + int ret = 0; + size_t offset = 0; + const tsk_tbl_size_t num_mutations = self->tables->mutations->num_rows; + const tsk_tbl_size_t num_sites = self->tables->sites->num_rows; + const tsk_id_t *restrict mutation_site = self->tables->mutations->site; + + self->site_mutations_mem = malloc(num_mutations * sizeof(tsk_mutation_t)); + self->site_mutations_length = malloc(num_sites * sizeof(tsk_tbl_size_t)); + self->site_mutations = malloc(num_sites * sizeof(tsk_mutation_t *)); + self->tree_sites_mem = malloc(num_sites * sizeof(tsk_site_t)); + if (self->site_mutations_mem == NULL + || self->site_mutations_length == NULL + || self->site_mutations == NULL + || self->tree_sites_mem == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (k = 0; k < num_mutations; k++) { + ret = tsk_treeseq_get_mutation(self, k, self->site_mutations_mem + k); + if (ret != 0) { + goto out; + } + } + k = 0; + for (j = 0; j < num_sites; j++) { + self->site_mutations[j] = self->site_mutations_mem + offset; + self->site_mutations_length[j] = 0; + /* Go through all mutations for this site */ + while (k < num_mutations && mutation_site[k] == (tsk_id_t) j) { + self->site_mutations_length[j]++; + offset++; + k++; + } + ret = tsk_treeseq_get_site(self, j, self->tree_sites_mem + j); + if (ret != 0) { + goto out; + } + } +out: + return ret; +} + +static int +tsk_treeseq_init_individuals(tsk_treeseq_t *self) +{ + int ret = 0; + tsk_id_t node; + tsk_id_t ind; + tsk_tbl_size_t offset = 0; + tsk_tbl_size_t total_node_refs = 0; + tsk_tbl_size_t *node_count = NULL; + tsk_id_t *node_array; + const size_t num_inds = self->tables->individuals->num_rows; + const size_t num_nodes = self->tables->nodes->num_rows; + const tsk_id_t *restrict node_individual = self->tables->nodes->individual; + + // First find number of nodes per individual + self->individual_nodes_length = calloc(TSK_MAX(1, num_inds), sizeof(tsk_tbl_size_t)); + node_count = calloc(TSK_MAX(1, num_inds), sizeof(size_t)); + + if (self->individual_nodes_length == NULL || node_count == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + for (node = 0; node < (tsk_id_t) num_nodes; node++) { + ind = node_individual[node]; + if (ind != TSK_NULL) { + self->individual_nodes_length[ind]++; + total_node_refs++; + } + } + + self->individual_nodes_mem = malloc(TSK_MAX(1, total_node_refs) * sizeof(tsk_node_t)); + self->individual_nodes = malloc(TSK_MAX(1, num_inds) * sizeof(tsk_node_t *)); + if (self->individual_nodes_mem == NULL || self->individual_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + /* Now fill in the node IDs */ + for (ind = 0; ind < (tsk_id_t) num_inds; ind++) { + self->individual_nodes[ind] = self->individual_nodes_mem + offset; + offset += self->individual_nodes_length[ind]; + } + for (node = 0; node < (tsk_id_t) num_nodes; node++) { + ind = node_individual[node]; + if (ind != TSK_NULL) { + node_array = self->individual_nodes[ind]; + assert(node_array - self->individual_nodes_mem + < total_node_refs - node_count[ind]); + node_array[node_count[ind]] = node; + node_count[ind] += 1; + } + } +out: + tsk_safe_free(node_count); + return ret; +} + +/* Initialises memory associated with the trees. + */ +static int +tsk_treeseq_init_trees(tsk_treeseq_t *self) +{ + int ret = TSK_ERR_GENERIC; + size_t j, k, tree_index; + tsk_id_t site; + double tree_left, tree_right; + const double sequence_length = self->tables->sequence_length; + const tsk_id_t num_sites = (tsk_id_t) self->tables->sites->num_rows; + const size_t num_edges = self->tables->edges->num_rows; + const double * restrict site_position = self->tables->sites->position; + const tsk_id_t * restrict I = self->tables->indexes.edge_insertion_order; + const tsk_id_t * restrict O = self->tables->indexes.edge_removal_order; + const double * restrict edge_right = self->tables->edges->right; + const double * restrict edge_left = self->tables->edges->left; + + tree_left = 0; + tree_right = sequence_length; + self->num_trees = 0; + j = 0; + k = 0; + assert(I != NULL && O != NULL); + while (j < num_edges || tree_left < sequence_length) { + while (k < num_edges && edge_right[O[k]] == tree_left) { + k++; + } + while (j < num_edges && edge_left[I[j]] == tree_left) { + j++; + } + tree_right = sequence_length; + if (j < num_edges) { + tree_right = TSK_MIN(tree_right, edge_left[I[j]]); + } + if (k < num_edges) { + tree_right = TSK_MIN(tree_right, edge_right[O[k]]); + } + tree_left = tree_right; + self->num_trees++; + } + assert(self->num_trees > 0); + + self->tree_sites_length = malloc(self->num_trees * sizeof(tsk_tbl_size_t)); + self->tree_sites = malloc(self->num_trees * sizeof(tsk_site_t *)); + if (self->tree_sites == NULL || self->tree_sites_length == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(self->tree_sites_length, 0, self->num_trees * sizeof(tsk_tbl_size_t)); + memset(self->tree_sites, 0, self->num_trees * sizeof(tsk_site_t *)); + + tree_left = 0; + tree_right = sequence_length; + tree_index = 0; + site = 0; + j = 0; + k = 0; + while (j < num_edges || tree_left < sequence_length) { + while (k < num_edges && edge_right[O[k]] == tree_left) { + k++; + } + while (j < num_edges && edge_left[I[j]] == tree_left) { + j++; + } + tree_right = sequence_length; + if (j < num_edges) { + tree_right = TSK_MIN(tree_right, edge_left[I[j]]); + } + if (k < num_edges) { + tree_right = TSK_MIN(tree_right, edge_right[O[k]]); + } + self->tree_sites[tree_index] = self->tree_sites_mem + site; + while (site < num_sites && site_position[site] < tree_right) { + self->tree_sites_length[tree_index]++; + site++; + } + tree_left = tree_right; + tree_index++; + } + assert(site == num_sites); + assert(tree_index == self->num_trees); + ret = 0; +out: + return ret; +} + +static int +tsk_treeseq_init_nodes(tsk_treeseq_t *self) +{ + size_t j, k; + size_t num_nodes = self->tables->nodes->num_rows; + const uint32_t *restrict node_flags = self->tables->nodes->flags; + int ret = 0; + + /* Determine the sample size */ + self->num_samples = 0; + for (j = 0; j < num_nodes; j++) { + if (!!(node_flags[j] & TSK_NODE_IS_SAMPLE)) { + self->num_samples++; + } + } + /* TODO raise an error if < 2 samples?? */ + self->samples = malloc(self->num_samples * sizeof(tsk_id_t)); + self->sample_index_map = malloc(num_nodes * sizeof(tsk_id_t)); + if (self->samples == NULL || self->sample_index_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + k = 0; + for (j = 0; j < num_nodes; j++) { + self->sample_index_map[j] = -1; + if (!!(node_flags[j] & TSK_NODE_IS_SAMPLE)) { + self->samples[k] = (tsk_id_t) j; + self->sample_index_map[j] = (tsk_id_t) k; + k++; + } + } + assert(k == self->num_samples); +out: + return ret; +} + +/* TODO we need flags to be able to control how the input table is used. + * - The default behaviour is to take a copy. TSK_BUILD_INDEXES is allowed + * in this case because we have an independent copy. + * - Need an option to take 'ownership' of the tables so that we keep the + * tables and free them at the end of the treeseq's lifetime. This will be + * used in tsk_treeseq_load below, where we can take advantage of the read-only + * access directly into the store's memory and avoid copying the tree sequence. + * - We should also allow a read-only "borrowed reference" where we use the + * tables directly, but don't free it at the end. + */ +int TSK_WARN_UNUSED +tsk_treeseq_alloc(tsk_treeseq_t *self, tsk_tbl_collection_t *tables, int flags) +{ + int ret = 0; + + memset(self, 0, sizeof(*self)); + if (tables == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + self->tables = malloc(sizeof(*self->tables)); + if (self->tables == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_tbl_collection_alloc(self->tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_copy(tables, self->tables); + if (ret != 0) { + goto out; + } + if (!!(flags & TSK_BUILD_INDEXES)) { + ret = tsk_tbl_collection_build_indexes(self->tables, 0); + if (ret != 0) { + goto out; + } + } + ret = tsk_tbl_collection_check_integrity(self->tables, TSK_CHECK_ALL); + if (ret != 0) { + goto out; + } + assert(tsk_tbl_collection_is_indexed(self->tables)); + + /* This is a hack to workaround the fact we're copying the tables here. + * In general, we don't want the file_uuid to be copied, as this should + * only be present if the tables are genuinely backed by a file and in + * read-only mode (which we also need to do). So, we copy the file_uuid + * into the local copy of the table for now until we have proper read-only + * access to the tables set up, where any attempts to modify the tables + * will fail. */ + if (tables->file_uuid != NULL) { + self->tables->file_uuid = malloc(TSK_UUID_SIZE + 1); + if (self->tables->file_uuid == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memcpy(self->tables->file_uuid, tables->file_uuid, TSK_UUID_SIZE + 1); + } + + ret = tsk_treeseq_init_nodes(self); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init_sites(self); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init_individuals(self); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_init_trees(self); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_copy_tables(tsk_treeseq_t *self, tsk_tbl_collection_t *tables) +{ + return tsk_tbl_collection_copy(self->tables, tables); +} + +int TSK_WARN_UNUSED +tsk_treeseq_load(tsk_treeseq_t *self, const char *filename, int TSK_UNUSED(flags)) +{ + int ret = 0; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_load(&tables, filename, 0); + if (ret != 0) { + goto out; + } + /* TODO the implementation is wasteful here, as we don't need to allocate + * a new table here but could load directly into the main table instead. + * See notes on the owned reference for treeseq_alloc above. + */ + ret = tsk_treeseq_alloc(self, &tables, 0); + if (ret != 0) { + goto out; + } +out: + tsk_tbl_collection_free(&tables); + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_dump(tsk_treeseq_t *self, const char *filename, int flags) +{ + return tsk_tbl_collection_dump(self->tables, filename, flags); +} + +/* Simple attribute getters */ + +double +tsk_treeseq_get_sequence_length(tsk_treeseq_t *self) +{ + return self->tables->sequence_length; +} + +char * +tsk_treeseq_get_file_uuid(tsk_treeseq_t *self) +{ + return self->tables->file_uuid; +} + +size_t +tsk_treeseq_get_num_samples(tsk_treeseq_t *self) +{ + return self->num_samples; +} + +size_t +tsk_treeseq_get_num_nodes(tsk_treeseq_t *self) +{ + return self->tables->nodes->num_rows; +} + +size_t +tsk_treeseq_get_num_edges(tsk_treeseq_t *self) +{ + return self->tables->edges->num_rows; +} + +size_t +tsk_treeseq_get_num_migrations(tsk_treeseq_t *self) +{ + return self->tables->migrations->num_rows; +} + +size_t +tsk_treeseq_get_num_sites(tsk_treeseq_t *self) +{ + return self->tables->sites->num_rows; +} + +size_t +tsk_treeseq_get_num_mutations(tsk_treeseq_t *self) +{ + return self->tables->mutations->num_rows; +} + +size_t +tsk_treeseq_get_num_populations(tsk_treeseq_t *self) +{ + return self->tables->populations->num_rows; +} + +size_t +tsk_treeseq_get_num_individuals(tsk_treeseq_t *self) +{ + return self->tables->individuals->num_rows; +} + +size_t +tsk_treeseq_get_num_provenances(tsk_treeseq_t *self) +{ + return self->tables->provenances->num_rows; +} + +size_t +tsk_treeseq_get_num_trees(tsk_treeseq_t *self) +{ + return self->num_trees; +} + +bool +tsk_treeseq_is_sample(tsk_treeseq_t *self, tsk_id_t u) +{ + bool ret = false; + + if (u >= 0 && u < (tsk_id_t) self->tables->nodes->num_rows) { + ret = !!(self->tables->nodes->flags[u] & TSK_NODE_IS_SAMPLE); + } + return ret; +} + +/* Accessors for records */ + +int TSK_WARN_UNUSED +tsk_treeseq_get_pairwise_diversity(tsk_treeseq_t *self, + tsk_id_t *samples, size_t num_samples, double *pi) +{ + int ret = 0; + tsk_tree_t *tree = NULL; + double result, denom, n, count; + tsk_site_t *sites; + tsk_tbl_size_t j, k, num_sites; + + if (num_samples < 2 || num_samples > self->num_samples) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + n = (double) num_samples; + tree = malloc(sizeof(tsk_tree_t)); + if (tree == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_tree_alloc(tree, self, TSK_SAMPLE_COUNTS); + if (ret != 0) { + goto out; + } + ret = tsk_tree_set_tracked_samples(tree, num_samples, samples); + if (ret != 0) { + goto out; + } + /* Allocation done; move onto main algorithm. */ + result = 0.0; + for (ret = tsk_tree_first(tree); ret == 1; ret = tsk_tree_next(tree)) { + ret = tsk_tree_get_sites(tree, &sites, &num_sites); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_sites; j++) { + if (sites[j].mutations_length != 1) { + ret = TSK_ERR_ONLY_INFINITE_SITES; + goto out; + } + for (k = 0; k < sites[j].mutations_length; k++) { + count = (double) tree->num_tracked_samples[sites[j].mutations[k].node]; + result += count * (n - count); + } + } + } + if (ret != 0) { + goto out; + } + denom = (n * (n - 1)) / 2.0; + *pi = result / denom; +out: + if (tree != NULL) { + tsk_tree_free(tree); + free(tree); + } + return ret; +} + +#define GET_2D_ROW(array, row_len, row) (array + (((size_t) (row_len)) * (size_t) row)) + +int TSK_WARN_UNUSED +tsk_treeseq_genealogical_nearest_neighbours(tsk_treeseq_t *self, + tsk_id_t *focal, size_t num_focal, + tsk_id_t **reference_sets, size_t *reference_set_size, size_t num_reference_sets, + int TSK_UNUSED(flags), double *ret_array) +{ + int ret = 0; + tsk_id_t u, v, p; + size_t j; + /* TODO It's probably not worth bothering with the int16_t here. */ + int16_t k, focal_reference_set; + /* We use the K'th element of the array for the total. */ + const int16_t K = (int16_t) (num_reference_sets + 1); + size_t num_nodes = self->tables->nodes->num_rows; + const tsk_id_t num_edges = (tsk_id_t) self->tables->edges->num_rows; + const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; + const double *restrict edge_left = self->tables->edges->left; + const double *restrict edge_right = self->tables->edges->right; + const tsk_id_t *restrict edge_parent = self->tables->edges->parent; + const tsk_id_t *restrict edge_child = self->tables->edges->child; + const double sequence_length = self->tables->sequence_length; + tsk_id_t tj, tk, h; + double left, right, *A_row, scale, tree_length; + tsk_id_t *restrict parent = malloc(num_nodes * sizeof(*parent)); + double *restrict length = calloc(num_focal, sizeof(*length)); + uint32_t *restrict ref_count = calloc(num_nodes * ((size_t) K), sizeof(*ref_count)); + int16_t *restrict reference_set_map = malloc(num_nodes * sizeof(*reference_set_map)); + uint32_t *restrict row, *restrict child_row, total; + + /* We support a max of 8K focal sets */ + if (num_reference_sets == 0 || num_reference_sets > (INT16_MAX - 1)) { + /* TODO: more specific error */ + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if (parent == NULL || ref_count == NULL || reference_set_map == NULL + || length == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + memset(parent, 0xff, num_nodes * sizeof(*parent)); + memset(reference_set_map, 0xff, num_nodes * sizeof(*reference_set_map)); + memset(ret_array, 0, num_focal * num_reference_sets * sizeof(*ret_array)); + + /* Set the initial conditions and check the input. */ + for (k = 0; k < (int16_t) num_reference_sets; k++) { + for (j = 0; j < reference_set_size[k]; j++) { + u = reference_sets[k][j]; + if (u < 0 || u >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (reference_set_map[u] != TSK_NULL) { + /* FIXME Technically inaccurate here: duplicate focal not sample */ + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + reference_set_map[u] = k; + row = GET_2D_ROW(ref_count, K, u); + row[k] = 1; + /* Also set the count for the total among all sets */ + row[K - 1] = 1; + } + } + for (j = 0; j < num_focal; j++) { + u = focal[j]; + if (u < 0 || u >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } + + /* Iterate over the trees */ + tj = 0; + tk = 0; + left = 0; + while (tj < num_edges || left < sequence_length) { + while (tk < num_edges && edge_right[O[tk]] == left) { + h = O[tk]; + tk++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = TSK_NULL; + child_row = GET_2D_ROW(ref_count, K, u); + while (v != TSK_NULL) { + row = GET_2D_ROW(ref_count, K, v); + for (k = 0; k < K; k++) { + row[k] -= child_row[k]; + } + v = parent[v]; + } + } + while (tj < num_edges && edge_left[I[tj]] == left) { + h = I[tj]; + tj++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = v; + child_row = GET_2D_ROW(ref_count, K, u); + while (v != TSK_NULL) { + row = GET_2D_ROW(ref_count, K, v); + for (k = 0; k < K; k++) { + row[k] += child_row[k]; + } + v = parent[v]; + } + } + right = sequence_length; + if (tj < num_edges) { + right = TSK_MIN(right, edge_left[I[tj]]); + } + if (tk < num_edges) { + right = TSK_MIN(right, edge_right[O[tk]]); + } + + tree_length = right - left; + /* Process this tree */ + for (j = 0; j < num_focal; j++) { + u = focal[j]; + p = parent[u]; + while (p != TSK_NULL) { + row = GET_2D_ROW(ref_count, K, p); + total = row[K - 1]; + if (total > 1) { + break; + } + p = parent[p]; + } + if (p != TSK_NULL) { + length[j] += tree_length; + focal_reference_set = reference_set_map[u]; + scale = tree_length / (total - (focal_reference_set != -1)); + A_row = GET_2D_ROW(ret_array, num_reference_sets, j); + for (k = 0; k < K - 1; k++) { + A_row[k] += row[k] * scale; + } + if (focal_reference_set != -1) { + /* Remove the contribution for the reference set u belongs to and + * insert the correct value. The long-hand version is + * A_row[k] = A_row[k] - row[k] * scale + (row[k] - 1) * scale; + * which cancels to give: */ + A_row[focal_reference_set] -= scale; + } + } + } + + /* Move on to the next tree */ + left = right; + } + + /* Divide by the accumulated length for each node to normalise */ + for (j = 0; j < num_focal; j++) { + A_row = GET_2D_ROW(ret_array, num_reference_sets, j); + if (length[j] > 0) { + for (k = 0; k < K - 1; k++) { + A_row[k] /= length[j]; + } + } + } +out: + /* Can't use msp_safe_free here because of restrict */ + if (parent != NULL) { + free(parent); + } + if (ref_count != NULL) { + free(ref_count); + } + if (reference_set_map != NULL) { + free(reference_set_map); + } + if (length != NULL) { + free(length); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_mean_descendants(tsk_treeseq_t *self, + tsk_id_t **reference_sets, size_t *reference_set_size, size_t num_reference_sets, + int TSK_UNUSED(flags), double *ret_array) +{ + int ret = 0; + tsk_id_t u, v; + size_t j; + int32_t k; + /* We use the K'th element of the array for the total. */ + const int32_t K = (int32_t) (num_reference_sets + 1); + size_t num_nodes = self->tables->nodes->num_rows; + const tsk_id_t num_edges = (tsk_id_t) self->tables->edges->num_rows; + const tsk_id_t *restrict I = self->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->tables->indexes.edge_removal_order; + const double *restrict edge_left = self->tables->edges->left; + const double *restrict edge_right = self->tables->edges->right; + const tsk_id_t *restrict edge_parent = self->tables->edges->parent; + const tsk_id_t *restrict edge_child = self->tables->edges->child; + const double sequence_length = self->tables->sequence_length; + tsk_id_t tj, tk, h; + double left, right, length, *restrict C_row; + tsk_id_t *restrict parent = malloc(num_nodes * sizeof(*parent)); + uint32_t *restrict ref_count = calloc(num_nodes * ((size_t) K), sizeof(*ref_count)); + double *restrict last_update = calloc(num_nodes, sizeof(*last_update)); + double *restrict total_length = calloc(num_nodes, sizeof(*total_length)); + uint32_t *restrict row, *restrict child_row; + + if (num_reference_sets == 0 || num_reference_sets > (INT32_MAX - 1)) { + /* TODO: more specific error */ + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if (parent == NULL || ref_count == NULL || last_update == NULL + || total_length == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + /* TODO add check for duplicate values in the reference sets */ + + memset(parent, 0xff, num_nodes * sizeof(*parent)); + memset(ret_array, 0, num_nodes * num_reference_sets * sizeof(*ret_array)); + + /* Set the initial conditions and check the input. */ + for (k = 0; k < (int32_t) num_reference_sets; k++) { + for (j = 0; j < reference_set_size[k]; j++) { + u = reference_sets[k][j]; + if (u < 0 || u >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + row = GET_2D_ROW(ref_count, K, u); + row[k] = 1; + /* Also set the count for the total among all sets */ + row[K - 1] = 1; + } + } + + /* Iterate over the trees */ + tj = 0; + tk = 0; + left = 0; + while (tj < num_edges || left < sequence_length) { + while (tk < num_edges && edge_right[O[tk]] == left) { + h = O[tk]; + tk++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = TSK_NULL; + child_row = GET_2D_ROW(ref_count, K, u); + while (v != TSK_NULL) { + row = GET_2D_ROW(ref_count, K, v); + if (last_update[v] != left) { + if (row[K - 1] > 0) { + length = left - last_update[v]; + C_row = GET_2D_ROW(ret_array, num_reference_sets, v); + for (k = 0; k < (int32_t) num_reference_sets; k++) { + C_row[k] += length * row[k]; + } + total_length[v] += length; + } + last_update[v] = left; + } + for (k = 0; k < K; k++) { + row[k] -= child_row[k]; + } + v = parent[v]; + } + } + while (tj < num_edges && edge_left[I[tj]] == left) { + h = I[tj]; + tj++; + u = edge_child[h]; + v = edge_parent[h]; + parent[u] = v; + child_row = GET_2D_ROW(ref_count, K, u); + while (v != TSK_NULL) { + row = GET_2D_ROW(ref_count, K, v); + if (last_update[v] != left) { + if (row[K - 1] > 0) { + length = left - last_update[v]; + C_row = GET_2D_ROW(ret_array, num_reference_sets, v); + for (k = 0; k < (int32_t) num_reference_sets; k++) { + C_row[k] += length * row[k]; + } + total_length[v] += length; + } + last_update[v] = left; + } + for (k = 0; k < K; k++) { + row[k] += child_row[k]; + } + v = parent[v]; + } + } + right = sequence_length; + if (tj < num_edges) { + right = TSK_MIN(right, edge_left[I[tj]]); + } + if (tk < num_edges) { + right = TSK_MIN(right, edge_right[O[tk]]); + } + left = right; + } + + /* Add the stats for the last tree and divide by the total length that + * each node was an ancestor to > 0 of the reference nodes. */ + for (v = 0; v < (tsk_id_t) num_nodes; v++) { + row = GET_2D_ROW(ref_count, K, v); + C_row = GET_2D_ROW(ret_array, num_reference_sets, v); + if (row[K - 1] > 0) { + length = sequence_length - last_update[v]; + total_length[v] += length; + for (k = 0; k < (int32_t) num_reference_sets; k++) { + C_row[k] += length * row[k]; + } + } + if (total_length[v] > 0) { + length = total_length[v]; + for (k = 0; k < (int32_t) num_reference_sets; k++) { + C_row[k] /= length; + } + } + } + +out: + /* Can't use msp_safe_free here because of restrict */ + if (parent != NULL) { + free(parent); + } + if (ref_count != NULL) { + free(ref_count); + } + if (last_update != NULL) { + free(last_update); + } + if (total_length != NULL) { + free(total_length); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_node(tsk_treeseq_t *self, size_t index, tsk_node_t *node) +{ + return tsk_node_tbl_get_row(self->tables->nodes, index, node); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_edge(tsk_treeseq_t *self, size_t index, tsk_edge_t *edge) +{ + return tsk_edge_tbl_get_row(self->tables->edges, index, edge); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_migration(tsk_treeseq_t *self, size_t index, tsk_migration_t *migration) +{ + return tsk_migration_tbl_get_row(self->tables->migrations, index, migration); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_mutation(tsk_treeseq_t *self, size_t index, tsk_mutation_t *mutation) +{ + return tsk_mutation_tbl_get_row(self->tables->mutations, index, mutation); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_site(tsk_treeseq_t *self, size_t index, tsk_site_t *site) +{ + int ret = 0; + + ret = tsk_site_tbl_get_row(self->tables->sites, index, site); + if (ret != 0) { + goto out; + } + site->mutations = self->site_mutations[index]; + site->mutations_length = self->site_mutations_length[index]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_individual(tsk_treeseq_t *self, size_t index, tsk_individual_t *individual) +{ + int ret = 0; + + ret = tsk_individual_tbl_get_row(self->tables->individuals, index, individual); + if (ret != 0) { + goto out; + } + individual->nodes = self->individual_nodes[index]; + individual->nodes_length = self->individual_nodes_length[index]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_population(tsk_treeseq_t *self, size_t index, + tsk_population_t *population) +{ + return tsk_population_tbl_get_row(self->tables->populations, index, population); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_provenance(tsk_treeseq_t *self, size_t index, tsk_provenance_t *provenance) +{ + return tsk_provenance_tbl_get_row(self->tables->provenances, index, provenance); +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_samples(tsk_treeseq_t *self, tsk_id_t **samples) +{ + *samples = self->samples; + return 0; +} + +int TSK_WARN_UNUSED +tsk_treeseq_get_sample_index_map(tsk_treeseq_t *self, tsk_id_t **sample_index_map) +{ + *sample_index_map = self->sample_index_map; + return 0; +} + +int TSK_WARN_UNUSED +tsk_treeseq_simplify(tsk_treeseq_t *self, tsk_id_t *samples, size_t num_samples, + int flags, tsk_treeseq_t *output, tsk_id_t *node_map) +{ + int ret = 0; + tsk_tbl_collection_t tables; + + ret = tsk_tbl_collection_alloc(&tables, 0); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_copy_tables(self, &tables); + if (ret != 0) { + goto out; + } + ret = tsk_tbl_collection_simplify(&tables, samples, num_samples, flags, node_map); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_alloc(output, &tables, TSK_BUILD_INDEXES); +out: + tsk_tbl_collection_free(&tables); + return ret; +} + +/* ======================================================== * + * Tree + * ======================================================== */ + +static int TSK_WARN_UNUSED +tsk_tree_clear(tsk_tree_t *self) +{ + int ret = 0; + size_t j; + tsk_id_t u; + const size_t N = self->num_nodes; + const size_t num_samples = self->tree_sequence->num_samples; + const bool sample_counts = !!(self->flags & TSK_SAMPLE_COUNTS); + const bool sample_lists = !!(self->flags & TSK_SAMPLE_LISTS); + + self->left = 0; + self->right = 0; + self->index = (size_t) -1; + /* TODO we should profile this method to see if just doing a single loop over + * the nodes would be more efficient than multiple memsets. + */ + memset(self->parent, 0xff, N * sizeof(tsk_id_t)); + memset(self->left_child, 0xff, N * sizeof(tsk_id_t)); + memset(self->right_child, 0xff, N * sizeof(tsk_id_t)); + memset(self->left_sib, 0xff, N * sizeof(tsk_id_t)); + memset(self->right_sib, 0xff, N * sizeof(tsk_id_t)); + memset(self->above_sample, 0, N * sizeof(bool)); + if (sample_counts) { + memset(self->num_samples, 0, N * sizeof(tsk_id_t)); + memset(self->marked, 0, N * sizeof(uint8_t)); + /* We can't reset the tracked samples via memset because we don't + * know where the tracked samples are. + */ + for (j = 0; j < self->num_nodes; j++) { + if (! tsk_treeseq_is_sample(self->tree_sequence, (tsk_id_t) j)) { + self->num_tracked_samples[j] = 0; + } + } + } + if (sample_lists) { + memset(self->left_sample, 0xff, N * sizeof(tsk_id_t)); + memset(self->right_sample, 0xff, N * sizeof(tsk_id_t)); + memset(self->next_sample, 0xff, num_samples * sizeof(tsk_id_t)); + } + /* Set the sample attributes */ + self->left_root = TSK_NULL; + if (num_samples > 0) { + self->left_root = self->samples[0]; + } + for (j = 0; j < num_samples; j++) { + u = self->samples[j]; + self->above_sample[u] = true; + if (sample_counts) { + self->num_samples[u] = 1; + } + if (sample_lists) { + /* We are mapping to *indexes* into the list of samples here */ + self->left_sample[u] = (tsk_id_t) j; + self->right_sample[u] = (tsk_id_t) j; + } + /* Set initial roots */ + if (j < num_samples - 1) { + self->right_sib[u] = self->samples[j + 1]; + } + if (j > 0) { + self->left_sib[u] = self->samples[j - 1]; + } + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_alloc(tsk_tree_t *self, tsk_treeseq_t *tree_sequence, int flags) +{ + int ret = TSK_ERR_NO_MEMORY; + size_t num_samples; + size_t num_nodes; + + memset(self, 0, sizeof(tsk_tree_t)); + if (tree_sequence == NULL) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + num_nodes = tree_sequence->tables->nodes->num_rows; + num_samples = tree_sequence->num_samples; + self->num_nodes = num_nodes; + self->tree_sequence = tree_sequence; + self->samples = tree_sequence->samples; + self->flags = flags; + self->parent = malloc(num_nodes * sizeof(tsk_id_t)); + self->left_child = malloc(num_nodes * sizeof(tsk_id_t)); + self->right_child = malloc(num_nodes * sizeof(tsk_id_t)); + self->left_sib = malloc(num_nodes * sizeof(tsk_id_t)); + self->right_sib = malloc(num_nodes * sizeof(tsk_id_t)); + self->above_sample = malloc(num_nodes * sizeof(bool)); + if (self->parent == NULL || self->left_child == NULL || self->right_child == NULL + || self->left_sib == NULL || self->right_sib == NULL + || self->above_sample == NULL) { + goto out; + } + /* the maximum possible height of the tree is num_nodes + 1, including + * the null value. */ + self->stack1 = malloc((num_nodes + 1) * sizeof(tsk_id_t)); + self->stack2 = malloc((num_nodes + 1) * sizeof(tsk_id_t)); + if (self->stack1 == NULL || self->stack2 == NULL) { + goto out; + } + if (!!(self->flags & TSK_SAMPLE_COUNTS)) { + self->num_samples = calloc(num_nodes, sizeof(tsk_id_t)); + self->num_tracked_samples = calloc(num_nodes, sizeof(tsk_id_t)); + self->marked = calloc(num_nodes, sizeof(uint8_t)); + if (self->num_samples == NULL || self->num_tracked_samples == NULL + || self->marked == NULL) { + goto out; + } + } + if (!!(self->flags & TSK_SAMPLE_LISTS)) { + self->left_sample = malloc(num_nodes * sizeof(*self->left_sample)); + self->right_sample = malloc(num_nodes * sizeof(*self->right_sample)); + self->next_sample = malloc(num_samples * sizeof(*self->next_sample)); + if (self->left_sample == NULL || self->right_sample == NULL + || self->next_sample == NULL) { + goto out; + } + } + ret = tsk_tree_clear(self); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_free(tsk_tree_t *self) +{ + tsk_safe_free(self->parent); + tsk_safe_free(self->left_child); + tsk_safe_free(self->right_child); + tsk_safe_free(self->left_sib); + tsk_safe_free(self->right_sib); + tsk_safe_free(self->above_sample); + tsk_safe_free(self->stack1); + tsk_safe_free(self->stack2); + tsk_safe_free(self->num_samples); + tsk_safe_free(self->num_tracked_samples); + tsk_safe_free(self->marked); + tsk_safe_free(self->left_sample); + tsk_safe_free(self->right_sample); + tsk_safe_free(self->next_sample); + return 0; +} + +bool +tsk_tree_has_sample_lists(tsk_tree_t *self) +{ + return !!(self->flags & TSK_SAMPLE_LISTS); +} + +bool +tsk_tree_has_sample_counts(tsk_tree_t *self) +{ + return !!(self->flags & TSK_SAMPLE_COUNTS); +} + +static int TSK_WARN_UNUSED +tsk_tree_reset_tracked_samples(tsk_tree_t *self) +{ + int ret = 0; + + if (!tsk_tree_has_sample_counts(self)) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + memset(self->num_tracked_samples, 0, self->num_nodes * sizeof(tsk_id_t)); +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_set_tracked_samples(tsk_tree_t *self, size_t num_tracked_samples, + tsk_id_t *tracked_samples) +{ + int ret = TSK_ERR_GENERIC; + size_t j; + tsk_id_t u; + + /* TODO This is not needed when the sparse tree is new. We should use the + * state machine to check and only reset the tracked samples when needed. + */ + ret = tsk_tree_reset_tracked_samples(self); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_tracked_samples; j++) { + u = tracked_samples[j]; + if (u < 0 || u >= (tsk_id_t) self->num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + if (! tsk_treeseq_is_sample(self->tree_sequence, u)) { + ret = TSK_ERR_BAD_SAMPLES; + goto out; + } + if (self->num_tracked_samples[u] != 0) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + /* Propagate this upwards */ + while (u != TSK_NULL) { + self->num_tracked_samples[u] += 1; + u = self->parent[u]; + } + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_set_tracked_samples_from_sample_list(tsk_tree_t *self, + tsk_tree_t *other, tsk_id_t node) +{ + int ret = TSK_ERR_GENERIC; + tsk_id_t u, stop, index; + const tsk_id_t *next = other->next_sample; + const tsk_id_t *samples = other->tree_sequence->samples; + + if (! tsk_tree_has_sample_lists(other)) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + /* TODO This is not needed when the sparse tree is new. We should use the + * state machine to check and only reset the tracked samples when needed. + */ + ret = tsk_tree_reset_tracked_samples(self); + if (ret != 0) { + goto out; + } + + index = other->left_sample[node]; + if (index != TSK_NULL) { + stop = other->right_sample[node]; + while (true) { + u = samples[index]; + assert(self->num_tracked_samples[u] == 0); + /* Propagate this upwards */ + while (u != TSK_NULL) { + self->num_tracked_samples[u] += 1; + u = self->parent[u]; + } + if (index == stop) { + break; + } + index = next[index]; + } + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_copy(tsk_tree_t *self, tsk_tree_t *source) +{ + int ret = TSK_ERR_GENERIC; + size_t N = self->num_nodes; + + if (self == source) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + if (self->tree_sequence != source->tree_sequence) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + self->left = source->left; + self->right = source->right; + self->left_root = source->left_root; + self->index = source->index; + self->sites = source->sites; + self->sites_length = source->sites_length; + + memcpy(self->parent, source->parent, N * sizeof(tsk_id_t)); + memcpy(self->left_child, source->left_child, N * sizeof(tsk_id_t)); + memcpy(self->right_child, source->right_child, N * sizeof(tsk_id_t)); + memcpy(self->left_sib, source->left_sib, N * sizeof(tsk_id_t)); + memcpy(self->right_sib, source->right_sib, N * sizeof(tsk_id_t)); + if (self->flags & TSK_SAMPLE_COUNTS) { + if (! (source->flags & TSK_SAMPLE_COUNTS)) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + memcpy(self->num_samples, source->num_samples, N * sizeof(tsk_id_t)); + } + if (self->flags & TSK_SAMPLE_LISTS) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + ret = 0; +out: + return ret; +} + +/* Returns 0 if the specified sparse trees are equal, 1 if they are + * not equal, and < 0 if an error occurs. + * + * We only consider topological properties of the tree. Optional + * counts and sample lists are not considered for equality. + */ +int TSK_WARN_UNUSED +tsk_tree_equal(tsk_tree_t *self, tsk_tree_t *other) +{ + int ret = 1; + int condition; + size_t N = self->num_nodes; + + if (self->tree_sequence != other->tree_sequence) { + /* It is an error to compare trees from different tree sequences. */ + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + condition = self->index == other->index + && self->left == other->left + && self->right == other->right + && self->sites_length == other->sites_length + && self->sites == other->sites + && memcmp(self->parent, other->parent, N * sizeof(tsk_id_t)) == 0; + /* We do not check the children for equality here because + * the ordering of the children within a parent are essentially irrelevant + * in terms of topology. Depending on the way in which we approach a given + * tree we can get different orderings within the children, and so the + * same tree would not be equal to itself. */ + if (condition) { + ret = 0; + } +out: + return ret; +} + +static int +tsk_tree_check_node(tsk_tree_t *self, tsk_id_t u) +{ + int ret = 0; + if (u < 0 || u >= (tsk_id_t) self->num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_mrca(tsk_tree_t *self, tsk_id_t u, tsk_id_t v, + tsk_id_t *mrca) +{ + int ret = 0; + tsk_id_t w = 0; + tsk_id_t *s1 = self->stack1; + tsk_id_t *s2 = self->stack2; + tsk_id_t j; + int l1, l2; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + ret = tsk_tree_check_node(self, v); + if (ret != 0) { + goto out; + } + j = u; + l1 = 0; + while (j != TSK_NULL) { + assert(l1 < (int) self->num_nodes); + s1[l1] = j; + l1++; + j = self->parent[j]; + } + s1[l1] = TSK_NULL; + j = v; + l2 = 0; + while (j != TSK_NULL) { + assert(l2 < (int) self->num_nodes); + s2[l2] = j; + l2++; + j = self->parent[j]; + } + s2[l2] = TSK_NULL; + do { + w = s1[l1]; + l1--; + l2--; + } while (l1 >= 0 && l2 >= 0 && s1[l1] == s2[l2]); + *mrca = w; + ret = 0; +out: + return ret; +} + +static int +tsk_tree_get_num_samples_by_traversal(tsk_tree_t *self, tsk_id_t u, + size_t *num_samples) +{ + int ret = 0; + tsk_id_t *stack = self->stack1; + tsk_id_t v; + size_t count = 0; + int stack_top = 0; + + stack[0] = u; + while (stack_top >= 0) { + v = stack[stack_top]; + stack_top--; + if (tsk_treeseq_is_sample(self->tree_sequence, v)) { + count++; + } + v = self->left_child[v]; + while (v != TSK_NULL) { + stack_top++; + stack[stack_top] = v; + v = self->right_sib[v]; + } + } + *num_samples = count; + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_num_samples(tsk_tree_t *self, tsk_id_t u, size_t *num_samples) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + + if (self->flags & TSK_SAMPLE_COUNTS) { + *num_samples = (size_t) self->num_samples[u]; + } else { + ret = tsk_tree_get_num_samples_by_traversal(self, u, num_samples); + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_num_tracked_samples(tsk_tree_t *self, tsk_id_t u, + size_t *num_tracked_samples) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + if (! (self->flags & TSK_SAMPLE_COUNTS)) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + *num_tracked_samples = (size_t) self->num_tracked_samples[u]; +out: + return ret; +} + +bool +tsk_tree_is_sample(tsk_tree_t *self, tsk_id_t u) +{ + return tsk_treeseq_is_sample(self->tree_sequence, u); +} + +size_t +tsk_tree_get_num_roots(tsk_tree_t *self) +{ + size_t num_roots = 0; + tsk_id_t u = self->left_root; + + while (u != TSK_NULL) { + u = self->right_sib[u]; + num_roots++; + } + return num_roots; +} + +int TSK_WARN_UNUSED +tsk_tree_get_parent(tsk_tree_t *self, tsk_id_t u, tsk_id_t *parent) +{ + int ret = 0; + + ret = tsk_tree_check_node(self, u); + if (ret != 0) { + goto out; + } + *parent = self->parent[u]; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_time(tsk_tree_t *self, tsk_id_t u, double *t) +{ + int ret = 0; + tsk_node_t node; + + ret = tsk_treeseq_get_node(self->tree_sequence, (size_t) u, &node); + if (ret != 0) { + goto out; + } + *t = node.time; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_get_sites(tsk_tree_t *self, tsk_site_t **sites, tsk_tbl_size_t *sites_length) +{ + *sites = self->sites; + *sites_length = self->sites_length; + return 0; +} + +static void +tsk_tree_check_state(tsk_tree_t *self) +{ + tsk_id_t u, v; + size_t j, num_samples; + int err, c; + tsk_site_t site; + tsk_id_t *children = malloc(self->num_nodes * sizeof(tsk_id_t)); + bool *is_root = calloc(self->num_nodes, sizeof(bool)); + + assert(children != NULL); + + for (j = 0; j < self->tree_sequence->num_samples; j++) { + u = self->samples[j]; + while (self->parent[u] != TSK_NULL) { + u = self->parent[u]; + } + is_root[u] = true; + } + if (self->tree_sequence->num_samples == 0) { + assert(self->left_root == TSK_NULL); + } else { + assert(self->left_sib[self->left_root] == TSK_NULL); + } + /* Iterate over the roots and make sure they are set */ + for (u = self->left_root; u != TSK_NULL; u = self->right_sib[u]) { + assert(is_root[u]); + is_root[u] = false; + } + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + assert(!is_root[u]); + c = 0; + for (v = self->left_child[u]; v != TSK_NULL; v = self->right_sib[v]) { + assert(self->parent[v] == u); + children[c] = v; + c++; + } + for (v = self->right_child[u]; v != TSK_NULL; v = self->left_sib[v]) { + assert(c > 0); + c--; + assert(v == children[c]); + } + } + for (j = 0; j < self->sites_length; j++) { + site = self->sites[j]; + assert(self->left <= site.position); + assert(site.position < self->right); + } + + if (self->flags & TSK_SAMPLE_COUNTS) { + assert(self->num_samples != NULL); + assert(self->num_tracked_samples != NULL); + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + err = tsk_tree_get_num_samples_by_traversal(self, u, &num_samples); + assert(err == 0); + assert(num_samples == (size_t) self->num_samples[u]); + } + } else { + assert(self->num_samples == NULL); + assert(self->num_tracked_samples == NULL); + } + if (self->flags & TSK_SAMPLE_LISTS) { + assert(self->right_sample != NULL); + assert(self->left_sample != NULL); + assert(self->next_sample != NULL); + } else { + assert(self->right_sample == NULL); + assert(self->left_sample == NULL); + assert(self->next_sample == NULL); + } + + free(children); + free(is_root); +} + +void +tsk_tree_print_state(tsk_tree_t *self, FILE *out) +{ + size_t j; + tsk_site_t site; + + fprintf(out, "Sparse tree state:\n"); + fprintf(out, "flags = %d\n", self->flags); + fprintf(out, "left = %f\n", self->left); + fprintf(out, "right = %f\n", self->right); + fprintf(out, "left_root = %d\n", (int) self->left_root); + fprintf(out, "index = %d\n", (int) self->index); + fprintf(out, "node\tparent\tlchild\trchild\tlsib\trsib"); + if (self->flags & TSK_SAMPLE_LISTS) { + fprintf(out, "\thead\ttail"); + } + fprintf(out, "\n"); + + for (j = 0; j < self->num_nodes; j++) { + fprintf(out, "%d\t%d\t%d\t%d\t%d\t%d", (int) j, self->parent[j], self->left_child[j], + self->right_child[j], self->left_sib[j], self->right_sib[j]); + if (self->flags & TSK_SAMPLE_LISTS) { + fprintf(out, "\t%d\t%d\t", self->left_sample[j], + self->right_sample[j]); + } + if (self->flags & TSK_SAMPLE_COUNTS) { + fprintf(out, "\t%d\t%d\t%d", (int) self->num_samples[j], + (int) self->num_tracked_samples[j], self->marked[j]); + } + fprintf(out, "\n"); + } + fprintf(out, "sites = \n"); + for (j = 0; j < self->sites_length; j++) { + site = self->sites[j]; + fprintf(out, "\t%d\t%f\n", site.id, site.position); + } + tsk_tree_check_state(self); +} + +/* Methods for positioning the tree along the sequence */ + +static inline void +tsk_tree_propagate_sample_count_loss(tsk_tree_t *self, tsk_id_t parent, + tsk_id_t child) +{ + tsk_id_t v; + const tsk_id_t all_samples_diff = self->num_samples[child]; + const tsk_id_t tracked_samples_diff = self->num_tracked_samples[child]; + const uint8_t mark = self->mark; + const tsk_id_t * restrict tree_parent = self->parent; + tsk_id_t * restrict num_samples = self->num_samples; + tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; + uint8_t * restrict marked = self->marked; + + /* propagate this loss up as far as we can */ + v = parent; + while (v != TSK_NULL) { + num_samples[v] -= all_samples_diff; + num_tracked_samples[v] -= tracked_samples_diff; + marked[v] = mark; + v = tree_parent[v]; + } +} + +static inline void +tsk_tree_propagate_sample_count_gain(tsk_tree_t *self, tsk_id_t parent, + tsk_id_t child) +{ + tsk_id_t v; + const tsk_id_t all_samples_diff = self->num_samples[child]; + const tsk_id_t tracked_samples_diff = self->num_tracked_samples[child]; + const uint8_t mark = self->mark; + const tsk_id_t * restrict tree_parent = self->parent; + tsk_id_t * restrict num_samples = self->num_samples; + tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; + uint8_t * restrict marked = self->marked; + + /* propogate this gain up as far as we can */ + v = parent; + while (v != TSK_NULL) { + num_samples[v] += all_samples_diff; + num_tracked_samples[v] += tracked_samples_diff; + marked[v] = mark; + v = tree_parent[v]; + } +} + +static inline void +tsk_tree_update_sample_lists(tsk_tree_t *self, tsk_id_t node) +{ + tsk_id_t u, v, sample_index; + tsk_id_t * restrict left = self->left_sample; + tsk_id_t * restrict right = self->right_sample; + tsk_id_t * restrict next = self->next_sample; + const tsk_id_t * restrict left_child = self->left_child; + const tsk_id_t * restrict right_sib = self->right_sib; + const tsk_id_t * restrict parent = self->parent; + const tsk_id_t * restrict sample_index_map = self->tree_sequence->sample_index_map; + + for (u = node; u != TSK_NULL; u = parent[u]) { + sample_index = sample_index_map[u]; + if (sample_index != TSK_NULL) { + right[u] = left[u]; + } else { + left[u] = TSK_NULL; + right[u] = TSK_NULL; + } + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { + if (left[v] != TSK_NULL) { + assert(right[v] != TSK_NULL); + if (left[u] == TSK_NULL) { + left[u] = left[v]; + right[u] = right[v]; + } else { + next[right[u]] = left[v]; + right[u] = right[v]; + } + } + } + } +} + +static int +tsk_tree_advance(tsk_tree_t *self, int direction, + const double * restrict out_breakpoints, + const tsk_id_t * restrict out_order, + tsk_id_t *out_index, + const double * restrict in_breakpoints, + const tsk_id_t * restrict in_order, + tsk_id_t *in_index) +{ + int ret = 0; + const int direction_change = direction * (direction != self->direction); + tsk_id_t in = *in_index + direction_change; + tsk_id_t out = *out_index + direction_change; + tsk_id_t k, p, c, u, v, root, lsib, rsib, lroot, rroot; + tsk_tbl_collection_t *tables = self->tree_sequence->tables; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_edges = (tsk_id_t) tables->edges->num_rows; + const tsk_id_t * restrict edge_parent = tables->edges->parent; + const tsk_id_t * restrict edge_child = tables->edges->child; + const uint32_t * restrict node_flags = tables->nodes->flags; + double x; + bool above_sample; + + if (direction == TSK_DIR_FORWARD) { + x = self->right; + } else { + x = self->left; + } + while (out >= 0 && out < num_edges && out_breakpoints[out_order[out]] == x) { + assert(out < num_edges); + k = out_order[out]; + out += direction; + p = edge_parent[k]; + c = edge_child[k]; + lsib = self->left_sib[c]; + rsib = self->right_sib[c]; + if (lsib == TSK_NULL) { + self->left_child[p] = rsib; + } else { + self->right_sib[lsib] = rsib; + } + if (rsib == TSK_NULL) { + self->right_child[p] = lsib; + } else { + self->left_sib[rsib] = lsib; + } + self->parent[c] = TSK_NULL; + self->left_sib[c] = TSK_NULL; + self->right_sib[c] = TSK_NULL; + if (self->flags & TSK_SAMPLE_COUNTS) { + tsk_tree_propagate_sample_count_loss(self, p, c); + } + if (self->flags & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, p); + } + + /* Update the roots. If c is not above a sample then we have nothing to do + * as we cannot affect the status of any roots. */ + if (self->above_sample[c]) { + /* Compute the new above sample status for the nodes from p up to root. */ + v = p; + root = v; + above_sample = false; + while (v != TSK_NULL && !above_sample) { + above_sample = !!(node_flags[v] & TSK_NODE_IS_SAMPLE); + u = self->left_child[v]; + while (u != TSK_NULL && !above_sample) { + above_sample = above_sample || self->above_sample[u]; + u = self->right_sib[u]; + } + self->above_sample[v] = above_sample; + root = v; + v = self->parent[v]; + } + if (!above_sample) { + /* root is no longer above samples. Remove it from the root list */ + lroot = self->left_sib[root]; + rroot = self->right_sib[root]; + self->left_root = TSK_NULL; + if (lroot != TSK_NULL) { + self->right_sib[lroot] = rroot; + self->left_root = lroot; + } + if (rroot != TSK_NULL) { + self->left_sib[rroot] = lroot; + self->left_root = rroot; + } + self->left_sib[root] = TSK_NULL; + self->right_sib[root] = TSK_NULL; + } + /* Add c to the root list */ + if (self->left_root != TSK_NULL) { + lroot = self->left_sib[self->left_root]; + if (lroot != TSK_NULL) { + self->right_sib[lroot] = c; + } + self->left_sib[c] = lroot; + self->left_sib[self->left_root] = c; + } + self->right_sib[c] = self->left_root; + self->left_root = c; + } + } + + while (in >= 0 && in < num_edges && in_breakpoints[in_order[in]] == x) { + k = in_order[in]; + in += direction; + p = edge_parent[k]; + c = edge_child[k]; + if (self->parent[c] != TSK_NULL) { + ret = TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN; + goto out; + } + self->parent[c] = p; + u = self->right_child[p]; + lsib = self->left_sib[c]; + rsib = self->right_sib[c]; + if (u == TSK_NULL) { + self->left_child[p] = c; + self->left_sib[c] = TSK_NULL; + self->right_sib[c] = TSK_NULL; + } else { + self->right_sib[u] = c; + self->left_sib[c] = u; + self->right_sib[c] = TSK_NULL; + } + self->right_child[p] = c; + if (self->flags & TSK_SAMPLE_COUNTS) { + tsk_tree_propagate_sample_count_gain(self, p, c); + } + if (self->flags & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, p); + } + + /* Update the roots. */ + if (self->above_sample[c]) { + v = p; + root = v; + above_sample = false; + while (v != TSK_NULL && !above_sample) { + above_sample = self->above_sample[v]; + self->above_sample[v] = self->above_sample[v] || self->above_sample[c]; + root = v; + v = self->parent[v]; + } + if (! above_sample) { + /* Replace c with root in root list */ + if (lsib != TSK_NULL) { + self->right_sib[lsib] = root; + } + if (rsib != TSK_NULL) { + self->left_sib[rsib] = root; + } + self->left_sib[root] = lsib; + self->right_sib[root] = rsib; + self->left_root = root; + } else { + /* Remove c from root list */ + self->left_root = TSK_NULL; + if (lsib != TSK_NULL) { + self->right_sib[lsib] = rsib; + self->left_root = lsib; + } + if (rsib != TSK_NULL) { + self->left_sib[rsib] = lsib; + self->left_root = rsib; + } + } + } + } + + if (self->left_root != TSK_NULL) { + /* Ensure that left_root is the left-most root */ + while (self->left_sib[self->left_root] != TSK_NULL) { + self->left_root = self->left_sib[self->left_root]; + } + } + + self->direction = direction; + self->index = (size_t) ((int64_t) self->index + direction); + if (direction == TSK_DIR_FORWARD) { + self->left = x; + self->right = sequence_length; + if (out >= 0 && out < num_edges) { + self->right = TSK_MIN(self->right, out_breakpoints[out_order[out]]); + } + if (in >= 0 && in < num_edges) { + self->right = TSK_MIN(self->right, in_breakpoints[in_order[in]]); + } + } else { + self->right = x; + self->left = 0; + if (out >= 0 && out < num_edges) { + self->left = TSK_MAX(self->left, out_breakpoints[out_order[out]]); + } + if (in >= 0 && in < num_edges) { + self->left = TSK_MAX(self->left, in_breakpoints[in_order[in]]); + } + } + assert(self->left < self->right); + *out_index = out; + *in_index = in; + if (tables->sites->num_rows > 0) { + self->sites = self->tree_sequence->tree_sites[self->index]; + self->sites_length = self->tree_sequence->tree_sites_length[self->index]; + } + ret = 1; +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_first(tsk_tree_t *self) +{ + int ret = 1; + tsk_tbl_collection_t *tables = self->tree_sequence->tables; + + self->left = 0; + self->index = 0; + self->right = tables->sequence_length; + self->sites = self->tree_sequence->tree_sites[0]; + self->sites_length = self->tree_sequence->tree_sites_length[0]; + + if (tables->edges->num_rows > 0) { + /* TODO this is redundant if this is the first usage of the tree. We + * should add a state machine here so we know what state the tree is + * in and can take the appropriate actions. + */ + ret = tsk_tree_clear(self); + if (ret != 0) { + goto out; + } + self->index = (size_t) -1; + self->left_index = 0; + self->right_index = 0; + self->direction = TSK_DIR_FORWARD; + self->right = 0; + + ret = tsk_tree_advance(self, TSK_DIR_FORWARD, + tables->edges->right, tables->indexes.edge_removal_order, + &self->right_index, tables->edges->left, + tables->indexes.edge_insertion_order, &self->left_index); + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_last(tsk_tree_t *self) +{ + int ret = 1; + tsk_treeseq_t *ts = self->tree_sequence; + const tsk_tbl_collection_t *tables = ts->tables; + + self->left = 0; + self->right = tables->sequence_length; + self->index = 0; + self->sites = ts->tree_sites[0]; + self->sites_length = ts->tree_sites_length[0]; + + if (tables->edges->num_rows > 0) { + /* TODO this is redundant if this is the first usage of the tree. We + * should add a state machine here so we know what state the tree is + * in and can take the appropriate actions. + */ + ret = tsk_tree_clear(self); + if (ret != 0) { + goto out; + } + self->index = tsk_treeseq_get_num_trees(ts); + self->left_index = (tsk_id_t) tables->edges->num_rows - 1; + self->right_index = (tsk_id_t) tables->edges->num_rows - 1; + self->direction = TSK_DIR_REVERSE; + self->left = tables->sequence_length; + self->right = 0; + + ret = tsk_tree_advance(self, TSK_DIR_REVERSE, + tables->edges->left, tables->indexes.edge_insertion_order, + &self->left_index, tables->edges->right, + tables->indexes.edge_removal_order, &self->right_index); + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_next(tsk_tree_t *self) +{ + int ret = 0; + tsk_treeseq_t *ts = self->tree_sequence; + const tsk_tbl_collection_t *tables = ts->tables; + size_t num_trees = tsk_treeseq_get_num_trees(ts); + + if (self->index < num_trees - 1) { + ret = tsk_tree_advance(self, TSK_DIR_FORWARD, + tables->edges->right, tables->indexes.edge_removal_order, + &self->right_index, tables->edges->left, + tables->indexes.edge_insertion_order, &self->left_index); + } + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_prev(tsk_tree_t *self) +{ + int ret = 0; + const tsk_tbl_collection_t *tables = self->tree_sequence->tables; + + if (self->index > 0) { + ret = tsk_tree_advance(self, TSK_DIR_REVERSE, + tables->edges->left, tables->indexes.edge_insertion_order, + &self->left_index, tables->edges->right, + tables->indexes.edge_removal_order, &self->right_index); + } + return ret; +} + +/* ======================================================== * + * Tree diff iterator. + * ======================================================== */ + +int TSK_WARN_UNUSED +tsk_diff_iter_alloc(tsk_diff_iter_t *self, tsk_treeseq_t *tree_sequence) +{ + int ret = 0; + + assert(tree_sequence != NULL); + memset(self, 0, sizeof(tsk_diff_iter_t)); + self->num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); + self->num_edges = tsk_treeseq_get_num_edges(tree_sequence); + self->tree_sequence = tree_sequence; + self->insertion_index = 0; + self->removal_index = 0; + self->tree_left = 0; + self->tree_index = (size_t) -1; + self->edge_list_nodes = malloc(self->num_edges * sizeof(tsk_edge_list_t)); + if (self->edge_list_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_diff_iter_free(tsk_diff_iter_t *self) +{ + int ret = 0; + tsk_safe_free(self->edge_list_nodes); + return ret; +} + +void +tsk_diff_iter_print_state(tsk_diff_iter_t *self, FILE *out) +{ + fprintf(out, "tree_diff_iterator state\n"); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + fprintf(out, "insertion_index = %d\n", (int) self->insertion_index); + fprintf(out, "removal_index = %d\n", (int) self->removal_index); + fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "tree_index = %d\n", (int) self->tree_index); +} + +int TSK_WARN_UNUSED +tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, + tsk_edge_list_t **edges_out, tsk_edge_list_t **edges_in) +{ + int ret = 0; + tsk_id_t k; + const double sequence_length = self->tree_sequence->tables->sequence_length; + double left = self->tree_left; + double right = sequence_length; + size_t next_edge_list_node = 0; + tsk_treeseq_t *s = self->tree_sequence; + tsk_edge_list_t *out_head = NULL; + tsk_edge_list_t *out_tail = NULL; + tsk_edge_list_t *in_head = NULL; + tsk_edge_list_t *in_tail = NULL; + tsk_edge_list_t *w = NULL; + size_t num_trees = tsk_treeseq_get_num_trees(s); + const tsk_edge_tbl_t *edges = s->tables->edges; + const tsk_id_t *insertion_order = s->tables->indexes.edge_insertion_order; + const tsk_id_t *removal_order = s->tables->indexes.edge_removal_order; + + if (self->tree_index + 1 < num_trees) { + /* First we remove the stale records */ + while (self->removal_index < self->num_edges && + left == edges->right[removal_order[self->removal_index]]) { + k = removal_order[self->removal_index]; + assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->next = NULL; + if (out_head == NULL) { + out_head = w; + out_tail = w; + } else { + out_tail->next = w; + out_tail = w; + } + self->removal_index++; + } + + /* Now insert the new records */ + while (self->insertion_index < self->num_edges && + left == edges->left[insertion_order[self->insertion_index]]) { + k = insertion_order[self->insertion_index]; + assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->next = NULL; + if (in_head == NULL) { + in_head = w; + in_tail = w; + } else { + in_tail->next = w; + in_tail = w; + } + self->insertion_index++; + } + right = sequence_length; + if (self->insertion_index < self->num_edges) { + right = TSK_MIN(right, edges->left[ + insertion_order[self->insertion_index]]); + } + if (self->removal_index < self->num_edges) { + right = TSK_MIN(right, edges->right[ + removal_order[self->removal_index]]); + } + self->tree_index++; + ret = 1; + } + *edges_out = out_head; + *edges_in = in_head; + *ret_left = left; + *ret_right = right; + /* Set the left coordinate for the next tree */ + self->tree_left = right; + return ret; +} diff --git a/c/tsk_trees.h b/c/tsk_trees.h new file mode 100644 index 0000000000..1eac3edb34 --- /dev/null +++ b/c/tsk_trees.h @@ -0,0 +1,209 @@ +#ifndef TSK_TREES_H +#define TSK_TREES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include "tsk_tables.h" + + +#define TSK_SAMPLE_COUNTS (1 << 0) +#define TSK_SAMPLE_LISTS (1 << 1) + +#define TSK_DIR_FORWARD 1 +#define TSK_DIR_REVERSE -1 + + +/* Tree sequences */ +typedef struct { + size_t num_trees; + size_t num_samples; + tsk_id_t *samples; + /* If a node is a sample, map to its index in the samples list */ + tsk_id_t *sample_index_map; + /* Map individuals to the list of nodes that reference them */ + tsk_id_t *individual_nodes_mem; + tsk_id_t **individual_nodes; + tsk_tbl_size_t *individual_nodes_length; + /* For each tree, a list of sites on that tree */ + tsk_site_t *tree_sites_mem; + tsk_site_t **tree_sites; + tsk_tbl_size_t *tree_sites_length; + /* For each site, a list of mutations at that site */ + tsk_mutation_t *site_mutations_mem; + tsk_mutation_t **site_mutations; + tsk_tbl_size_t *site_mutations_length; + /* The underlying tables */ + tsk_tbl_collection_t *tables; +} tsk_treeseq_t; + +typedef struct { + tsk_treeseq_t *tree_sequence; + size_t num_nodes; + int flags; + tsk_id_t *samples; + /* The left-most root in the forest. Roots are sibs and all roots are found + * via left_sib and right_sib */ + tsk_id_t left_root; + /* Left and right physical coordinates of the tree */ + double left; + double right; + tsk_id_t *parent; /* parent of node u */ + tsk_id_t *left_child; /* leftmost child of node u */ + tsk_id_t *right_child; /* rightmost child of node u */ + tsk_id_t *left_sib; /* sibling to right of node u */ + tsk_id_t *right_sib; /* sibling to the left of node u */ + bool *above_sample; + size_t index; + /* These are involved in the optional sample tracking; num_samples counts + * all samples below a give node, and num_tracked_samples counts those + * from a specific subset. */ + tsk_id_t *num_samples; + tsk_id_t *num_tracked_samples; + /* All nodes that are marked during a particular transition are marked + * with a given value. */ + uint8_t *marked; + uint8_t mark; + /* These are for the optional sample list tracking. */ + tsk_id_t *left_sample; + tsk_id_t *right_sample; + tsk_id_t *next_sample; + tsk_id_t *sample_index_map; + /* traversal stacks */ + tsk_id_t *stack1; + tsk_id_t *stack2; + /* The sites on this tree */ + tsk_site_t *sites; + tsk_tbl_size_t sites_length; + /* Counters needed for next() and prev() transformations. */ + int direction; + tsk_id_t left_index; + tsk_id_t right_index; +} tsk_tree_t; + +/* Diff iterator. TODO Not sure if we want to keep this, as it's not used + * very much in the C code. */ +typedef struct _tsk_edge_list_t { + tsk_edge_t edge; + struct _tsk_edge_list_t *next; +} tsk_edge_list_t; + +typedef struct { + size_t num_nodes; + size_t num_edges; + double tree_left; + tsk_treeseq_t *tree_sequence; + size_t insertion_index; + size_t removal_index; + size_t tree_index; + tsk_edge_list_t *edge_list_nodes; +} tsk_diff_iter_t; + +/****************************************************************************/ +/* Tree sequence.*/ +/****************************************************************************/ + +int tsk_treeseq_alloc(tsk_treeseq_t *self, tsk_tbl_collection_t *tables, int flags); +int tsk_treeseq_load(tsk_treeseq_t *self, const char *filename, int flags); +int tsk_treeseq_dump(tsk_treeseq_t *self, const char *filename, int flags); +int tsk_treeseq_copy_tables(tsk_treeseq_t *self, tsk_tbl_collection_t *tables); +int tsk_treeseq_free(tsk_treeseq_t *self); +void tsk_treeseq_print_state(tsk_treeseq_t *self, FILE *out); + +size_t tsk_treeseq_get_num_nodes(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_edges(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_migrations(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_sites(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_mutations(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_provenances(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_populations(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_individuals(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_trees(tsk_treeseq_t *self); +size_t tsk_treeseq_get_num_samples(tsk_treeseq_t *self); +char * tsk_treeseq_get_file_uuid(tsk_treeseq_t *self); +double tsk_treeseq_get_sequence_length(tsk_treeseq_t *self); +bool tsk_treeseq_is_sample(tsk_treeseq_t *self, tsk_id_t u); + +int tsk_treeseq_get_node(tsk_treeseq_t *self, size_t index, tsk_node_t *node); +int tsk_treeseq_get_edge(tsk_treeseq_t *self, size_t index, tsk_edge_t *edge); +int tsk_treeseq_get_migration(tsk_treeseq_t *self, size_t index, + tsk_migration_t *migration); +int tsk_treeseq_get_site(tsk_treeseq_t *self, size_t index, tsk_site_t *site); +int tsk_treeseq_get_mutation(tsk_treeseq_t *self, size_t index, + tsk_mutation_t *mutation); +int tsk_treeseq_get_provenance(tsk_treeseq_t *self, size_t index, + tsk_provenance_t *provenance); +int tsk_treeseq_get_population(tsk_treeseq_t *self, size_t index, + tsk_population_t *population); +int tsk_treeseq_get_individual(tsk_treeseq_t *self, size_t index, + tsk_individual_t *individual); +int tsk_treeseq_get_samples(tsk_treeseq_t *self, tsk_id_t **samples); +int tsk_treeseq_get_sample_index_map(tsk_treeseq_t *self, + tsk_id_t **sample_index_map); + +int tsk_treeseq_simplify(tsk_treeseq_t *self, tsk_id_t *samples, + size_t num_samples, int flags, tsk_treeseq_t *output, + tsk_id_t *node_map); +/* TODO do these belong in trees or stats? They should probably be in stats. + * Keep them here for now until we figure out the correct interface. + */ +int tsk_treeseq_get_pairwise_diversity(tsk_treeseq_t *self, + tsk_id_t *samples, size_t num_samples, double *pi); +int tsk_treeseq_genealogical_nearest_neighbours(tsk_treeseq_t *self, + tsk_id_t *focal, size_t num_focal, + tsk_id_t **reference_sets, size_t *reference_set_size, size_t num_reference_sets, + int flags, double *ret_array); +int tsk_treeseq_mean_descendants(tsk_treeseq_t *self, + tsk_id_t **reference_sets, size_t *reference_set_size, size_t num_reference_sets, + int flags, double *ret_array); + + +/****************************************************************************/ +/* Tree */ +/****************************************************************************/ + +int tsk_tree_alloc(tsk_tree_t *self, tsk_treeseq_t *tree_sequence, + int flags); +int tsk_tree_free(tsk_tree_t *self); +bool tsk_tree_has_sample_lists(tsk_tree_t *self); +bool tsk_tree_has_sample_counts(tsk_tree_t *self); +int tsk_tree_copy(tsk_tree_t *self, tsk_tree_t *source); +int tsk_tree_equal(tsk_tree_t *self, tsk_tree_t *other); +int tsk_tree_set_tracked_samples(tsk_tree_t *self, + size_t num_tracked_samples, tsk_id_t *tracked_samples); +int tsk_tree_set_tracked_samples_from_sample_list(tsk_tree_t *self, + tsk_tree_t *other, tsk_id_t node); +int tsk_tree_get_root(tsk_tree_t *self, tsk_id_t *root); +bool tsk_tree_is_sample(tsk_tree_t *self, tsk_id_t u); +size_t tsk_tree_get_num_roots(tsk_tree_t *self); +int tsk_tree_get_parent(tsk_tree_t *self, tsk_id_t u, tsk_id_t *parent); +int tsk_tree_get_time(tsk_tree_t *self, tsk_id_t u, double *t); +int tsk_tree_get_mrca(tsk_tree_t *self, tsk_id_t u, tsk_id_t v, tsk_id_t *mrca); +int tsk_tree_get_num_samples(tsk_tree_t *self, tsk_id_t u, size_t *num_samples); +int tsk_tree_get_num_tracked_samples(tsk_tree_t *self, tsk_id_t u, + size_t *num_tracked_samples); +int tsk_tree_get_sites(tsk_tree_t *self, tsk_site_t **sites, tsk_tbl_size_t *sites_length); + +void tsk_tree_print_state(tsk_tree_t *self, FILE *out); +/* Method for positioning the tree in the sequence. */ +int tsk_tree_first(tsk_tree_t *self); +int tsk_tree_last(tsk_tree_t *self); +int tsk_tree_next(tsk_tree_t *self); +int tsk_tree_prev(tsk_tree_t *self); + +/****************************************************************************/ +/* Diff iterator */ +/****************************************************************************/ + +int tsk_diff_iter_alloc(tsk_diff_iter_t *self, tsk_treeseq_t *tree_sequence); +int tsk_diff_iter_free(tsk_diff_iter_t *self); +int tsk_diff_iter_next(tsk_diff_iter_t *self, + double *left, double *right, + tsk_edge_list_t **edges_out, tsk_edge_list_t **edges_in); +void tsk_diff_iter_print_state(tsk_diff_iter_t *self, FILE *out); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c/tskit.h b/c/tskit.h new file mode 100644 index 0000000000..dd71ff8196 --- /dev/null +++ b/c/tskit.h @@ -0,0 +1,10 @@ +#ifndef __TSKIT_H__ +#define __TSKIT_H__ + +#include "tsk_core.h" +#include "tsk_trees.h" +#include "tsk_genotypes.h" +#include "tsk_convert.h" +#include "tsk_stats.h" + +#endif diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst new file mode 100644 index 0000000000..aef5332be3 --- /dev/null +++ b/python/CHANGELOG.rst @@ -0,0 +1,7 @@ +-------------------- +[0.0.0] - 2019-01-19 +-------------------- + +Initial extraction of tskit code from msprime. Relicense to MIT. + +Code copied at hash 29921408661d5fe0b1a82b1ca302a8b87510fd23 diff --git a/python/LICENSE b/python/LICENSE new file mode 120000 index 0000000000..ea5b60640b --- /dev/null +++ b/python/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 0000000000..e9d170ba1a --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,3 @@ +include lib/kastore/c/kastore.h +include LICENSE +include tskit/_version.py diff --git a/python/Makefile b/python/Makefile new file mode 100644 index 0000000000..f3d5a52694 --- /dev/null +++ b/python/Makefile @@ -0,0 +1,18 @@ + +all: ext3 + +ext3: _tskitmodule.c + # CFLAGS="-std=c99 -Wall -Wextra -Werror -Wno-unused-parameter" \ + # python3 setup.py build_ext --inplace + # Disable checks for now. + python3 setup.py build_ext --inplace + +ext2: _tskitmodule.c + python2 setup.py build_ext --inplace + +ctags: + ctags lib/*.c lib/*.h tskit/*.py + +clean: + rm -f *.so *.o tags + rm -fR build diff --git a/python/README.rst b/python/README.rst new file mode 100644 index 0000000000..5b0d0d6059 --- /dev/null +++ b/python/README.rst @@ -0,0 +1 @@ +The tree sequence toolkit. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c new file mode 100644 index 0000000000..a63c1cc51c --- /dev/null +++ b/python/_tskitmodule.c @@ -0,0 +1,8563 @@ +/* +** Copyright (C) 2014-2018 University of Oxford +** +** This file is part of tskit. +** +** tskit is free software: you can redistribute it and/or modify +** it under the terms of the GNU General Public License as published by +** the Free Software Foundation, either version 3 of the License, or +** (at your option) any later version. +** +** tskit is distributed in the hope that it will be useful, +** but WITHOUT ANY WARRANTY; without even the implied warranty of +** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +** GNU General Public License for more details. +** +** You should have received a copy of the GNU General Public License +** along with tskit. If not, see . +*/ + +#define PY_SSIZE_T_CLEAN +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include +#include +#include + +#include "kastore.h" +#include "tskit.h" + +#if PY_MAJOR_VERSION >= 3 +#define IS_PY3K +#endif + +#define MODULE_DOC \ +"Low level interface for tskit" + +#define SET_COLS 0 +#define APPEND_COLS 1 + +/* TskitException is the superclass of all exceptions that can be thrown by + * tskit. We define it here in the low-level library so that exceptions defined + * here and in the high-level library can inherit from it. + */ +static PyObject *TskitException; +static PyObject *TskitLibraryError; +static PyObject *TskitFileFormatError; +static PyObject *TskitVersionTooOldError; +static PyObject *TskitVersionTooNewError; + + +/* A lightweight wrapper for a table collection. This serves only as a wrapper + * around a pointer and a way to data in-and-out of the low level structures + * via the canonical dictionary encoding. + */ +typedef struct { + PyObject_HEAD + tsk_tbl_collection_t *tables; +} LightweightTableCollection; + +/* The XTable classes each have 'lock' attribute, which is used to + * raise an error if a Python thread attempts to access a table + * while another Python thread is operating on it. Because tables + * allocate memory dynamically, we cannot gaurantee safety otherwise. + * The locks are set before the GIL is released and unset afterwards. + * Because C code executed here represents atomic Python operations + * (while the GIL is held), this should be safe */ + +typedef struct _TableCollection { + PyObject_HEAD + tsk_tbl_collection_t *tables; +} TableCollection; + + /* The table pointer in each of the Table classes either points to locally + * allocated memory or to the table stored in a tbl_collection_t. If we're + * using the memory in a tbl_collection_t, we keep a reference to the + * TableCollection object to ensure that the memory isn't free'd while a + * reference to the table itself is live. */ +typedef struct { + PyObject_HEAD + bool locked; + tsk_individual_tbl_t *table; + TableCollection *tables; +} IndividualTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_node_tbl_t *table; + TableCollection *tables; +} NodeTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_edge_tbl_t *table; + TableCollection *tables; +} EdgeTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_site_tbl_t *table; + TableCollection *tables; +} SiteTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_mutation_tbl_t *table; + TableCollection *tables; +} MutationTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_migration_tbl_t *table; + TableCollection *tables; +} MigrationTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_population_tbl_t *table; + TableCollection *tables; +} PopulationTable; + +typedef struct { + PyObject_HEAD + bool locked; + tsk_provenance_tbl_t *table; + TableCollection *tables; +} ProvenanceTable; + +typedef struct { + PyObject_HEAD + tsk_treeseq_t *tree_sequence; +} TreeSequence; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_tree_t *tree; +} Tree; + +typedef struct { + PyObject_HEAD + Tree *tree; + int first; +} TreeIterator; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_diff_iter_t *tree_diff_iterator; +} TreeDiffIterator; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_vcf_converter_t *tsk_vcf_converter; +} VcfConverter; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_hapgen_t *haplotype_generator; +} HaplotypeGenerator; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_vargen_t *variant_generator; +} VariantGenerator; + +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_ld_calc_t *ld_calc; +} LdCalculator; + +static void +handle_library_error(int err) +{ + if (tsk_is_kas_error(err)) { + PyErr_SetString(TskitFileFormatError, tsk_strerror(err)); + } else { + switch (err) { + case TSK_ERR_FILE_VERSION_TOO_NEW: + PyErr_SetString(TskitVersionTooNewError, tsk_strerror(err)); + break; + case TSK_ERR_FILE_VERSION_TOO_OLD: + PyErr_SetString(TskitVersionTooOldError, tsk_strerror(err)); + break; + case TSK_ERR_FILE_FORMAT: + PyErr_SetString(TskitFileFormatError, tsk_strerror(err)); + break; + case TSK_ERR_OUT_OF_BOUNDS: + PyErr_SetString(PyExc_IndexError, tsk_strerror(err)); + break; + default: + PyErr_SetString(TskitLibraryError, tsk_strerror(err)); + } + } +} + +static int +parse_sample_ids(PyObject *py_samples, tsk_treeseq_t *ts, size_t *num_samples, + tsk_id_t **samples) +{ + int ret = -1; + PyObject *item; + Py_ssize_t j, num_samples_local; + tsk_id_t *samples_local = NULL; + + num_samples_local = PyList_Size(py_samples); + if (num_samples_local < 2) { + PyErr_SetString(PyExc_ValueError, "Must provide at least 2 samples"); + goto out; + } + samples_local = PyMem_Malloc(num_samples_local * sizeof(tsk_id_t)); + if (samples_local == NULL) { + PyErr_NoMemory(); + goto out; + } + for (j = 0; j < num_samples_local; j++) { + item = PyList_GetItem(py_samples, j); + if (!PyNumber_Check(item)) { + PyErr_SetString(PyExc_TypeError, "sample id must be a number"); + goto out; + } + samples_local[j] = (tsk_id_t) PyLong_AsLong(item); + if (samples_local[j] < 0 + || samples_local[j] >= (tsk_id_t) tsk_treeseq_get_num_nodes(ts)) { + PyErr_SetString(PyExc_ValueError, "node ID out of bounds"); + goto out; + } + if (! tsk_treeseq_is_sample(ts, samples_local[j])) { + PyErr_SetString(PyExc_ValueError, "Specified node is not a sample"); + goto out; + } + } + *num_samples = (size_t) num_samples_local; + *samples = samples_local; + samples_local = NULL; + ret = 0; +out: + if (samples_local != NULL) { + PyMem_Free(samples_local); + } + return ret; +} + + +static PyObject * +convert_node_id_list(tsk_id_t *children, size_t num_children) +{ + PyObject *ret = NULL; + PyObject *t; + PyObject *py_int; + size_t j; + + t = PyTuple_New(num_children); + if (t == NULL) { + goto out; + } + for (j = 0; j < num_children; j++) { + py_int = Py_BuildValue("i", (int) children[j]); + if (py_int == NULL) { + Py_DECREF(children); + goto out; + } + PyTuple_SET_ITEM(t, j, py_int); + } + ret = t; +out: + return ret; +} + +static PyObject * +make_metadata(const char *metadata, Py_ssize_t length) +{ + const char *m = metadata == NULL? "": metadata; + return PyBytes_FromStringAndSize(m, length); +} + +static PyObject * +make_mutation(tsk_mutation_t *mutation) +{ + PyObject *ret = NULL; + PyObject* metadata = NULL; + + metadata = make_metadata(mutation->metadata, (Py_ssize_t) mutation->metadata_length); + if (metadata == NULL) { + goto out; + } + ret = Py_BuildValue("iis#iO", mutation->site, mutation->node, mutation->derived_state, + (Py_ssize_t) mutation->derived_state_length, mutation->parent, + metadata); +out: + Py_XDECREF(metadata); + return ret; +} + +static PyObject * +make_mutation_id_list(tsk_mutation_t *mutations, size_t length) +{ + PyObject *ret = NULL; + PyObject *t; + PyObject *item; + size_t j; + + t = PyTuple_New(length); + if (t == NULL) { + goto out; + } + for (j = 0; j < length; j++) { + item = Py_BuildValue("i", mutations[j].id); + if (item == NULL) { + Py_DECREF(t); + goto out; + } + PyTuple_SET_ITEM(t, j, item); + } + ret = t; +out: + return ret; +} + +static PyObject * +make_population(tsk_population_t *population) +{ + PyObject *ret = NULL; + PyObject *metadata = make_metadata(population->metadata, + (Py_ssize_t) population->metadata_length); + + ret = Py_BuildValue("(O)", metadata); + return ret; +} + +static PyObject * +make_provenance(tsk_provenance_t *provenance) +{ + PyObject *ret = NULL; + + ret = Py_BuildValue("s#s#", + provenance->timestamp, (Py_ssize_t) provenance->timestamp_length, + provenance->record, (Py_ssize_t) provenance->record_length); + return ret; +} + +static PyObject * +make_individual_row(tsk_individual_t *r) +{ + PyObject *ret = NULL; + PyObject *metadata = make_metadata(r->metadata, (Py_ssize_t) r->metadata_length); + PyArrayObject *location = NULL; + npy_intp dims; + + dims = (npy_intp) r->location_length; + location = (PyArrayObject *) PyArray_SimpleNew(1, &dims, NPY_FLOAT64); + if (metadata == NULL || location == NULL) { + goto out; + } + memcpy(PyArray_DATA(location), r->location, r->location_length * sizeof(double)); + ret = Py_BuildValue("IOO", (unsigned int) r->flags, location, metadata); +out: + Py_XDECREF(location); + Py_XDECREF(metadata); + return ret; +} + +static PyObject * +make_individual_object(tsk_individual_t *r) +{ + PyObject *ret = NULL; + PyObject *metadata = make_metadata(r->metadata, (Py_ssize_t) r->metadata_length); + PyArrayObject *location = NULL; + PyArrayObject *nodes = NULL; + npy_intp dims; + + dims = (npy_intp) r->location_length; + location = (PyArrayObject *) PyArray_SimpleNew(1, &dims, NPY_FLOAT64); + dims = (npy_intp) r->nodes_length; + nodes = (PyArrayObject *) PyArray_SimpleNew(1, &dims, NPY_INT32); + if (metadata == NULL || location == NULL || nodes == NULL) { + goto out; + } + memcpy(PyArray_DATA(location), r->location, r->location_length * sizeof(double)); + memcpy(PyArray_DATA(nodes), r->nodes, r->nodes_length * sizeof(tsk_id_t)); + ret = Py_BuildValue("IOOO", (unsigned int) r->flags, location, metadata, nodes); +out: + Py_XDECREF(location); + Py_XDECREF(metadata); + Py_XDECREF(nodes); + return ret; +} + +static PyObject * +make_node(tsk_node_t *r) +{ + PyObject *ret = NULL; + PyObject* metadata = make_metadata(r->metadata, (Py_ssize_t) r->metadata_length); + if (metadata == NULL) { + goto out; + } + ret = Py_BuildValue("IdiiO", + (unsigned int) r->flags, r->time, (int) r->population, (int) r->individual, metadata); +out: + Py_XDECREF(metadata); + return ret; +} + +static PyObject * +make_edge(tsk_edge_t *edge) +{ + return Py_BuildValue("ddii", + edge->left, edge->right, (int) edge->parent, (int) edge->child); +} + +static PyObject * +make_migration(tsk_migration_t *r) +{ + int source = r->source == TSK_NULL ? -1: r->source; + int dest = r->dest == TSK_NULL ? -1: r->dest; + PyObject *ret = NULL; + + ret = Py_BuildValue("ddiiid", + r->left, r->right, (int) r->node, source, dest, r->time); + return ret; +} + +static PyObject * +make_site_row(tsk_site_t *site) +{ + PyObject *ret = NULL; + PyObject* metadata = NULL; + + metadata = make_metadata(site->metadata, (Py_ssize_t) site->metadata_length); + if (metadata == NULL) { + goto out; + } + ret = Py_BuildValue("ds#O", site->position, site->ancestral_state, + (Py_ssize_t) site->ancestral_state_length, metadata); +out: + Py_XDECREF(metadata); + return ret; +} + +static PyObject * +make_site_object(tsk_site_t *site) +{ + PyObject *ret = NULL; + PyObject *mutations = NULL; + PyObject* metadata = NULL; + + metadata = make_metadata(site->metadata, (Py_ssize_t) site->metadata_length); + if (metadata == NULL) { + goto out; + } + mutations = make_mutation_id_list(site->mutations, site->mutations_length); + if (mutations == NULL) { + goto out; + } + /* TODO should reorder this tuple, as it's not very logical. */ + ret = Py_BuildValue("ds#OnO", site->position, site->ancestral_state, + (Py_ssize_t) site->ancestral_state_length, mutations, + (Py_ssize_t) site->id, metadata); +out: + Py_XDECREF(mutations); + Py_XDECREF(metadata); + return ret; +} + +static PyObject * +make_alleles(tsk_variant_t *variant) +{ + PyObject *ret = NULL; + PyObject *item, *t; + size_t j; + + t = PyTuple_New(variant->num_alleles); + if (t == NULL) { + goto out; + } + for (j = 0; j < variant->num_alleles; j++) { + item = Py_BuildValue("s#", variant->alleles[j], variant->allele_lengths[j]); + if (item == NULL) { + Py_DECREF(t); + goto out; + } + PyTuple_SET_ITEM(t, j, item); + } + ret = t; +out: + return ret; +} + +static PyObject * +make_variant(tsk_variant_t *variant, size_t num_samples) +{ + PyObject *ret = NULL; + npy_intp dims = num_samples; + PyObject *alleles = make_alleles(variant); + PyArrayObject *genotypes = (PyArrayObject *) PyArray_SimpleNew(1, &dims, NPY_UINT8); + + /* TODO update this to account for 16 bit variants when we provide the + * high-level interface. */ + if (genotypes == NULL || alleles == NULL) { + goto out; + } + memcpy(PyArray_DATA(genotypes), variant->genotypes.u8, num_samples * sizeof(uint8_t)); + ret = Py_BuildValue("iOO", variant->site->id, genotypes, alleles); +out: + Py_XDECREF(genotypes); + Py_XDECREF(alleles); + return ret; +} + +static PyObject * +convert_sites(tsk_site_t *sites, size_t num_sites) +{ + PyObject *ret = NULL; + PyObject *l = NULL; + PyObject *py_site = NULL; + size_t j; + + l = PyList_New(num_sites); + if (l == NULL) { + goto out; + } + for (j = 0; j < num_sites; j++) { + py_site = make_site_object(&sites[j]); + if (py_site == NULL) { + Py_DECREF(l); + goto out; + } + PyList_SET_ITEM(l, j, py_site); + } + ret = l; +out: + return ret; +} + +/*=================================================================== + * General table code. + *=================================================================== + */ + +/* + * Retrieves the PyObject* corresponding the specified key in the + * specified dictionary. If required is true, raise a TypeError if the + * value is None. + * + * NB This returns a *borrowed reference*, so don't DECREF it! + */ +static PyObject * +get_table_dict_value(PyObject *dict, const char *key_str, bool required) +{ + PyObject *ret = NULL; + + ret = PyDict_GetItemString(dict, key_str); + if (ret == NULL) { + PyErr_Format(PyExc_ValueError, "'%s' not specified", key_str); + } + if (required && ret == Py_None) { + PyErr_Format(PyExc_TypeError, "'%s' is required", key_str); + ret = NULL; + } + return ret; +} + +static PyObject * +table_get_column_array(size_t num_rows, void *data, int npy_type, + size_t element_size) +{ + PyObject *ret = NULL; + PyArrayObject *array; + npy_intp dims = (npy_intp) num_rows; + + array = (PyArrayObject *) PyArray_EMPTY(1, &dims, npy_type, 0); + if (array == NULL) { + goto out; + } + memcpy(PyArray_DATA(array), data, num_rows * element_size); + ret = (PyObject *) array; +out: + return ret; +} + +static PyArrayObject * +table_read_column_array(PyObject *input, int npy_type, size_t *num_rows, bool check_num_rows) +{ + PyArrayObject *ret = NULL; + PyArrayObject *array = NULL; + npy_intp *shape; + + array = (PyArrayObject *) PyArray_FROMANY(input, npy_type, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + shape = PyArray_DIMS(array); + if (check_num_rows) { + if (*num_rows != (size_t) shape[0]) { + PyErr_SetString(PyExc_ValueError, "Input array dimensions must be equal."); + goto out; + } + } else { + *num_rows = (size_t) shape[0]; + } + ret = array; + array = NULL; +out: + Py_XDECREF(array); + return ret; +} + +static PyArrayObject * +table_read_offset_array(PyObject *input, size_t *num_rows, size_t length, bool check_num_rows) +{ + PyArrayObject *ret = NULL; + PyArrayObject *array = NULL; + npy_intp *shape; + uint32_t *data; + + array = (PyArrayObject *) PyArray_FROMANY(input, NPY_UINT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + shape = PyArray_DIMS(array); + if (! check_num_rows) { + *num_rows = shape[0]; + if (*num_rows == 0) { + PyErr_SetString(PyExc_ValueError, "Offset arrays must have at least one element"); + goto out; + } + *num_rows -= 1; + } + if (shape[0] != *num_rows + 1) { + PyErr_SetString(PyExc_ValueError, "offset columns must have n + 1 rows."); + goto out; + } + data = PyArray_DATA(array); + if (data[*num_rows] != (uint32_t) length) { + PyErr_SetString(PyExc_ValueError, "Bad offset column encoding"); + goto out; + } + ret = array; +out: + if (ret == NULL) { + Py_XDECREF(array); + } + return ret; +} + +static int +parse_individual_table_dict(tsk_individual_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows, metadata_length, location_length; + char *metadata_data = NULL; + double *location_data = NULL; + uint32_t *metadata_offset_data = NULL; + uint32_t *location_offset_data = NULL; + PyObject *flags_input = NULL; + PyArrayObject *flags_array = NULL; + PyObject *location_input = NULL; + PyArrayObject *location_array = NULL; + PyObject *location_offset_input = NULL; + PyArrayObject *location_offset_array = NULL; + PyObject *metadata_input = NULL; + PyArrayObject *metadata_array = NULL; + PyObject *metadata_offset_input = NULL; + PyArrayObject *metadata_offset_array = NULL; + + /* Get the input values */ + flags_input = get_table_dict_value(dict, "flags", true); + if (flags_input == NULL) { + goto out; + } + location_input = get_table_dict_value(dict, "location", false); + if (location_input == NULL) { + goto out; + } + location_offset_input = get_table_dict_value(dict, "location_offset", false); + if (location_offset_input == NULL) { + goto out; + } + metadata_input = get_table_dict_value(dict, "metadata", false); + if (metadata_input == NULL) { + goto out; + } + metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + if (metadata_offset_input == NULL) { + goto out; + } + + /* Pull out the arrays */ + flags_array = table_read_column_array(flags_input, NPY_UINT32, &num_rows, false); + if (flags_array == NULL) { + goto out; + } + if ((location_input == Py_None) != (location_offset_input == Py_None)) { + PyErr_SetString(PyExc_TypeError, + "location and location_offset must be specified together"); + goto out; + } + if (location_input != Py_None) { + location_array = table_read_column_array(location_input, NPY_FLOAT64, + &location_length, false); + if (location_array == NULL) { + goto out; + } + location_data = PyArray_DATA(location_array); + location_offset_array = table_read_offset_array(location_offset_input, &num_rows, + location_length, true); + if (location_offset_array == NULL) { + goto out; + } + location_offset_data = PyArray_DATA(location_offset_array); + } + if ((metadata_input == Py_None) != (metadata_offset_input == Py_None)) { + PyErr_SetString(PyExc_TypeError, + "metadata and metadata_offset must be specified together"); + goto out; + } + if (metadata_input != Py_None) { + metadata_array = table_read_column_array(metadata_input, NPY_INT8, + &metadata_length, false); + if (metadata_array == NULL) { + goto out; + } + metadata_data = PyArray_DATA(metadata_array); + metadata_offset_array = table_read_offset_array(metadata_offset_input, &num_rows, + metadata_length, true); + if (metadata_offset_array == NULL) { + goto out; + } + metadata_offset_data = PyArray_DATA(metadata_offset_array); + } + + if (clear_table) { + err = tsk_individual_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_individual_tbl_append_columns(table, num_rows, + PyArray_DATA(flags_array), + location_data, location_offset_data, + metadata_data, metadata_offset_data); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(flags_array); + Py_XDECREF(location_array); + Py_XDECREF(location_offset_array); + Py_XDECREF(metadata_array); + Py_XDECREF(metadata_offset_array); + return ret; +} + +static int +parse_node_table_dict(tsk_node_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows, metadata_length; + char *metadata_data = NULL; + uint32_t *metadata_offset_data = NULL; + void *population_data = NULL; + void *individual_data = NULL; + PyObject *time_input = NULL; + PyArrayObject *time_array = NULL; + PyObject *flags_input = NULL; + PyArrayObject *flags_array = NULL; + PyObject *population_input = NULL; + PyArrayObject *population_array = NULL; + PyObject *individual_input = NULL; + PyArrayObject *individual_array = NULL; + PyObject *metadata_input = NULL; + PyArrayObject *metadata_array = NULL; + PyObject *metadata_offset_input = NULL; + PyArrayObject *metadata_offset_array = NULL; + + /* Get the input values */ + flags_input = get_table_dict_value(dict, "flags", true); + if (flags_input == NULL) { + goto out; + } + time_input = get_table_dict_value(dict, "time", true); + if (time_input == NULL) { + goto out; + } + population_input = get_table_dict_value(dict, "population", false); + if (population_input == NULL) { + goto out; + } + individual_input = get_table_dict_value(dict, "individual", false); + if (individual_input == NULL) { + goto out; + } + metadata_input = get_table_dict_value(dict, "metadata", false); + if (metadata_input == NULL) { + goto out; + } + metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + if (metadata_offset_input == NULL) { + goto out; + } + + /* Create the arrays */ + flags_array = table_read_column_array(flags_input, NPY_UINT32, &num_rows, false); + if (flags_array == NULL) { + goto out; + } + time_array = table_read_column_array(time_input, NPY_FLOAT64, &num_rows, true); + if (time_array == NULL) { + goto out; + } + if (population_input != Py_None) { + population_array = table_read_column_array(population_input, NPY_INT32, + &num_rows, true); + if (population_array == NULL) { + goto out; + } + population_data = PyArray_DATA(population_array); + } + if (individual_input != Py_None) { + individual_array = table_read_column_array(individual_input, NPY_INT32, + &num_rows, true); + if (individual_array == NULL) { + goto out; + } + individual_data = PyArray_DATA(individual_array); + } + if ((metadata_input == Py_None) != (metadata_offset_input == Py_None)) { + PyErr_SetString(PyExc_TypeError, + "metadata and metadata_offset must be specified together"); + goto out; + } + if (metadata_input != Py_None) { + metadata_array = table_read_column_array(metadata_input, NPY_INT8, + &metadata_length, false); + if (metadata_array == NULL) { + goto out; + } + metadata_data = PyArray_DATA(metadata_array); + metadata_offset_array = table_read_offset_array(metadata_offset_input, &num_rows, + metadata_length, true); + if (metadata_offset_array == NULL) { + goto out; + } + metadata_offset_data = PyArray_DATA(metadata_offset_array); + } + if (clear_table) { + err = tsk_node_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_node_tbl_append_columns(table, num_rows, + PyArray_DATA(flags_array), PyArray_DATA(time_array), population_data, + individual_data, metadata_data, metadata_offset_data); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(flags_array); + Py_XDECREF(time_array); + Py_XDECREF(population_array); + Py_XDECREF(individual_array); + Py_XDECREF(metadata_array); + Py_XDECREF(metadata_offset_array); + return ret; +} + +static int +parse_edge_table_dict(tsk_edge_tbl_t *table, PyObject *dict, bool clear_table) +{ + int ret = -1; + int err; + size_t num_rows = 0; + PyObject *left_input = NULL; + PyArrayObject *left_array = NULL; + PyObject *right_input = NULL; + PyArrayObject *right_array = NULL; + PyObject *parent_input = NULL; + PyArrayObject *parent_array = NULL; + PyObject *child_input = NULL; + PyArrayObject *child_array = NULL; + + /* Get the input values */ + left_input = get_table_dict_value(dict, "left", true); + if (left_input == NULL) { + goto out; + } + right_input = get_table_dict_value(dict, "right", true); + if (right_input == NULL) { + goto out; + } + parent_input = get_table_dict_value(dict, "parent", true); + if (parent_input == NULL) { + goto out; + } + child_input = get_table_dict_value(dict, "child", true); + if (child_input == NULL) { + goto out; + } + + /* Create the arrays */ + left_array = table_read_column_array(left_input, NPY_FLOAT64, &num_rows, false); + if (left_array == NULL) { + goto out; + } + right_array = table_read_column_array(right_input, NPY_FLOAT64, &num_rows, true); + if (right_array == NULL) { + goto out; + } + parent_array = table_read_column_array(parent_input, NPY_INT32, &num_rows, true); + if (parent_array == NULL) { + goto out; + } + child_array = table_read_column_array(child_input, NPY_INT32, &num_rows, true); + if (child_array == NULL) { + goto out; + } + + if (clear_table) { + err = tsk_edge_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_edge_tbl_append_columns(table, num_rows, + PyArray_DATA(left_array), PyArray_DATA(right_array), + PyArray_DATA(parent_array), PyArray_DATA(child_array)); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(left_array); + Py_XDECREF(right_array); + Py_XDECREF(parent_array); + Py_XDECREF(child_array); + return ret; +} + +static int +parse_migration_table_dict(tsk_migration_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows; + PyObject *left_input = NULL; + PyArrayObject *left_array = NULL; + PyObject *right_input = NULL; + PyArrayObject *right_array = NULL; + PyObject *node_input = NULL; + PyArrayObject *node_array = NULL; + PyObject *source_input = NULL; + PyArrayObject *source_array = NULL; + PyObject *dest_input = NULL; + PyArrayObject *dest_array = NULL; + PyObject *time_input = NULL; + PyArrayObject *time_array = NULL; + + /* Get the input values */ + left_input = get_table_dict_value(dict, "left", true); + if (left_input == NULL) { + goto out; + } + right_input = get_table_dict_value(dict, "right", true); + if (right_input == NULL) { + goto out; + } + node_input = get_table_dict_value(dict, "node", true); + if (node_input == NULL) { + goto out; + } + source_input = get_table_dict_value(dict, "source", true); + if (source_input == NULL) { + goto out; + } + dest_input = get_table_dict_value(dict, "dest", true); + if (dest_input == NULL) { + goto out; + } + time_input = get_table_dict_value(dict, "time", true); + if (time_input == NULL) { + goto out; + } + + /* Build the arrays */ + left_array = table_read_column_array(left_input, NPY_FLOAT64, &num_rows, false); + if (left_array == NULL) { + goto out; + } + right_array = table_read_column_array(right_input, NPY_FLOAT64, &num_rows, true); + if (right_array == NULL) { + goto out; + } + node_array = table_read_column_array(node_input, NPY_INT32, &num_rows, true); + if (node_array == NULL) { + goto out; + } + source_array = table_read_column_array(source_input, NPY_INT32, &num_rows, true); + if (source_array == NULL) { + goto out; + } + dest_array = table_read_column_array(dest_input, NPY_INT32, &num_rows, true); + if (dest_array == NULL) { + goto out; + } + time_array = table_read_column_array(time_input, NPY_FLOAT64, &num_rows, true); + if (time_array == NULL) { + goto out; + } + + if (clear_table) { + err = tsk_migration_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_migration_tbl_append_columns(table, num_rows, + PyArray_DATA(left_array), PyArray_DATA(right_array), PyArray_DATA(node_array), + PyArray_DATA(source_array), PyArray_DATA(dest_array), PyArray_DATA(time_array)); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(left_array); + Py_XDECREF(right_array); + Py_XDECREF(node_array); + Py_XDECREF(source_array); + Py_XDECREF(dest_array); + Py_XDECREF(time_array); + return ret; +} + +static int +parse_site_table_dict(tsk_site_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows = 0; + size_t ancestral_state_length, metadata_length; + PyObject *position_input = NULL; + PyArrayObject *position_array = NULL; + PyObject *ancestral_state_input = NULL; + PyArrayObject *ancestral_state_array = NULL; + PyObject *ancestral_state_offset_input = NULL; + PyArrayObject *ancestral_state_offset_array = NULL; + PyObject *metadata_input = NULL; + PyArrayObject *metadata_array = NULL; + PyObject *metadata_offset_input = NULL; + PyArrayObject *metadata_offset_array = NULL; + char *metadata_data; + uint32_t *metadata_offset_data; + + /* Get the input values */ + position_input = get_table_dict_value(dict, "position", true); + if (position_input == NULL) { + goto out; + } + ancestral_state_input = get_table_dict_value(dict, "ancestral_state", true); + if (ancestral_state_input == NULL) { + goto out; + } + ancestral_state_offset_input = get_table_dict_value(dict, "ancestral_state_offset", true); + if (ancestral_state_offset_input == NULL) { + goto out; + } + metadata_input = get_table_dict_value(dict, "metadata", false); + if (metadata_input == NULL) { + goto out; + } + metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + if (metadata_offset_input == NULL) { + goto out; + } + + /* Get the arrays */ + position_array = table_read_column_array(position_input, NPY_FLOAT64, &num_rows, false); + if (position_array == NULL) { + goto out; + } + ancestral_state_array = table_read_column_array(ancestral_state_input, NPY_INT8, + &ancestral_state_length, false); + if (ancestral_state_array == NULL) { + goto out; + } + ancestral_state_offset_array = table_read_offset_array(ancestral_state_offset_input, + &num_rows, ancestral_state_length, true); + if (ancestral_state_offset_array == NULL) { + goto out; + } + + metadata_data = NULL; + metadata_offset_data = NULL; + if ((metadata_input == Py_None) != (metadata_offset_input == Py_None)) { + PyErr_SetString(PyExc_TypeError, + "metadata and metadata_offset must be specified together"); + goto out; + } + if (metadata_input != Py_None) { + metadata_array = table_read_column_array(metadata_input, NPY_INT8, + &metadata_length, false); + if (metadata_array == NULL) { + goto out; + } + metadata_data = PyArray_DATA(metadata_array); + metadata_offset_array = table_read_offset_array(metadata_offset_input, &num_rows, + metadata_length, false); + if (metadata_offset_array == NULL) { + goto out; + } + metadata_offset_data = PyArray_DATA(metadata_offset_array); + } + + if (clear_table) { + err = tsk_site_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_site_tbl_append_columns(table, num_rows, + PyArray_DATA(position_array), PyArray_DATA(ancestral_state_array), + PyArray_DATA(ancestral_state_offset_array), metadata_data, metadata_offset_data); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(position_array); + Py_XDECREF(ancestral_state_array); + Py_XDECREF(ancestral_state_offset_array); + Py_XDECREF(metadata_array); + Py_XDECREF(metadata_offset_array); + return ret; +} + +static int +parse_mutation_table_dict(tsk_mutation_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows = 0; + size_t derived_state_length = 0; + size_t metadata_length = 0; + PyObject *site_input = NULL; + PyArrayObject *site_array = NULL; + PyObject *derived_state_input = NULL; + PyArrayObject *derived_state_array = NULL; + PyObject *derived_state_offset_input = NULL; + PyArrayObject *derived_state_offset_array = NULL; + PyObject *node_input = NULL; + PyArrayObject *node_array = NULL; + PyObject *parent_input = NULL; + PyArrayObject *parent_array = NULL; + tsk_id_t *parent_data; + PyObject *metadata_input = NULL; + PyArrayObject *metadata_array = NULL; + PyObject *metadata_offset_input = NULL; + PyArrayObject *metadata_offset_array = NULL; + char *metadata_data; + uint32_t *metadata_offset_data; + + /* Get the input values */ + site_input = get_table_dict_value(dict, "site", true); + if (site_input == NULL) { + goto out; + } + node_input = get_table_dict_value(dict, "node", true); + if (node_input == NULL) { + goto out; + } + parent_input = get_table_dict_value(dict, "parent", false); + if (parent_input == NULL) { + goto out; + } + derived_state_input = get_table_dict_value(dict, "derived_state", true); + if (derived_state_input == NULL) { + goto out; + } + derived_state_offset_input = get_table_dict_value(dict, "derived_state_offset", true); + if (derived_state_offset_input == NULL) { + goto out; + } + metadata_input = get_table_dict_value(dict, "metadata", false); + if (metadata_input == NULL) { + goto out; + } + metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + if (metadata_offset_input == NULL) { + goto out; + } + + /* Get the arrays */ + site_array = table_read_column_array(site_input, NPY_INT32, &num_rows, false); + if (site_array == NULL) { + goto out; + } + derived_state_array = table_read_column_array(derived_state_input, NPY_INT8, + &derived_state_length, false); + if (derived_state_array == NULL) { + goto out; + } + derived_state_offset_array = table_read_offset_array(derived_state_offset_input, + &num_rows, derived_state_length, true); + if (derived_state_offset_array == NULL) { + goto out; + } + node_array = table_read_column_array(node_input, NPY_INT32, &num_rows, true); + if (node_array == NULL) { + goto out; + } + + parent_data = NULL; + if (parent_input != Py_None) { + parent_array = table_read_column_array(parent_input, NPY_INT32, &num_rows, true); + if (parent_array == NULL) { + goto out; + } + parent_data = PyArray_DATA(parent_array); + } + + metadata_data = NULL; + metadata_offset_data = NULL; + if ((metadata_input == Py_None) != (metadata_offset_input == Py_None)) { + PyErr_SetString(PyExc_TypeError, + "metadata and metadata_offset must be specified together"); + goto out; + } + if (metadata_input != Py_None) { + metadata_array = table_read_column_array(metadata_input, NPY_INT8, + &metadata_length, false); + if (metadata_array == NULL) { + goto out; + } + metadata_data = PyArray_DATA(metadata_array); + metadata_offset_array = table_read_offset_array(metadata_offset_input, &num_rows, + metadata_length, false); + if (metadata_offset_array == NULL) { + goto out; + } + metadata_offset_data = PyArray_DATA(metadata_offset_array); + } + + if (clear_table) { + err = tsk_mutation_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_mutation_tbl_append_columns(table, num_rows, + PyArray_DATA(site_array), PyArray_DATA(node_array), + parent_data, PyArray_DATA(derived_state_array), + PyArray_DATA(derived_state_offset_array), + metadata_data, metadata_offset_data); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(site_array); + Py_XDECREF(derived_state_array); + Py_XDECREF(derived_state_offset_array); + Py_XDECREF(metadata_array); + Py_XDECREF(metadata_offset_array); + Py_XDECREF(node_array); + Py_XDECREF(parent_array); + return ret; +} + +static int +parse_population_table_dict(tsk_population_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows, metadata_length; + PyObject *metadata_input = NULL; + PyArrayObject *metadata_array = NULL; + PyObject *metadata_offset_input = NULL; + PyArrayObject *metadata_offset_array = NULL; + + /* Get the inputs */ + metadata_input = get_table_dict_value(dict, "metadata", true); + if (metadata_input == NULL) { + goto out; + } + metadata_offset_input = get_table_dict_value(dict, "metadata_offset", true); + if (metadata_offset_input == NULL) { + goto out; + } + + /* Get the arrays */ + metadata_array = table_read_column_array(metadata_input, NPY_INT8, + &metadata_length, false); + if (metadata_array == NULL) { + goto out; + } + metadata_offset_array = table_read_offset_array(metadata_offset_input, &num_rows, + metadata_length, false); + if (metadata_offset_array == NULL) { + goto out; + } + + if (clear_table) { + err = tsk_population_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_population_tbl_append_columns(table, num_rows, + PyArray_DATA(metadata_array), PyArray_DATA(metadata_offset_array)); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(metadata_array); + Py_XDECREF(metadata_offset_array); + return ret; +} + +static int +parse_provenance_table_dict(tsk_provenance_tbl_t *table, PyObject *dict, bool clear_table) +{ + int err; + int ret = -1; + size_t num_rows, timestamp_length, record_length; + PyObject *timestamp_input = NULL; + PyArrayObject *timestamp_array = NULL; + PyObject *timestamp_offset_input = NULL; + PyArrayObject *timestamp_offset_array = NULL; + PyObject *record_input = NULL; + PyArrayObject *record_array = NULL; + PyObject *record_offset_input = NULL; + PyArrayObject *record_offset_array = NULL; + + /* Get the inputs */ + timestamp_input = get_table_dict_value(dict, "timestamp", true); + if (timestamp_input == NULL) { + goto out; + } + timestamp_offset_input = get_table_dict_value(dict, "timestamp_offset", true); + if (timestamp_offset_input == NULL) { + goto out; + } + record_input = get_table_dict_value(dict, "record", true); + if (record_input == NULL) { + goto out; + } + record_offset_input = get_table_dict_value(dict, "record_offset", true); + if (record_offset_input == NULL) { + goto out; + } + + timestamp_array = table_read_column_array(timestamp_input, NPY_INT8, + ×tamp_length, false); + if (timestamp_array == NULL) { + goto out; + } + timestamp_offset_array = table_read_offset_array(timestamp_offset_input, &num_rows, + timestamp_length, false); + if (timestamp_offset_array == NULL) { + goto out; + } + record_array = table_read_column_array(record_input, NPY_INT8, + &record_length, false); + if (record_array == NULL) { + goto out; + } + record_offset_array = table_read_offset_array(record_offset_input, &num_rows, + record_length, true); + if (record_offset_array == NULL) { + goto out; + } + + if (clear_table) { + err = tsk_provenance_tbl_clear(table); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + err = tsk_provenance_tbl_append_columns(table, num_rows, + PyArray_DATA(timestamp_array), PyArray_DATA(timestamp_offset_array), + PyArray_DATA(record_array), PyArray_DATA(record_offset_array)); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(timestamp_array); + Py_XDECREF(timestamp_offset_array); + Py_XDECREF(record_array); + Py_XDECREF(record_offset_array); + return ret; +} + +static int +parse_table_collection_dict(tsk_tbl_collection_t *tables, PyObject *tables_dict) +{ + int ret = -1; + PyObject *value = NULL; + + value = get_table_dict_value(tables_dict, "sequence_length", true); + if (value == NULL) { + goto out; + } + if (!PyNumber_Check(value)) { + PyErr_Format(PyExc_TypeError, "'sequence_length' is not number"); + goto out; + } + tables->sequence_length = PyFloat_AsDouble(value); + + /* individuals */ + value = get_table_dict_value(tables_dict, "individuals", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_individual_table_dict(tables->individuals, value, true) != 0) { + goto out; + } + + /* nodes */ + value = get_table_dict_value(tables_dict, "nodes", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_node_table_dict(tables->nodes, value, true) != 0) { + goto out; + } + + /* edges */ + value = get_table_dict_value(tables_dict, "edges", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_edge_table_dict(tables->edges, value, true) != 0) { + goto out; + } + + /* migrations */ + value = get_table_dict_value(tables_dict, "migrations", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_migration_table_dict(tables->migrations, value, true) != 0) { + goto out; + } + + /* sites */ + value = get_table_dict_value(tables_dict, "sites", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_site_table_dict(tables->sites, value, true) != 0) { + goto out; + } + + /* mutations */ + value = get_table_dict_value(tables_dict, "mutations", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_mutation_table_dict(tables->mutations, value, true) != 0) { + goto out; + } + + /* populations */ + value = get_table_dict_value(tables_dict, "populations", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_population_table_dict(tables->populations, value, true) != 0) { + goto out; + } + + /* provenances */ + value = get_table_dict_value(tables_dict, "provenances", true); + if (value == NULL) { + goto out; + } + if (!PyDict_Check(value)) { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + goto out; + } + if (parse_provenance_table_dict(tables->provenances, value, true) != 0) { + goto out; + } + + ret = 0; +out: + return ret; +} + +static int +write_table_arrays(tsk_tbl_collection_t *tables, PyObject *dict) +{ + struct table_col { + const char *name; + void *data; + npy_intp num_rows; + int type; + }; + struct table_desc { + const char *name; + struct table_col *cols; + }; + int ret = -1; + PyObject *array = NULL; + PyObject *table_dict = NULL; + size_t j; + struct table_col *col; + + struct table_col individual_cols[] = { + {"flags", + (void *) tables->individuals->flags, tables->individuals->num_rows, NPY_UINT32}, + {"location", + (void *) tables->individuals->location, tables->individuals->location_length, + NPY_FLOAT64}, + {"location_offset", + (void *) tables->individuals->location_offset, tables->individuals->num_rows + 1, + NPY_UINT32}, + {"metadata", + (void *) tables->individuals->metadata, tables->individuals->metadata_length, + NPY_INT8}, + {"metadata_offset", + (void *) tables->individuals->metadata_offset, tables->individuals->num_rows + 1, + NPY_UINT32}, + {NULL}, + }; + + struct table_col node_cols[] = { + {"time", + (void *) tables->nodes->time, tables->nodes->num_rows, NPY_FLOAT64}, + {"flags", + (void *) tables->nodes->flags, tables->nodes->num_rows, NPY_UINT32}, + {"population", + (void *) tables->nodes->population, tables->nodes->num_rows, NPY_INT32}, + {"individual", + (void *) tables->nodes->individual, tables->nodes->num_rows, NPY_INT32}, + {"metadata", + (void *) tables->nodes->metadata, tables->nodes->metadata_length, NPY_INT8}, + {"metadata_offset", + (void *) tables->nodes->metadata_offset, tables->nodes->num_rows + 1, NPY_UINT32}, + {NULL}, + }; + + struct table_col edge_cols[] = { + {"left", (void *) tables->edges->left, tables->edges->num_rows, NPY_FLOAT64}, + {"right", (void *) tables->edges->right, tables->edges->num_rows, NPY_FLOAT64}, + {"parent", (void *) tables->edges->parent, tables->edges->num_rows, NPY_INT32}, + {"child", (void *) tables->edges->child, tables->edges->num_rows, NPY_INT32}, + {NULL}, + }; + + struct table_col migration_cols[] = { + {"left", + (void *) tables->migrations->left, tables->migrations->num_rows, NPY_FLOAT64}, + {"right", + (void *) tables->migrations->right, tables->migrations->num_rows, NPY_FLOAT64}, + {"node", + (void *) tables->migrations->node, tables->migrations->num_rows, NPY_INT32}, + {"source", + (void *) tables->migrations->source, tables->migrations->num_rows, NPY_INT32}, + {"dest", + (void *) tables->migrations->dest, tables->migrations->num_rows, NPY_INT32}, + {"time", + (void *) tables->migrations->time, tables->migrations->num_rows, NPY_FLOAT64}, + {NULL}, + }; + + struct table_col site_cols[] = { + {"position", + (void *) tables->sites->position, tables->sites->num_rows, NPY_FLOAT64}, + {"ancestral_state", + (void *) tables->sites->ancestral_state, tables->sites->ancestral_state_length, + NPY_INT8}, + {"ancestral_state_offset", + (void *) tables->sites->ancestral_state_offset, tables->sites->num_rows + 1, + NPY_UINT32}, + {"metadata", + (void *) tables->sites->metadata, tables->sites->metadata_length, NPY_INT8}, + {"metadata_offset", + (void *) tables->sites->metadata_offset, tables->sites->num_rows + 1, NPY_UINT32}, + {NULL}, + }; + + struct table_col mutation_cols[] = { + {"site", + (void *) tables->mutations->site, tables->mutations->num_rows, NPY_INT32}, + {"node", + (void *) tables->mutations->node, tables->mutations->num_rows, NPY_INT32}, + {"parent", + (void *) tables->mutations->parent, tables->mutations->num_rows, NPY_INT32}, + {"derived_state", + (void *) tables->mutations->derived_state, + tables->mutations->derived_state_length, NPY_INT8}, + {"derived_state_offset", + (void *) tables->mutations->derived_state_offset, + tables->mutations->num_rows + 1, NPY_UINT32}, + {"metadata", + (void *) tables->mutations->metadata, + tables->mutations->metadata_length, NPY_INT8}, + {"metadata_offset", + (void *) tables->mutations->metadata_offset, + tables->mutations->num_rows + 1, NPY_UINT32}, + {NULL}, + }; + + struct table_col population_cols[] = { + {"metadata", (void *) tables->populations->metadata, + tables->populations->metadata_length, NPY_INT8}, + {"metadata_offset", (void *) tables->populations->metadata_offset, + tables->populations->num_rows+ 1, NPY_UINT32}, + {NULL}, + }; + + struct table_col provenance_cols[] = { + {"timestamp", (void *) tables->provenances->timestamp, + tables->provenances->timestamp_length, NPY_INT8}, + {"timestamp_offset", (void *) tables->provenances->timestamp_offset, + tables->provenances->num_rows+ 1, NPY_UINT32}, + {"record", (void *) tables->provenances->record, + tables->provenances->record_length, NPY_INT8}, + {"record_offset", (void *) tables->provenances->record_offset, + tables->provenances->num_rows + 1, NPY_UINT32}, + {NULL}, + }; + + struct table_desc table_descs[] = { + {"individuals", individual_cols}, + {"nodes", node_cols}, + {"edges", edge_cols}, + {"migrations", migration_cols}, + {"sites", site_cols}, + {"mutations", mutation_cols}, + {"populations", population_cols}, + {"provenances", provenance_cols}, + }; + + for (j = 0; j < sizeof(table_descs) / sizeof(*table_descs); j++) { + table_dict = PyDict_New(); + if (table_dict == NULL) { + goto out; + } + col = table_descs[j].cols; + while (col->name != NULL) { + array = PyArray_SimpleNewFromData(1, &col->num_rows, col->type, col->data); + if (array == NULL) { + goto out; + } + if (PyDict_SetItemString(table_dict, col->name, array) != 0) { + goto out; + } + Py_DECREF(array); + array = NULL; + col++; + } + if (PyDict_SetItemString(dict, table_descs[j].name, table_dict) != 0) { + goto out; + } + Py_DECREF(table_dict); + table_dict = NULL; + } + ret = 0; +out: + Py_XDECREF(array); + Py_XDECREF(table_dict); + return ret; +} + +/* Returns a dictionary encoding of the specified table collection */ +static PyObject* +dump_tables_dict(tsk_tbl_collection_t *tables) +{ + PyObject *ret = NULL; + PyObject *dict = NULL; + PyObject *val = NULL; + int err; + + dict = PyDict_New(); + if (dict == NULL) { + goto out; + } + val = Py_BuildValue("d", tables->sequence_length); + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, "sequence_length", val) != 0) { + goto out; + } + Py_DECREF(val); + val = NULL; + + err = write_table_arrays(tables, dict); + if (err != 0) { + goto out; + } + ret = dict; + dict = NULL; +out: + Py_XDECREF(dict); + Py_XDECREF(val); + return ret; +} + +/*=================================================================== + * LightweightTableCollection + *=================================================================== + */ + +static int +LightweightTableCollection_check_state(LightweightTableCollection *self) +{ + int ret = 0; + if (self->tables == NULL) { + PyErr_SetString(PyExc_SystemError, "LightweightTableCollection not initialised"); + ret = -1; + } + return ret; +} + +static void +LightweightTableCollection_dealloc(LightweightTableCollection* self) +{ + if (self->tables != NULL) { + tsk_tbl_collection_free(self->tables); + PyMem_Free(self->tables); + self->tables = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +LightweightTableCollection_init(LightweightTableCollection *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"sequence_length", NULL}; + double sequence_length = -1; + + self->tables = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|d", kwlist, &sequence_length)) { + goto out; + } + self->tables = PyMem_Malloc(sizeof(*self->tables)); + if (self->tables == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_tbl_collection_alloc(self->tables, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + self->tables->sequence_length = sequence_length; + ret = 0; +out: + return ret; +} + +static PyObject * +LightweightTableCollection_asdict(LightweightTableCollection *self) +{ + PyObject *ret = NULL; + + if (LightweightTableCollection_check_state(self) != 0) { + goto out; + } + ret = dump_tables_dict(self->tables); +out: + return ret; +} + +static PyObject * +LightweightTableCollection_fromdict(LightweightTableCollection *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + err = parse_table_collection_dict(self->tables, dict); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyMemberDef LightweightTableCollection_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef LightweightTableCollection_methods[] = { + {"asdict", (PyCFunction) LightweightTableCollection_asdict, + METH_NOARGS, "Returns the tables encoded as a dictionary."}, + {"fromdict", (PyCFunction) LightweightTableCollection_fromdict, + METH_VARARGS, "Populates the internal tables using the specified dictionary."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject LightweightTableCollectionType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_msprime.LightweightTableCollection", /* tp_name */ + sizeof(LightweightTableCollection), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)LightweightTableCollection_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "LightweightTableCollection objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + LightweightTableCollection_methods, /* tp_methods */ + LightweightTableCollection_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)LightweightTableCollection_init, /* tp_init */ +}; + +/*=================================================================== + * IndividualTable + *=================================================================== + */ + +static int +IndividualTable_check_state(IndividualTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "IndividualTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "IndividualTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +IndividualTable_dealloc(IndividualTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_individual_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +IndividualTable_init(IndividualTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + self->locked = false; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, + &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_individual_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_individual_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_individual_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +IndividualTable_add_row(IndividualTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + unsigned int flags = 0; + PyObject *py_metadata = Py_None; + PyObject *py_location = Py_None; + PyArrayObject *location_array = NULL; + double *location_data = NULL; + tsk_tbl_size_t location_length = 0; + char *metadata = ""; + Py_ssize_t metadata_length = 0; + npy_intp *shape; + static char *kwlist[] = {"flags", "location", "metadata", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|iOO", kwlist, + &flags, &py_location, &py_metadata)) { + goto out; + } + if (IndividualTable_check_state(self) != 0) { + goto out; + } + if (py_metadata != Py_None) { + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + } + if (py_location != Py_None) { + /* This ensures that only 1D arrays are accepted. */ + location_array = (PyArrayObject *) PyArray_FromAny(py_location, + PyArray_DescrFromType(NPY_FLOAT64), 1, 1, + NPY_ARRAY_IN_ARRAY, NULL); + if (location_array == NULL) { + goto out; + } + shape = PyArray_DIMS(location_array); + location_length = (tsk_tbl_size_t) shape[0]; + location_data = PyArray_DATA(location_array); + } + err = tsk_individual_tbl_add_row(self->table, (uint32_t) flags, + location_data, location_length, metadata, metadata_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + Py_XDECREF(location_array); + return ret; +} + + +/* Forward declaration */ +static PyTypeObject IndividualTableType; + +static PyObject * +IndividualTable_equals(IndividualTable *self, PyObject *args) +{ + PyObject *ret = NULL; + IndividualTable *other = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &IndividualTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_individual_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +IndividualTable_get_row(IndividualTable *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + Py_ssize_t row_id; + tsk_individual_t individual; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_individual_tbl_get_row(self->table, (size_t) row_id, &individual); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_individual_row(&individual); +out: + return ret; +} + +static PyObject * +IndividualTable_parse_dict_arg(IndividualTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (IndividualTable_check_state(self) != 0) { + goto out; + } + err = parse_individual_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +IndividualTable_append_columns(IndividualTable *self, PyObject *args) +{ + return IndividualTable_parse_dict_arg(self, args, false); +} + +static PyObject * +IndividualTable_set_columns(IndividualTable *self, PyObject *args) +{ + return IndividualTable_parse_dict_arg(self, args, true); +} + +static PyObject * +IndividualTable_clear(IndividualTable *self) +{ + PyObject *ret = NULL; + int err; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + err = tsk_individual_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +IndividualTable_truncate(IndividualTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_individual_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +IndividualTable_get_max_rows_increment(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +IndividualTable_get_num_rows(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +IndividualTable_get_max_rows(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +IndividualTable_get_flags(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->flags, + NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +IndividualTable_get_location(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->location_length, + self->table->location, NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +IndividualTable_get_location_offset(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->location_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +IndividualTable_get_metadata(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->metadata_length, + self->table->metadata, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +IndividualTable_get_metadata_offset(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->metadata_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef IndividualTable_getsetters[] = { + {"max_rows_increment", + (getter) IndividualTable_get_max_rows_increment, NULL, "The size increment"}, + {"num_rows", (getter) IndividualTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) IndividualTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"flags", (getter) IndividualTable_get_flags, NULL, "The flags array"}, + {"location", (getter) IndividualTable_get_location, NULL, "The location array"}, + {"location_offset", (getter) IndividualTable_get_location_offset, NULL, + "The location offset array"}, + {"metadata", (getter) IndividualTable_get_metadata, NULL, "The metadata array"}, + {"metadata_offset", (getter) IndividualTable_get_metadata_offset, NULL, + "The metadata offset array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef IndividualTable_methods[] = { + {"add_row", (PyCFunction) IndividualTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"get_row", (PyCFunction) IndividualTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"equals", (PyCFunction) IndividualTable_equals, METH_VARARGS, + "Returns true if the specified individual table is equal."}, + {"append_columns", (PyCFunction) IndividualTable_append_columns, + METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"set_columns", (PyCFunction) IndividualTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) IndividualTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) IndividualTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject IndividualTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.IndividualTable", /* tp_name */ + sizeof(IndividualTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)IndividualTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "IndividualTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + IndividualTable_methods, /* tp_methods */ + 0, /* tp_members */ + IndividualTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)IndividualTable_init, /* tp_init */ +}; + + +/*=================================================================== + * NodeTable + *=================================================================== + */ + +static int +NodeTable_check_state(NodeTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "NodeTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "NodeTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +NodeTable_dealloc(NodeTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_node_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +NodeTable_init(NodeTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + self->locked = false; + self->tables = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, + &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_node_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_node_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_node_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +NodeTable_add_row(NodeTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + unsigned int flags = 0; + double time = 0; + int population = -1; + int individual = -1; + PyObject *py_metadata = Py_None; + char *metadata = ""; + Py_ssize_t metadata_length = 0; + static char *kwlist[] = {"flags", "time", "population", "individual", "metadata", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|idiiO", kwlist, + &flags, &time, &population, &individual, &py_metadata)) { + goto out; + } + if (NodeTable_check_state(self) != 0) { + goto out; + } + if (py_metadata != Py_None) { + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + } + err = tsk_node_tbl_add_row(self->table, (uint32_t) flags, time, + (tsk_id_t) population, individual, metadata, metadata_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject NodeTableType; + +static PyObject * +NodeTable_equals(NodeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + NodeTable *other = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &NodeTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_node_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +NodeTable_get_row(NodeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + Py_ssize_t row_id; + tsk_node_t node; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_node_tbl_get_row(self->table, (size_t) row_id, &node); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_node(&node); +out: + return ret; +} + +static PyObject * +NodeTable_parse_dict_arg(NodeTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (NodeTable_check_state(self) != 0) { + goto out; + } + err = parse_node_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +NodeTable_append_columns(NodeTable *self, PyObject *args) +{ + return NodeTable_parse_dict_arg(self, args, false); +} + +static PyObject * +NodeTable_set_columns(NodeTable *self, PyObject *args) +{ + return NodeTable_parse_dict_arg(self, args, true); +} + +static PyObject * +NodeTable_clear(NodeTable *self) +{ + PyObject *ret = NULL; + int err; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + err = tsk_node_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +NodeTable_truncate(NodeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_node_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +NodeTable_get_max_rows_increment(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +NodeTable_get_num_rows(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +NodeTable_get_max_rows(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +NodeTable_get_time(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->time, + NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +NodeTable_get_flags(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->flags, + NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +NodeTable_get_population(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->population, + NPY_INT32, sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +NodeTable_get_individual(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->individual, + NPY_INT32, sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +NodeTable_get_metadata(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->metadata_length, + self->table->metadata, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +NodeTable_get_metadata_offset(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->metadata_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef NodeTable_getsetters[] = { + {"max_rows_increment", + (getter) NodeTable_get_max_rows_increment, NULL, "The size increment"}, + {"num_rows", (getter) NodeTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) NodeTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"time", (getter) NodeTable_get_time, NULL, "The time array"}, + {"flags", (getter) NodeTable_get_flags, NULL, "The flags array"}, + {"population", (getter) NodeTable_get_population, NULL, "The population array"}, + {"individual", (getter) NodeTable_get_individual, NULL, "The individual array"}, + {"metadata", (getter) NodeTable_get_metadata, NULL, "The metadata array"}, + {"metadata_offset", (getter) NodeTable_get_metadata_offset, NULL, + "The metadata offset array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef NodeTable_methods[] = { + {"add_row", (PyCFunction) NodeTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) NodeTable_equals, METH_VARARGS, + "Returns True if the specified NodeTable is equal to this one."}, + {"get_row", (PyCFunction) NodeTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"append_columns", (PyCFunction) NodeTable_append_columns, METH_VARARGS, + "Appends the data in the specified arrays into the columns."}, + {"set_columns", (PyCFunction) NodeTable_set_columns, METH_VARARGS, + "Copies the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) NodeTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) NodeTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject NodeTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.NodeTable", /* tp_name */ + sizeof(NodeTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)NodeTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "NodeTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + NodeTable_methods, /* tp_methods */ + 0, /* tp_members */ + NodeTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)NodeTable_init, /* tp_init */ +}; + +/*=================================================================== + * EdgeTable + *=================================================================== + */ + +static int +EdgeTable_check_state(EdgeTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "EdgeTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "EdgeTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +EdgeTable_dealloc(EdgeTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_edge_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +EdgeTable_init(EdgeTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_edge_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_edge_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_edge_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +EdgeTable_add_row(EdgeTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + double left = 0.0; + double right = 1.0; + int parent; + int child; + static char *kwlist[] = {"left", "right", "parent", "child", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "ddii", kwlist, + &left, &right, &parent, &child)) { + goto out; + } + if (EdgeTable_check_state(self) != 0) { + goto out; + } + err = tsk_edge_tbl_add_row(self->table, left, right, parent, child); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject EdgeTableType; + +static PyObject * +EdgeTable_equals(EdgeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + EdgeTable *other = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &EdgeTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_edge_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +EdgeTable_get_row(EdgeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_edge_t edge; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_edge_tbl_get_row(self->table, (size_t) row_id, &edge); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_edge(&edge); +out: + return ret; +} + +static PyObject * +EdgeTable_parse_dict_arg(EdgeTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (EdgeTable_check_state(self) != 0) { + goto out; + } + err = parse_edge_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +EdgeTable_append_columns(EdgeTable *self, PyObject *args) +{ + return EdgeTable_parse_dict_arg(self, args, false); +} + +static PyObject * +EdgeTable_set_columns(EdgeTable *self, PyObject *args) +{ + return EdgeTable_parse_dict_arg(self, args, true); +} + +static PyObject * +EdgeTable_clear(EdgeTable *self) +{ + PyObject *ret = NULL; + int err; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + err = tsk_edge_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +EdgeTable_truncate(EdgeTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_edge_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +EdgeTable_get_max_rows_increment(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +EdgeTable_get_num_rows(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +EdgeTable_get_max_rows(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +EdgeTable_get_left(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->left, NPY_FLOAT64, + sizeof(double)); +out: + return ret; +} + +static PyObject * +EdgeTable_get_right(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->right, NPY_FLOAT64, + sizeof(double)); +out: + return ret; +} + +static PyObject * +EdgeTable_get_parent(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->parent, NPY_INT32, + sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +EdgeTable_get_child(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->child, NPY_INT32, + sizeof(int32_t)); +out: + return ret; +} + +static PyGetSetDef EdgeTable_getsetters[] = { + {"max_rows_increment", + (getter) EdgeTable_get_max_rows_increment, NULL, + "The size increment"}, + {"num_rows", (getter) EdgeTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) EdgeTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"left", (getter) EdgeTable_get_left, NULL, "The left array"}, + {"right", (getter) EdgeTable_get_right, NULL, "The right array"}, + {"parent", (getter) EdgeTable_get_parent, NULL, "The parent array"}, + {"child", (getter) EdgeTable_get_child, NULL, "The child array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef EdgeTable_methods[] = { + {"add_row", (PyCFunction) EdgeTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) EdgeTable_equals, METH_VARARGS, + "Returns True if the specified EdgeTable is equal to this one."}, + {"get_row", (PyCFunction) EdgeTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"set_columns", (PyCFunction) EdgeTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"append_columns", (PyCFunction) EdgeTable_append_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) EdgeTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) EdgeTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject EdgeTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.EdgeTable", /* tp_name */ + sizeof(EdgeTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)EdgeTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "EdgeTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + EdgeTable_methods, /* tp_methods */ + 0, /* tp_members */ + EdgeTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)EdgeTable_init, /* tp_init */ +}; + +/*=================================================================== + * MigrationTable + *=================================================================== + */ + +static int +MigrationTable_check_state(MigrationTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "MigrationTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "MigrationTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +MigrationTable_dealloc(MigrationTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_migration_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +MigrationTable_init(MigrationTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, + &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_migration_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_migration_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_migration_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +MigrationTable_add_row(MigrationTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + double left, right, time; + int node, source, dest; + static char *kwlist[] = {"left", "right", "node", "source", "dest", "time", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "ddiiid", kwlist, + &left, &right, &node, &source, &dest, &time)) { + goto out; + } + if (MigrationTable_check_state(self) != 0) { + goto out; + } + err = tsk_migration_tbl_add_row(self->table, left, right, node, + source, dest, time); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject MigrationTableType; + +static PyObject * +MigrationTable_equals(MigrationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + MigrationTable *other = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &MigrationTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_migration_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_row(MigrationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_migration_t migration; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_migration_tbl_get_row(self->table, (size_t) row_id, &migration); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_migration(&migration); +out: + return ret; +} + +static PyObject * +MigrationTable_parse_dict_arg(MigrationTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (MigrationTable_check_state(self) != 0) { + goto out; + } + err = parse_migration_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MigrationTable_append_columns(MigrationTable *self, PyObject *args) +{ + return MigrationTable_parse_dict_arg(self, args, false); +} + +static PyObject * +MigrationTable_set_columns(MigrationTable *self, PyObject *args) +{ + return MigrationTable_parse_dict_arg(self, args, true); +} + +static PyObject * +MigrationTable_clear(MigrationTable *self) +{ + PyObject *ret = NULL; + int err; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + err = tsk_migration_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MigrationTable_truncate(MigrationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_migration_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MigrationTable_get_max_rows_increment(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +MigrationTable_get_num_rows(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +MigrationTable_get_max_rows(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +MigrationTable_get_left(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->left, + NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_right(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->right, + NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_time(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->time, + NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_node(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->node, + NPY_INT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_source(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->source, + NPY_INT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +MigrationTable_get_dest(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows, self->table->dest, + NPY_INT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef MigrationTable_getsetters[] = { + {"max_rows_increment", + (getter) MigrationTable_get_max_rows_increment, NULL, "The size increment"}, + {"num_rows", (getter) MigrationTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) MigrationTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"left", (getter) MigrationTable_get_left, NULL, "The left array"}, + {"right", (getter) MigrationTable_get_right, NULL, "The right array"}, + {"node", (getter) MigrationTable_get_node, NULL, "The node array"}, + {"source", (getter) MigrationTable_get_source, NULL, "The source array"}, + {"dest", (getter) MigrationTable_get_dest, NULL, "The dest array"}, + {"time", (getter) MigrationTable_get_time, NULL, "The time array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef MigrationTable_methods[] = { + {"add_row", (PyCFunction) MigrationTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) MigrationTable_equals, METH_VARARGS, + "Returns True if the specified MigrationTable is equal to this one."}, + {"get_row", (PyCFunction) MigrationTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"set_columns", (PyCFunction) MigrationTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"append_columns", (PyCFunction) MigrationTable_append_columns, METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) MigrationTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) MigrationTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject MigrationTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.MigrationTable", /* tp_name */ + sizeof(MigrationTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)MigrationTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "MigrationTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + MigrationTable_methods, /* tp_methods */ + 0, /* tp_members */ + MigrationTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)MigrationTable_init, /* tp_init */ +}; + + +/*=================================================================== + * SiteTable + *=================================================================== + */ + +static int +SiteTable_check_state(SiteTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "SiteTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "SiteTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +SiteTable_dealloc(SiteTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_site_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +SiteTable_init(SiteTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_site_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_site_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_site_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +SiteTable_add_row(SiteTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + double position; + char *ancestral_state = NULL; + Py_ssize_t ancestral_state_length = 0; + PyObject *py_metadata = Py_None; + char *metadata = NULL; + Py_ssize_t metadata_length = 0; + static char *kwlist[] = {"position", "ancestral_state", "metadata", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "ds#|O", kwlist, + &position, &ancestral_state, &ancestral_state_length, &py_metadata)) { + goto out; + } + if (SiteTable_check_state(self) != 0) { + goto out; + } + if (py_metadata != Py_None) { + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + } + err = tsk_site_tbl_add_row(self->table, position, ancestral_state, + ancestral_state_length, metadata, metadata_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject SiteTableType; + +static PyObject * +SiteTable_equals(SiteTable *self, PyObject *args) +{ + PyObject *ret = NULL; + SiteTable *other = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &SiteTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_site_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +SiteTable_get_row(SiteTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_site_t site; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_site_tbl_get_row(self->table, (size_t) row_id, &site); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_site_row(&site); +out: + return ret; +} + +static PyObject * +SiteTable_parse_dict_arg(SiteTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (SiteTable_check_state(self) != 0) { + goto out; + } + err = parse_site_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +SiteTable_append_columns(SiteTable *self, PyObject *args) +{ + return SiteTable_parse_dict_arg(self, args, false); +} + +static PyObject * +SiteTable_set_columns(SiteTable *self, PyObject *args) +{ + return SiteTable_parse_dict_arg(self, args, true); +} + +static PyObject * +SiteTable_clear(SiteTable *self) +{ + PyObject *ret = NULL; + int err; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + err = tsk_site_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +SiteTable_truncate(SiteTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_site_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +SiteTable_get_max_rows_increment(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +SiteTable_get_num_rows(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +SiteTable_get_max_rows(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +SiteTable_get_position(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, + self->table->position, NPY_FLOAT64, sizeof(double)); +out: + return ret; +} + +static PyObject * +SiteTable_get_ancestral_state(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->ancestral_state_length, + self->table->ancestral_state, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +SiteTable_get_ancestral_state_offset(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows + 1, + self->table->ancestral_state_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +SiteTable_get_metadata(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->metadata_length, + self->table->metadata, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +SiteTable_get_metadata_offset(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows + 1, + self->table->metadata_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef SiteTable_getsetters[] = { + {"max_rows_increment", + (getter) SiteTable_get_max_rows_increment, NULL, + "The size increment"}, + {"num_rows", + (getter) SiteTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", + (getter) SiteTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"position", (getter) SiteTable_get_position, NULL, + "The position array."}, + {"ancestral_state", (getter) SiteTable_get_ancestral_state, NULL, + "The ancestral state array."}, + {"ancestral_state_offset", (getter) SiteTable_get_ancestral_state_offset, NULL, + "The ancestral state offset array."}, + {"metadata", (getter) SiteTable_get_metadata, NULL, + "The metadata array."}, + {"metadata_offset", (getter) SiteTable_get_metadata_offset, NULL, + "The metadata offset array."}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef SiteTable_methods[] = { + {"add_row", (PyCFunction) SiteTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) SiteTable_equals, METH_VARARGS, + "Returns True if the specified SiteTable is equal to this one."}, + {"get_row", (PyCFunction) SiteTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"set_columns", (PyCFunction) SiteTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"append_columns", (PyCFunction) SiteTable_append_columns, METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) SiteTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) SiteTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject SiteTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.SiteTable", /* tp_name */ + sizeof(SiteTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)SiteTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "SiteTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + SiteTable_methods, /* tp_methods */ + 0, /* tp_members */ + SiteTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)SiteTable_init, /* tp_init */ +}; + + +/*=================================================================== + * MutationTable + *=================================================================== + */ + +static int +MutationTable_check_state(MutationTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "MutationTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "MutationTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +MutationTable_dealloc(MutationTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_mutation_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +MutationTable_init(MutationTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_mutation_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_mutation_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_mutation_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +MutationTable_add_row(MutationTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + int site; + int node; + int parent = TSK_NULL; + char *derived_state; + Py_ssize_t derived_state_length; + PyObject *py_metadata = Py_None; + char *metadata = NULL; + Py_ssize_t metadata_length = 0; + static char *kwlist[] = {"site", "node", "derived_state", "parent", "metadata", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "iis#|iO", kwlist, + &site, &node, &derived_state, &derived_state_length, &parent, + &py_metadata)) { + goto out; + } + if (MutationTable_check_state(self) != 0) { + goto out; + } + if (py_metadata != Py_None) { + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + } + err = tsk_mutation_tbl_add_row(self->table, (tsk_id_t) site, + (tsk_id_t) node, (tsk_id_t) parent, + derived_state, derived_state_length, + metadata, metadata_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject MutationTableType; + +static PyObject * +MutationTable_equals(MutationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + MutationTable *other = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &MutationTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_mutation_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +MutationTable_get_row(MutationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_mutation_t mutation; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_mutation_tbl_get_row(self->table, (size_t) row_id, &mutation); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_mutation(&mutation); +out: + return ret; +} + +static PyObject * +MutationTable_parse_dict_arg(MutationTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (MutationTable_check_state(self) != 0) { + goto out; + } + err = parse_mutation_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MutationTable_append_columns(MutationTable *self, PyObject *args) +{ + return MutationTable_parse_dict_arg(self, args, false); +} + +static PyObject * +MutationTable_set_columns(MutationTable *self, PyObject *args) +{ + return MutationTable_parse_dict_arg(self, args, true); +} + +static PyObject * +MutationTable_clear(MutationTable *self) +{ + PyObject *ret = NULL; + int err; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + err = tsk_mutation_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MutationTable_truncate(MutationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_mutation_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +MutationTable_get_max_rows_increment(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +MutationTable_get_num_rows(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +MutationTable_get_max_rows(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +MutationTable_get_site(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->site, NPY_INT32, + sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +MutationTable_get_node(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->node, NPY_INT32, + sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +MutationTable_get_parent(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows, self->table->parent, NPY_INT32, + sizeof(int32_t)); +out: + return ret; +} + +static PyObject * +MutationTable_get_derived_state(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->derived_state_length, self->table->derived_state, + NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +MutationTable_get_derived_state_offset(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows + 1, self->table->derived_state_offset, + NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +MutationTable_get_metadata(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->metadata_length, self->table->metadata, + NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +MutationTable_get_metadata_offset(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array( + self->table->num_rows + 1, self->table->metadata_offset, + NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef MutationTable_getsetters[] = { + {"max_rows_increment", + (getter) MutationTable_get_max_rows_increment, NULL, + "The size increment"}, + {"num_rows", + (getter) MutationTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", + (getter) MutationTable_get_max_rows, NULL, + "The curret maximum number of rows in the table."}, + {"site", (getter) MutationTable_get_site, NULL, "The site array"}, + {"node", (getter) MutationTable_get_node, NULL, "The node array"}, + {"parent", (getter) MutationTable_get_parent, NULL, "The parent array"}, + {"derived_state", (getter) MutationTable_get_derived_state, NULL, + "The derived_state array"}, + {"derived_state_offset", (getter) MutationTable_get_derived_state_offset, NULL, + "The derived_state_offset array"}, + {"metadata", (getter) MutationTable_get_metadata, NULL, + "The metadata array"}, + {"metadata_offset", (getter) MutationTable_get_metadata_offset, NULL, + "The metadata_offset array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef MutationTable_methods[] = { + {"add_row", (PyCFunction) MutationTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) MutationTable_equals, METH_VARARGS, + "Returns True if the specified MutationTable is equal to this one."}, + {"get_row", (PyCFunction) MutationTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"set_columns", (PyCFunction) MutationTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"append_columns", (PyCFunction) MutationTable_append_columns, METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) MutationTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) MutationTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject MutationTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.MutationTable", /* tp_name */ + sizeof(MutationTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)MutationTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "MutationTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + MutationTable_methods, /* tp_methods */ + 0, /* tp_members */ + MutationTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)MutationTable_init, /* tp_init */ +}; + +/*=================================================================== + * PopulationTable + *=================================================================== + */ + +static int +PopulationTable_check_state(PopulationTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "PopulationTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "PopulationTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +PopulationTable_dealloc(PopulationTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_population_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +PopulationTable_init(PopulationTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + self->locked = false; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, + &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_population_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_population_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_population_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} + +static PyObject * +PopulationTable_add_row(PopulationTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + PyObject *py_metadata = Py_None; + char *metadata = NULL; + Py_ssize_t metadata_length = 0; + static char *kwlist[] = {"metadata", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &py_metadata)) { + goto out; + } + if (PopulationTable_check_state(self) != 0) { + goto out; + } + + if (py_metadata != Py_None) { + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + } + err = tsk_population_tbl_add_row(self->table, metadata, metadata_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject PopulationTableType; + +static PyObject * +PopulationTable_equals(PopulationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + PopulationTable *other = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &PopulationTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_population_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +PopulationTable_get_row(PopulationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_population_t population; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_population_tbl_get_row(self->table, (size_t) row_id, &population); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_population(&population); +out: + return ret; +} + +static PyObject * +PopulationTable_parse_dict_arg(PopulationTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (PopulationTable_check_state(self) != 0) { + goto out; + } + err = parse_population_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +PopulationTable_append_columns(PopulationTable *self, PyObject *args) +{ + return PopulationTable_parse_dict_arg(self, args, false); +} + +static PyObject * +PopulationTable_set_columns(PopulationTable *self, PyObject *args) +{ + return PopulationTable_parse_dict_arg(self, args, true); +} + +static PyObject * +PopulationTable_clear(PopulationTable *self) +{ + PyObject *ret = NULL; + int err; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + err = tsk_population_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +PopulationTable_truncate(PopulationTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_population_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +PopulationTable_get_max_rows_increment(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +PopulationTable_get_num_rows(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +PopulationTable_get_max_rows(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +PopulationTable_get_metadata(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->metadata_length, + self->table->metadata, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +PopulationTable_get_metadata_offset(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->metadata_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef PopulationTable_getsetters[] = { + {"max_rows_increment", + (getter) PopulationTable_get_max_rows_increment, NULL, "The size increment"}, + {"num_rows", (getter) PopulationTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) PopulationTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"metadata", (getter) PopulationTable_get_metadata, NULL, "The metadata array"}, + {"metadata_offset", (getter) PopulationTable_get_metadata_offset, NULL, + "The metadata offset array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef PopulationTable_methods[] = { + {"add_row", (PyCFunction) PopulationTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) PopulationTable_equals, METH_VARARGS, + "Returns True if the specified PopulationTable is equal to this one."}, + {"get_row", (PyCFunction) PopulationTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"append_columns", (PyCFunction) PopulationTable_append_columns, + METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"set_columns", (PyCFunction) PopulationTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) PopulationTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) PopulationTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject PopulationTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.PopulationTable", /* tp_name */ + sizeof(PopulationTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)PopulationTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "PopulationTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + PopulationTable_methods, /* tp_methods */ + 0, /* tp_members */ + PopulationTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)PopulationTable_init, /* tp_init */ +}; + + +/*=================================================================== + * ProvenanceTable + *=================================================================== + */ + +static int +ProvenanceTable_check_state(ProvenanceTable *self) +{ + int ret = -1; + if (self->table == NULL) { + PyErr_SetString(PyExc_SystemError, "ProvenanceTable not initialised"); + goto out; + } + if (self->locked) { + PyErr_SetString(PyExc_RuntimeError, "ProvenanceTable in use by other thread."); + goto out; + } + ret = 0; +out: + return ret; +} + +static void +ProvenanceTable_dealloc(ProvenanceTable* self) +{ + if (self->tables != NULL) { + Py_DECREF(self->tables); + } else if (self->table != NULL) { + tsk_provenance_tbl_free(self->table); + PyMem_Free(self->table); + self->table = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +ProvenanceTable_init(ProvenanceTable *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"max_rows_increment", NULL}; + Py_ssize_t max_rows_increment = 0; + + self->table = NULL; + self->locked = false; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, + &max_rows_increment)) { + goto out; + } + if (max_rows_increment < 0) { + PyErr_SetString(PyExc_ValueError, "max_rows_increment must be positive"); + goto out; + } + self->table = PyMem_Malloc(sizeof(tsk_provenance_tbl_t)); + if (self->table == NULL) { + PyErr_NoMemory(); + goto out; + } + + err = tsk_provenance_tbl_alloc(self->table, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + tsk_provenance_tbl_set_max_rows_increment(self->table, max_rows_increment); + ret = 0; +out: + return ret; +} +static PyObject * +ProvenanceTable_add_row(ProvenanceTable *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + int err; + char *timestamp = ""; + Py_ssize_t timestamp_length = 0; + char *record = ""; + Py_ssize_t record_length = 0; + static char *kwlist[] = {"timestamp", "record", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s#s#", kwlist, + ×tamp, ×tamp_length, &record, &record_length)){ + goto out; + } + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + err = tsk_provenance_tbl_add_row(self->table, + timestamp, timestamp_length, record, record_length); + if (err < 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", err); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject ProvenanceTableType; + +static PyObject * +ProvenanceTable_equals(ProvenanceTable *self, PyObject *args) +{ + PyObject *ret = NULL; + ProvenanceTable *other = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O!", &ProvenanceTableType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_provenance_tbl_equals(self->table, other->table)); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_row(ProvenanceTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t row_id; + int err; + tsk_provenance_t provenance; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &row_id)) { + goto out; + } + err = tsk_provenance_tbl_get_row(self->table, (size_t) row_id, &provenance); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_provenance(&provenance); +out: + return ret; +} + +static PyObject * +ProvenanceTable_parse_dict_arg(ProvenanceTable *self, PyObject *args, bool clear_table) +{ + int err; + PyObject *ret = NULL; + PyObject *dict = NULL; + + if (!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict)) { + goto out; + } + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + err = parse_provenance_table_dict(self->table, dict, clear_table); + if (err != 0) { + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +ProvenanceTable_append_columns(ProvenanceTable *self, PyObject *args) +{ + return ProvenanceTable_parse_dict_arg(self, args, false); +} + +static PyObject * +ProvenanceTable_set_columns(ProvenanceTable *self, PyObject *args) +{ + return ProvenanceTable_parse_dict_arg(self, args, true); +} + +static PyObject * +ProvenanceTable_clear(ProvenanceTable *self) +{ + PyObject *ret = NULL; + int err; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + err = tsk_provenance_tbl_clear(self->table); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +ProvenanceTable_truncate(ProvenanceTable *self, PyObject *args) +{ + PyObject *ret = NULL; + Py_ssize_t num_rows; + int err; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &num_rows)) { + goto out; + } + if (num_rows < 0 || num_rows > (Py_ssize_t) self->table->num_rows) { + PyErr_SetString(PyExc_ValueError, "num_rows out of bounds"); + goto out; + } + err = tsk_provenance_tbl_truncate(self->table, (size_t) num_rows); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_max_rows_increment(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows_increment); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_num_rows(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->num_rows); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_max_rows(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->table->max_rows); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_timestamp(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->timestamp_length, + self->table->timestamp, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_timestamp_offset(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->timestamp_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_record(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->record_length, + self->table->record, NPY_INT8, sizeof(char)); +out: + return ret; +} + +static PyObject * +ProvenanceTable_get_record_offset(ProvenanceTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (ProvenanceTable_check_state(self) != 0) { + goto out; + } + ret = table_get_column_array(self->table->num_rows + 1, + self->table->record_offset, NPY_UINT32, sizeof(uint32_t)); +out: + return ret; +} + +static PyGetSetDef ProvenanceTable_getsetters[] = { + {"max_rows_increment", + (getter) ProvenanceTable_get_max_rows_increment, NULL, "The size increment"}, + {"num_rows", (getter) ProvenanceTable_get_num_rows, NULL, + "The number of rows in the table."}, + {"max_rows", (getter) ProvenanceTable_get_max_rows, NULL, + "The current maximum number of rows in the table."}, + {"timestamp", (getter) ProvenanceTable_get_timestamp, NULL, "The timestamp array"}, + {"timestamp_offset", (getter) ProvenanceTable_get_timestamp_offset, NULL, + "The timestamp offset array"}, + {"record", (getter) ProvenanceTable_get_record, NULL, "The record array"}, + {"record_offset", (getter) ProvenanceTable_get_record_offset, NULL, + "The record offset array"}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef ProvenanceTable_methods[] = { + {"add_row", (PyCFunction) ProvenanceTable_add_row, METH_VARARGS|METH_KEYWORDS, + "Adds a new row to this table."}, + {"equals", (PyCFunction) ProvenanceTable_equals, METH_VARARGS, + "Returns True if the specified ProvenanceTable is equal to this one."}, + {"get_row", (PyCFunction) ProvenanceTable_get_row, METH_VARARGS, + "Returns the kth row in this table."}, + {"append_columns", (PyCFunction) ProvenanceTable_append_columns, + METH_VARARGS|METH_KEYWORDS, + "Appends the data in the specified arrays into the columns."}, + {"set_columns", (PyCFunction) ProvenanceTable_set_columns, METH_VARARGS|METH_KEYWORDS, + "Copies the data in the specified arrays into the columns."}, + {"clear", (PyCFunction) ProvenanceTable_clear, METH_NOARGS, + "Clears this table."}, + {"truncate", (PyCFunction) ProvenanceTable_truncate, METH_VARARGS, + "Truncates this table to the specified number of rows."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject ProvenanceTableType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.ProvenanceTable", /* tp_name */ + sizeof(ProvenanceTable), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)ProvenanceTable_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "ProvenanceTable objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + ProvenanceTable_methods, /* tp_methods */ + 0, /* tp_members */ + ProvenanceTable_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)ProvenanceTable_init, /* tp_init */ +}; + +/*=================================================================== + * TableCollection + *=================================================================== + */ + +static void +TableCollection_dealloc(TableCollection* self) +{ + if (self->tables != NULL) { + tsk_tbl_collection_free(self->tables); + PyMem_Free(self->tables); + self->tables = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +TableCollection_init(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"sequence_length", NULL}; + double sequence_length = -1; + + self->tables = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|d", kwlist, &sequence_length)) { + goto out; + } + + self->tables = PyMem_Malloc(sizeof(tsk_tbl_collection_t)); + if (self->tables == NULL) { + PyErr_NoMemory(); + } + err = tsk_tbl_collection_alloc(self->tables, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + self->tables->sequence_length = sequence_length; + ret = 0; +out: + return ret; +} + +/* The getters for each of the tables returns a new reference which we + * set up here. These references use a pointer to the table stored in + * the table collection, so to guard against this memory getting freed + * we the Python Table classes keep a reference to the TableCollection + * and INCREF it. We don't keep permanent references to the Table classes + * in the TableCollection as this gives a circular references which would + * require implementing support for cyclic garbage collection. + */ + +static PyObject * +TableCollection_get_individuals(TableCollection *self, void *closure) +{ + IndividualTable *individuals = NULL; + + individuals = PyObject_New(IndividualTable, &IndividualTableType); + if (individuals == NULL) { + goto out; + } + individuals->table = self->tables->individuals; + individuals->locked = false; + individuals->tables = self; + Py_INCREF(self); +out: + return (PyObject *) individuals; +} + +static PyObject * +TableCollection_get_nodes(TableCollection *self, void *closure) +{ + NodeTable *nodes = NULL; + + nodes = PyObject_New(NodeTable, &NodeTableType); + if (nodes == NULL) { + goto out; + } + nodes->table = self->tables->nodes; + nodes->locked = false; + nodes->tables = self; + Py_INCREF(self); +out: + return (PyObject *) nodes; +} + +static PyObject * +TableCollection_get_edges(TableCollection *self, void *closure) +{ + EdgeTable *edges = NULL; + + edges = PyObject_New(EdgeTable, &EdgeTableType); + if (edges == NULL) { + goto out; + } + edges->table = self->tables->edges; + edges->locked = false; + edges->tables = self; + Py_INCREF(self); +out: + return (PyObject *) edges; +} + +static PyObject * +TableCollection_get_migrations(TableCollection *self, void *closure) +{ + MigrationTable *migrations = NULL; + + migrations = PyObject_New(MigrationTable, &MigrationTableType); + if (migrations == NULL) { + goto out; + } + migrations->table = self->tables->migrations; + migrations->locked = false; + migrations->tables = self; + Py_INCREF(self); +out: + return (PyObject *) migrations; +} + +static PyObject * +TableCollection_get_sites(TableCollection *self, void *closure) +{ + SiteTable *sites = NULL; + + sites = PyObject_New(SiteTable, &SiteTableType); + if (sites == NULL) { + goto out; + } + sites->table = self->tables->sites; + sites->locked = false; + sites->tables = self; + Py_INCREF(self); +out: + return (PyObject *) sites; +} + +static PyObject * +TableCollection_get_mutations(TableCollection *self, void *closure) +{ + MutationTable *mutations = NULL; + + mutations = PyObject_New(MutationTable, &MutationTableType); + if (mutations == NULL) { + goto out; + } + mutations->table = self->tables->mutations; + mutations->locked = false; + mutations->tables = self; + Py_INCREF(self); +out: + return (PyObject *) mutations; +} + +static PyObject * +TableCollection_get_populations(TableCollection *self, void *closure) +{ + PopulationTable *populations = NULL; + + populations = PyObject_New(PopulationTable, &PopulationTableType); + if (populations == NULL) { + goto out; + } + populations->table = self->tables->populations; + populations->locked = false; + populations->tables = self; + Py_INCREF(self); +out: + return (PyObject *) populations; +} + +static PyObject * +TableCollection_get_provenances(TableCollection *self, void *closure) +{ + ProvenanceTable *provenances = NULL; + + provenances = PyObject_New(ProvenanceTable, &ProvenanceTableType); + if (provenances == NULL) { + goto out; + } + provenances->table = self->tables->provenances; + provenances->locked = false; + provenances->tables = self; + Py_INCREF(self); +out: + return (PyObject *) provenances; +} + +static PyObject * +TableCollection_get_sequence_length(TableCollection *self, void *closure) +{ + return Py_BuildValue("f", self->tables->sequence_length); +} + +static PyObject * +TableCollection_get_file_uuid(TableCollection *self, void *closure) +{ + return Py_BuildValue("s", self->tables->file_uuid); +} + +static PyObject * +TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + PyObject *samples = NULL; + PyArrayObject *samples_array = NULL; + PyArrayObject *node_map_array = NULL; + npy_intp *shape, dims; + size_t num_samples; + int flags = 0; + int filter_sites = true; + int filter_individuals = false; + int filter_populations = false; + int reduce_to_site_topology = false; + static char *kwlist[] = { + "samples", "filter_sites", "filter_populations", "filter_individuals", + "reduce_to_site_topology", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiii", kwlist, + &samples, &filter_sites, &filter_populations, &filter_individuals, + &reduce_to_site_topology)) { + goto out; + } + samples_array = (PyArrayObject *) PyArray_FROMANY(samples, NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + num_samples = shape[0]; + if (filter_sites) { + flags |= TSK_FILTER_SITES; + } + if (filter_individuals) { + flags |= TSK_FILTER_INDIVIDUALS; + } + if (filter_populations) { + flags |= TSK_FILTER_POPULATIONS; + } + if (reduce_to_site_topology) { + flags |= TSK_REDUCE_TO_SITE_TOPOLOGY; + } + + /* Allocate a new array to hold the node map. */ + dims = self->tables->nodes->num_rows; + node_map_array = (PyArrayObject *) PyArray_SimpleNew(1, &dims, NPY_INT32); + if (node_map_array == NULL) { + goto out; + } + err = tsk_tbl_collection_simplify(self->tables, + PyArray_DATA(samples_array), num_samples, flags, + PyArray_DATA(node_map_array)); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) node_map_array; + node_map_array = NULL; +out: + Py_XDECREF(samples_array); + Py_XDECREF(node_map_array); + return ret; +} + +static PyObject * +TableCollection_sort(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t edge_start = 0; + int flags = 0; + + static char *kwlist[] = {"edge_start", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|n", kwlist, &edge_start)) { + goto out; + } + err = tsk_tbl_collection_sort(self->tables, (size_t) edge_start, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TableCollection_compute_mutation_parents(TableCollection *self) +{ + int err; + PyObject *ret = NULL; + + err = tsk_tbl_collection_compute_mutation_parents(self->tables, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TableCollection_deduplicate_sites(TableCollection *self) +{ + int err; + PyObject *ret = NULL; + + err = tsk_tbl_collection_deduplicate_sites(self->tables, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +/* Forward declaration */ +static PyTypeObject TableCollectionType; + +static PyObject * +TableCollection_equals(TableCollection *self, PyObject *args) +{ + PyObject *ret = NULL; + TableCollection *other = NULL; + + if (!PyArg_ParseTuple(args, "O!", &TableCollectionType, &other)) { + goto out; + } + ret = Py_BuildValue("i", tsk_tbl_collection_equals(self->tables, other->tables)); +out: + return ret; +} + +static PyGetSetDef TableCollection_getsetters[] = { + {"individuals", (getter) TableCollection_get_individuals, NULL, "The individual table."}, + {"nodes", (getter) TableCollection_get_nodes, NULL, "The node table."}, + {"edges", (getter) TableCollection_get_edges, NULL, "The edge table."}, + {"migrations", (getter) TableCollection_get_migrations, NULL, "The migration table."}, + {"sites", (getter) TableCollection_get_sites, NULL, "The site table."}, + {"mutations", (getter) TableCollection_get_mutations, NULL, "The mutation table."}, + {"populations", (getter) TableCollection_get_populations, NULL, "The population table."}, + {"provenances", (getter) TableCollection_get_provenances, NULL, "The provenance table."}, + {"sequence_length", (getter) TableCollection_get_sequence_length, NULL, + "The sequence length."}, + {"file_uuid", (getter) TableCollection_get_file_uuid, NULL, + "The UUID of the corresponding file."}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef TableCollection_methods[] = { + {"simplify", (PyCFunction) TableCollection_simplify, METH_VARARGS|METH_KEYWORDS, + "Simplifies for a given sample subset." }, + {"sort", (PyCFunction) TableCollection_sort, METH_VARARGS|METH_KEYWORDS, + "Sorts the tables to satisfy tree sequence requirements." }, + {"equals", (PyCFunction) TableCollection_equals, METH_VARARGS, + "Returns True if the parameter table collection is equal to this one." }, + {"compute_mutation_parents", (PyCFunction) TableCollection_compute_mutation_parents, + METH_NOARGS, "Computes the mutation parents for a the tables." }, + {"deduplicate_sites", (PyCFunction) TableCollection_deduplicate_sites, + METH_NOARGS, "Removes sites with duplicate positions." }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject TableCollectionType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.TableCollection", /* tp_name */ + sizeof(TableCollection), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)TableCollection_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + "TableCollection objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + TableCollection_methods, /* tp_methods */ + 0, /* tp_members */ + TableCollection_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)TableCollection_init, /* tp_init */ +}; + +/*=================================================================== + * TreeSequence + *=================================================================== + */ + +static int +TreeSequence_check_tree_sequence(TreeSequence *self) +{ + int ret = 0; + if (self->tree_sequence == NULL) { + PyErr_SetString(PyExc_ValueError, "tree_sequence not initialised"); + ret = -1; + } + return ret; +} + +static void +TreeSequence_dealloc(TreeSequence* self) +{ + if (self->tree_sequence != NULL) { + tsk_treeseq_free(self->tree_sequence); + PyMem_Free(self->tree_sequence); + self->tree_sequence = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +TreeSequence_alloc(TreeSequence *self) +{ + int ret = -1; + + if (self->tree_sequence != NULL) { + tsk_treeseq_free(self->tree_sequence); + PyMem_Free(self->tree_sequence); + } + self->tree_sequence = PyMem_Malloc(sizeof(tsk_treeseq_t)); + if (self->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + memset(self->tree_sequence, 0, sizeof(*self->tree_sequence)); + ret = 0; +out: + return ret; +} + +static int +TreeSequence_init(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + self->tree_sequence = NULL; + return 0; +} + +static PyObject * +TreeSequence_dump(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + char *path; + PyObject *ret = NULL; + int flags = 0; + static char *kwlist[] = {"path", NULL}; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", kwlist, &path)) { + goto out; + } + err = tsk_treeseq_dump(self->tree_sequence, path, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TreeSequence_load_tables(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + TableCollection *tables = NULL; + static char *kwlist[] = {"tables", NULL}; + /* TODO add an interface to turn this on and off. */ + int flags = TSK_BUILD_INDEXES; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TableCollectionType, &tables)) { + goto out; + } + err = TreeSequence_alloc(self); + if (err != 0) { + goto out; + } + err = tsk_treeseq_alloc(self->tree_sequence, tables->tables, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TreeSequence_dump_tables(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + TableCollection *tables = NULL; + static char *kwlist[] = {"tables", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TableCollectionType, &tables)) { + goto out; + } + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + err = tsk_treeseq_copy_tables(self->tree_sequence, tables->tables); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TreeSequence_load(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + int err; + char *path; + int flags = 0; + PyObject *ret = NULL; + static char *kwlist[] = {"path", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", kwlist, &path)) { + goto out; + } + err = TreeSequence_alloc(self); + if (err != 0) { + goto out; + } + err = tsk_treeseq_load(self->tree_sequence, path, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +TreeSequence_get_node(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_node_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_nodes(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_node(self->tree_sequence, (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_node(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_edge(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_edge_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_edges(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_edge(self->tree_sequence, (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_edge(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_migration(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_migration_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_migrations( + self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_migration(self->tree_sequence, + (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_migration(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_site(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_site_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_sites(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_site(self->tree_sequence, (tsk_id_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_site_object(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_mutation(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_mutation_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_mutations( + self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_mutation(self->tree_sequence, + (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_mutation(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_individual(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_individual_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_individuals(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_individual(self->tree_sequence, (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_individual_object(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_population(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_population_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_populations(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_population(self->tree_sequence, (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_population(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_provenance(TreeSequence *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t record_index, num_records; + tsk_provenance_t record; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "n", &record_index)) { + goto out; + } + num_records = (Py_ssize_t) tsk_treeseq_get_num_provenances(self->tree_sequence); + if (record_index < 0 || record_index >= num_records) { + PyErr_SetString(PyExc_IndexError, "record index out of bounds"); + goto out; + } + err = tsk_treeseq_get_provenance(self->tree_sequence, (size_t) record_index, &record); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = make_provenance(&record); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_edges(TreeSequence *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_records; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_records = tsk_treeseq_get_num_edges(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_records); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_migrations(TreeSequence *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_records; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_records = tsk_treeseq_get_num_migrations(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_records); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_individuals(TreeSequence *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_records; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_records = tsk_treeseq_get_num_individuals(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_records); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_populations(TreeSequence *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_records; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_records = tsk_treeseq_get_num_populations(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_records); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_trees(TreeSequence *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_trees; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_trees = tsk_treeseq_get_num_trees(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_trees); +out: + return ret; +} + +static PyObject * +TreeSequence_get_sequence_length(TreeSequence *self) +{ + PyObject *ret = NULL; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + ret = Py_BuildValue("d", + tsk_treeseq_get_sequence_length(self->tree_sequence)); +out: + return ret; +} + +static PyObject * +TreeSequence_get_file_uuid(TreeSequence *self) +{ + PyObject *ret = NULL; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + ret = Py_BuildValue("s", tsk_treeseq_get_file_uuid(self->tree_sequence)); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_samples(TreeSequence *self) +{ + PyObject *ret = NULL; + size_t num_samples; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_samples); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_nodes(TreeSequence *self) +{ + PyObject *ret = NULL; + size_t num_nodes; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_nodes = tsk_treeseq_get_num_nodes(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_nodes); +out: + return ret; +} + +static PyObject * +TreeSequence_get_samples(TreeSequence *self) +{ + PyObject *ret = NULL; + tsk_id_t *samples; + PyObject *py_samples = NULL; + PyObject *py_int = NULL; + size_t j, n; + int err; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + n = tsk_treeseq_get_num_samples(self->tree_sequence); + err = tsk_treeseq_get_samples(self->tree_sequence, &samples); + if (err != 0) { + handle_library_error(err); + } + py_samples = PyList_New(n); + if (py_samples == NULL) { + goto out; + } + for (j = 0; j < n; j++) { + py_int = Py_BuildValue("i", (int) samples[j]); + if (py_int == NULL) { + Py_DECREF(py_samples); + goto out; + } + PyList_SET_ITEM(py_samples, j, py_int); + } + ret = py_samples; +out: + return ret; +} + +static PyObject * +TreeSequence_get_pairwise_diversity(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + PyObject *py_samples = NULL; + static char *kwlist[] = {"samples", NULL}; + tsk_id_t *samples = NULL; + size_t num_samples = 0; + double pi; + int err; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &PyList_Type, &py_samples)) { + goto out; + } + if (parse_sample_ids(py_samples, self->tree_sequence, &num_samples, &samples) != 0) { + goto out; + } + err = tsk_treeseq_get_pairwise_diversity( + self->tree_sequence, samples, (uint32_t) num_samples, &pi); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("d", pi); +out: + if (samples != NULL) { + PyMem_Free(samples); + } + + return ret; +} + +static PyObject * +TreeSequence_genealogical_nearest_neighbours(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = {"focal", "reference_sets", NULL}; + tsk_id_t **reference_sets = NULL; + size_t *reference_set_size = NULL; + PyObject *focal = NULL; + PyObject *reference_sets_list = NULL; + PyArrayObject *focal_array = NULL; + PyArrayObject **reference_set_arrays = NULL; + PyArrayObject *ret_array = NULL; + npy_intp *shape, dims[2]; + size_t num_focal = 0; + size_t num_reference_sets = 0; + size_t j; + int err; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO!", kwlist, + &focal, &PyList_Type, &reference_sets_list)) { + goto out; + } + + /* We're releasing the GIL here so we need to make sure that the memory we + * pass to the low-level code doesn't change while it's in use. This is + * why we take copies of the input arrays. */ + focal_array = (PyArrayObject *) PyArray_FROMANY(focal, NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY|NPY_ARRAY_ENSURECOPY); + if (focal_array == NULL) { + goto out; + } + shape = PyArray_DIMS(focal_array); + num_focal = shape[0]; + num_reference_sets = PyList_Size(reference_sets_list); + if (num_reference_sets == 0) { + PyErr_SetString(PyExc_ValueError, "Must have at least one sample set"); + goto out; + } + reference_set_size = PyMem_Malloc(num_reference_sets * sizeof(*reference_set_size)); + reference_sets = PyMem_Malloc(num_reference_sets * sizeof(*reference_sets)); + reference_set_arrays = PyMem_Malloc(num_reference_sets * sizeof(*reference_set_arrays)); + if (reference_sets == NULL || reference_set_size == NULL || reference_set_arrays == NULL) { + goto out; + } + memset(reference_set_arrays, 0, num_reference_sets * sizeof(*reference_set_arrays)); + for (j = 0; j < num_reference_sets; j++) { + reference_set_arrays[j] = (PyArrayObject *) PyArray_FROMANY( + PyList_GetItem(reference_sets_list, j), NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY|NPY_ARRAY_ENSURECOPY); + if (reference_set_arrays[j] == NULL) { + goto out; + } + reference_sets[j] = PyArray_DATA(reference_set_arrays[j]); + shape = PyArray_DIMS(reference_set_arrays[j]); + reference_set_size[j] = shape[0]; + } + + /* Allocate the return array */ + dims[0] = num_focal; + dims[1] = num_reference_sets; + ret_array = (PyArrayObject *) PyArray_SimpleNew(2, dims, NPY_FLOAT64); + if (ret_array == NULL) { + goto out; + } + + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_genealogical_nearest_neighbours(self->tree_sequence, + PyArray_DATA(focal_array), num_focal, + reference_sets, reference_set_size, num_reference_sets, + 0, PyArray_DATA(ret_array)); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = (PyObject *) ret_array; + ret_array = NULL; +out: + if (reference_sets != NULL) { + PyMem_Free(reference_sets); + } + if (reference_set_size != NULL) { + PyMem_Free(reference_set_size); + } + if (reference_set_arrays != NULL) { + for (j = 0; j < num_reference_sets; j++) { + Py_XDECREF(reference_set_arrays[j]); + } + PyMem_Free(reference_set_arrays); + } + Py_XDECREF(focal_array); + Py_XDECREF(ret_array); + return ret; +} + +static PyObject * +TreeSequence_mean_descendants(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = {"reference_sets", NULL}; + tsk_id_t **reference_sets = NULL; + size_t *reference_set_size = NULL; + PyObject *reference_sets_list = NULL; + PyArrayObject **reference_set_arrays = NULL; + PyArrayObject *ret_array = NULL; + npy_intp *shape, dims[2]; + size_t num_reference_sets = 0; + size_t j; + int err; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &PyList_Type, &reference_sets_list)) { + goto out; + } + + num_reference_sets = PyList_Size(reference_sets_list); + if (num_reference_sets == 0) { + PyErr_SetString(PyExc_ValueError, "Must have at least one sample set"); + goto out; + } + reference_set_size = PyMem_Malloc(num_reference_sets * sizeof(*reference_set_size)); + reference_sets = PyMem_Malloc(num_reference_sets * sizeof(*reference_sets)); + reference_set_arrays = PyMem_Malloc(num_reference_sets * sizeof(*reference_set_arrays)); + if (reference_sets == NULL || reference_set_size == NULL || reference_set_arrays == NULL) { + goto out; + } + memset(reference_set_arrays, 0, num_reference_sets * sizeof(*reference_set_arrays)); + for (j = 0; j < num_reference_sets; j++) { + /* We're releasing the GIL here so we need to make sure that the memory we + * pass to the low-level code doesn't change while it's in use. This is + * why we take copies of the input arrays. */ + reference_set_arrays[j] = (PyArrayObject *) PyArray_FROMANY( + PyList_GetItem(reference_sets_list, j), NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY|NPY_ARRAY_ENSURECOPY); + if (reference_set_arrays[j] == NULL) { + goto out; + } + reference_sets[j] = PyArray_DATA(reference_set_arrays[j]); + shape = PyArray_DIMS(reference_set_arrays[j]); + reference_set_size[j] = shape[0]; + } + + /* Allocate the return array */ + dims[0] = tsk_treeseq_get_num_nodes(self->tree_sequence); + dims[1] = num_reference_sets; + ret_array = (PyArrayObject *) PyArray_SimpleNew(2, dims, NPY_FLOAT64); + if (ret_array == NULL) { + goto out; + } + + Py_BEGIN_ALLOW_THREADS + err = tsk_treeseq_mean_descendants(self->tree_sequence, + reference_sets, reference_set_size, num_reference_sets, + 0, PyArray_DATA(ret_array)); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = (PyObject *) ret_array; + ret_array = NULL; +out: + if (reference_sets != NULL) { + PyMem_Free(reference_sets); + } + if (reference_set_size != NULL) { + PyMem_Free(reference_set_size); + } + if (reference_set_arrays != NULL) { + for (j = 0; j < num_reference_sets; j++) { + Py_XDECREF(reference_set_arrays[j]); + } + PyMem_Free(reference_set_arrays); + } + Py_XDECREF(ret_array); + return ret; +} + +static PyObject * +TreeSequence_get_num_mutations(TreeSequence *self) +{ + PyObject *ret = NULL; + size_t num_mutations; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_mutations = tsk_treeseq_get_num_mutations(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_mutations); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_sites(TreeSequence *self) +{ + PyObject *ret = NULL; + size_t num_sites; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_sites = tsk_treeseq_get_num_sites(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_sites); +out: + return ret; +} + +static PyObject * +TreeSequence_get_num_provenances(TreeSequence *self) +{ + PyObject *ret = NULL; + size_t num_provenances; + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_provenances = tsk_treeseq_get_num_provenances(self->tree_sequence); + ret = Py_BuildValue("n", (Py_ssize_t) num_provenances); +out: + return ret; +} + +static PyObject * +TreeSequence_get_genotype_matrix(TreeSequence *self) +{ + PyObject *ret = NULL; + int err; + size_t num_sites; + size_t num_samples; + npy_intp dims[2]; + PyArrayObject *genotype_matrix = NULL; + tsk_vargen_t *vg = NULL; + char *V; + tsk_variant_t *variant; + size_t j; + + /* TODO add option for 16 bit genotypes */ + + if (TreeSequence_check_tree_sequence(self) != 0) { + goto out; + } + num_sites = tsk_treeseq_get_num_sites(self->tree_sequence); + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + dims[0] = num_sites; + dims[1] = num_samples; + + genotype_matrix = (PyArrayObject *) PyArray_SimpleNew(2, dims, NPY_UINT8); + if (genotype_matrix == NULL) { + goto out; + } + V = (char *) PyArray_DATA(genotype_matrix); + vg = PyMem_Malloc(sizeof(tsk_vargen_t)); + if (vg == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_vargen_alloc(vg, self->tree_sequence, NULL, 0, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + j = 0; + while ((err = tsk_vargen_next(vg, &variant)) == 1) { + memcpy(V + (j * num_samples), variant->genotypes.u8, num_samples * sizeof(uint8_t)); + j++; + } + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) genotype_matrix; + genotype_matrix = NULL; +out: + if (vg != NULL) { + tsk_vargen_free(vg); + PyMem_Free(vg); + } + Py_XDECREF(genotype_matrix); + return ret; +} + +static PyMemberDef TreeSequence_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef TreeSequence_methods[] = { + {"dump", (PyCFunction) TreeSequence_dump, + METH_VARARGS|METH_KEYWORDS, + "Writes the tree sequence out to the specified path."}, + {"load", (PyCFunction) TreeSequence_load, + METH_VARARGS|METH_KEYWORDS, + "Loads a tree sequence from the specified path."}, + {"load_tables", (PyCFunction) TreeSequence_load_tables, + METH_VARARGS|METH_KEYWORDS, + "Loads a tree sequence from the specified set of tables"}, + {"dump_tables", (PyCFunction) TreeSequence_dump_tables, + METH_VARARGS|METH_KEYWORDS, + "Dumps the tree sequence to the specified set of tables"}, + {"get_node", + (PyCFunction) TreeSequence_get_node, METH_VARARGS, + "Returns the node record at the specified index."}, + {"get_edge", + (PyCFunction) TreeSequence_get_edge, METH_VARARGS, + "Returns the edge record at the specified index."}, + {"get_migration", + (PyCFunction) TreeSequence_get_migration, METH_VARARGS, + "Returns the migration record at the specified index."}, + {"get_site", + (PyCFunction) TreeSequence_get_site, METH_VARARGS, + "Returns the mutation type record at the specified index."}, + {"get_mutation", + (PyCFunction) TreeSequence_get_mutation, METH_VARARGS, + "Returns the mutation record at the specified index."}, + {"get_individual", + (PyCFunction) TreeSequence_get_individual, METH_VARARGS, + "Returns the individual record at the specified index."}, + {"get_population", + (PyCFunction) TreeSequence_get_population, METH_VARARGS, + "Returns the population record at the specified index."}, + {"get_provenance", + (PyCFunction) TreeSequence_get_provenance, METH_VARARGS, + "Returns the provenance record at the specified index."}, + {"get_num_edges", (PyCFunction) TreeSequence_get_num_edges, + METH_NOARGS, "Returns the number of coalescence records." }, + {"get_num_migrations", (PyCFunction) TreeSequence_get_num_migrations, + METH_NOARGS, "Returns the number of migration records." }, + {"get_num_populations", (PyCFunction) TreeSequence_get_num_populations, + METH_NOARGS, "Returns the number of population records." }, + {"get_num_individuals", (PyCFunction) TreeSequence_get_num_individuals, + METH_NOARGS, "Returns the number of individual records." }, + {"get_num_trees", (PyCFunction) TreeSequence_get_num_trees, + METH_NOARGS, "Returns the number of trees in the tree sequence." }, + {"get_sequence_length", (PyCFunction) TreeSequence_get_sequence_length, + METH_NOARGS, "Returns the sequence length in bases." }, + {"get_file_uuid", (PyCFunction) TreeSequence_get_file_uuid, + METH_NOARGS, "Returns the UUID of the underlying file, if present." }, + {"get_num_sites", (PyCFunction) TreeSequence_get_num_sites, + METH_NOARGS, "Returns the number of sites" }, + {"get_num_mutations", (PyCFunction) TreeSequence_get_num_mutations, METH_NOARGS, + "Returns the number of mutations" }, + {"get_num_provenances", (PyCFunction) TreeSequence_get_num_provenances, + METH_NOARGS, "Returns the number of provenances" }, + {"get_num_nodes", (PyCFunction) TreeSequence_get_num_nodes, METH_NOARGS, + "Returns the number of unique nodes in the tree sequence." }, + {"get_num_samples", (PyCFunction) TreeSequence_get_num_samples, METH_NOARGS, + "Returns the sample size" }, + {"get_samples", (PyCFunction) TreeSequence_get_samples, METH_NOARGS, + "Returns the samples." }, + {"get_pairwise_diversity", + (PyCFunction) TreeSequence_get_pairwise_diversity, + METH_VARARGS|METH_KEYWORDS, "Returns the average pairwise diversity." }, + {"genealogical_nearest_neighbours", + (PyCFunction) TreeSequence_genealogical_nearest_neighbours, + METH_VARARGS|METH_KEYWORDS, "Returns the genealogical nearest neighbours statistic." }, + {"mean_descendants", + (PyCFunction) TreeSequence_mean_descendants, + METH_VARARGS|METH_KEYWORDS, "Returns the mean number of nodes descending from each node." }, + {"get_genotype_matrix", (PyCFunction) TreeSequence_get_genotype_matrix, METH_NOARGS, + "Returns the genotypes matrix." }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject TreeSequenceType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.TreeSequence", /* tp_name */ + sizeof(TreeSequence), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)TreeSequence_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "TreeSequence objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + TreeSequence_methods, /* tp_methods */ + TreeSequence_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)TreeSequence_init, /* tp_init */ +}; + +/*=================================================================== + * Tree + *=================================================================== + */ + +static int +Tree_check_tree(Tree *self) +{ + int ret = 0; + if (self->tree == NULL) { + PyErr_SetString(PyExc_RuntimeError, "tree not initialised"); + ret = -1; + } + return ret; +} + +static int +Tree_check_bounds(Tree *self, int node) +{ + int ret = 0; + if (node < 0 || node >= self->tree->num_nodes) { + PyErr_SetString(PyExc_ValueError, "Node index out of bounds"); + ret = -1; + } + return ret; +} + +static void +Tree_dealloc(Tree* self) +{ + if (self->tree != NULL) { + tsk_tree_free(self->tree); + PyMem_Free(self->tree); + self->tree = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +/* TODO this API should be updated to remove the TreeIterator object + * and instead support the first(), last() etc methods. Until some seeking + * function has been called, we should be in a state that errors if any + * methods are called. + * + * The _free method below is also probably redundant now and should be + * removed. + */ +static int +Tree_init(Tree *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", "flags", "tracked_samples", + NULL}; + PyObject *py_tracked_samples = NULL; + TreeSequence *tree_sequence = NULL; + tsk_id_t *tracked_samples = NULL; + int flags = 0; + uint32_t j, num_tracked_samples, num_nodes; + PyObject *item; + + self->tree = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|iO!", kwlist, + &TreeSequenceType, &tree_sequence, + &flags, &PyList_Type, &py_tracked_samples)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(tree_sequence) != 0) { + goto out; + } + num_nodes = tsk_treeseq_get_num_nodes(tree_sequence->tree_sequence); + num_tracked_samples = 0; + if (py_tracked_samples != NULL) { + if (!(flags & TSK_SAMPLE_COUNTS)) { + PyErr_SetString(PyExc_ValueError, + "Cannot specified tracked_samples without count_samples flag"); + goto out; + } + num_tracked_samples = PyList_Size(py_tracked_samples); + } + tracked_samples = PyMem_Malloc(num_tracked_samples * sizeof(tsk_id_t)); + if (tracked_samples == NULL) { + PyErr_NoMemory(); + goto out; + } + for (j = 0; j < num_tracked_samples; j++) { + item = PyList_GetItem(py_tracked_samples, j); + if (!PyNumber_Check(item)) { + PyErr_SetString(PyExc_TypeError, "sample must be a number"); + goto out; + } + tracked_samples[j] = (tsk_id_t) PyLong_AsLong(item); + if (tracked_samples[j] >= num_nodes) { + PyErr_SetString(PyExc_ValueError, "samples must be valid nodes"); + goto out; + } + } + self->tree = PyMem_Malloc(sizeof(tsk_tree_t)); + if (self->tree == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_tree_alloc(self->tree, tree_sequence->tree_sequence, + flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + if (!!(flags & TSK_SAMPLE_COUNTS)) { + err = tsk_tree_set_tracked_samples(self->tree, num_tracked_samples, + tracked_samples); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + ret = 0; +out: + if (tracked_samples != NULL) { + PyMem_Free(tracked_samples); + } + return ret; +} + +/* TODO this should be redundant; remove */ +static PyObject * +Tree_free(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + /* This method is need because we have dangling references to + * trees after a for loop and we can't run set_sites. + */ + tsk_tree_free(self->tree); + PyMem_Free(self->tree); + self->tree = NULL; + ret = Py_BuildValue(""); +out: + return ret; +} + +static PyObject * +Tree_get_sample_size(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->tree->tree_sequence->num_samples); +out: + return ret; +} + +static PyObject * +Tree_get_num_nodes(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->tree->num_nodes); +out: + return ret; +} + +static PyObject * +Tree_get_num_roots(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) tsk_tree_get_num_roots(self->tree)); +out: + return ret; +} + +static PyObject * +Tree_get_index(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->tree->index); +out: + return ret; +} + +static PyObject * +Tree_get_left_root(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("i", (int) self->tree->left_root); +out: + return ret; +} + +static PyObject * +Tree_get_left(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("d", self->tree->left); +out: + return ret; +} + +static PyObject * +Tree_get_right(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("d", self->tree->right); +out: + return ret; +} + +static PyObject * +Tree_get_flags(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("i", self->tree->flags); +out: + return ret; +} + +static int +Tree_get_node_argument(Tree *self, PyObject *args, int *node) +{ + int ret = -1; + if (Tree_check_tree(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "I", node)) { + goto out; + } + if (Tree_check_bounds(self, *node)) { + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +Tree_is_sample(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + ret = Py_BuildValue("i", tsk_tree_is_sample(self->tree, (tsk_id_t) node)); +out: + return ret; +} + +static PyObject * +Tree_get_parent(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t parent; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + parent = self->tree->parent[node]; + ret = Py_BuildValue("i", (int) parent); +out: + return ret; +} + +static PyObject * +Tree_get_population(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_node_t node; + int node_id, err; + + if (Tree_get_node_argument(self, args, &node_id) != 0) { + goto out; + } + err = tsk_treeseq_get_node(self->tree->tree_sequence, node_id, &node); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", (int) node.population); +out: + return ret; +} + +static PyObject * +Tree_get_time(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + double time; + int node, err; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + err = tsk_tree_get_time(self->tree, node, &time); + if (ret != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("d", time); +out: + return ret; +} + +static PyObject * +Tree_get_left_child(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t child; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + child = self->tree->left_child[node]; + ret = Py_BuildValue("i", (int) child); +out: + return ret; +} + +static PyObject * +Tree_get_right_child(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t child; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + child = self->tree->right_child[node]; + ret = Py_BuildValue("i", (int) child); +out: + return ret; +} + +static PyObject * +Tree_get_left_sib(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t sib; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + sib = self->tree->left_sib[node]; + ret = Py_BuildValue("i", (int) sib); +out: + return ret; +} + +static PyObject * +Tree_get_right_sib(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t sib; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + sib = self->tree->right_sib[node]; + ret = Py_BuildValue("i", (int) sib); +out: + return ret; +} + +static PyObject * +Tree_get_children(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + int node; + tsk_id_t u; + size_t j, num_children; + tsk_id_t *children = NULL; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + num_children = 0; + for (u = self->tree->left_child[node]; u != TSK_NULL; u = self->tree->right_sib[u]) { + num_children++; + } + children = PyMem_Malloc(num_children * sizeof(tsk_id_t)); + if (children == NULL) { + PyErr_NoMemory(); + goto out; + } + j = 0; + for (u = self->tree->left_child[node]; u != TSK_NULL; u = self->tree->right_sib[u]) { + children[j] = u; + j++; + } + ret = convert_node_id_list(children, num_children); +out: + if (children != NULL) { + PyMem_Free(children); + } + return ret; +} + +static bool +Tree_check_sample_list(Tree *self) +{ + bool ret = tsk_tree_has_sample_lists(self->tree); + if (! ret) { + PyErr_SetString(PyExc_ValueError, + "Sample lists not supported. Please set sample_lists=True."); + } + return ret; +} + +static PyObject * +Tree_get_right_sample(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t sample_index; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + if (!Tree_check_sample_list(self)) { + goto out; + } + sample_index = self->tree->right_sample[node]; + ret = Py_BuildValue("i", (int) sample_index); +out: + return ret; +} + +static PyObject * +Tree_get_left_sample(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t sample_index; + int node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + if (!Tree_check_sample_list(self)) { + goto out; + } + sample_index = self->tree->left_sample[node]; + ret = Py_BuildValue("i", (int) sample_index); +out: + return ret; +} + +static PyObject * +Tree_get_next_sample(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t out_index; + int in_index, num_samples; + + if (Tree_check_tree(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "I", &in_index)) { + goto out; + } + num_samples = (int) tsk_treeseq_get_num_samples(self->tree->tree_sequence); + if (in_index < 0 || in_index >= num_samples) { + PyErr_SetString(PyExc_ValueError, "Sample index out of bounds"); + goto out; + } + if (!Tree_check_sample_list(self)) { + goto out; + } + out_index = self->tree->next_sample[in_index]; + ret = Py_BuildValue("i", (int) out_index); +out: + return ret; +} + +static PyObject * +Tree_get_mrca(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + tsk_id_t mrca; + int u, v; + + if (Tree_check_tree(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "ii", &u, &v)) { + goto out; + } + if (Tree_check_bounds(self, u)) { + goto out; + } + if (Tree_check_bounds(self, v)) { + goto out; + } + err = tsk_tree_get_mrca(self->tree, (tsk_id_t) u, + (tsk_id_t) v, &mrca); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("i", (int) mrca); +out: + return ret; +} + +static PyObject * +Tree_get_num_samples(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_samples; + int err, node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + err = tsk_tree_get_num_samples(self->tree, (tsk_id_t) node, + &num_samples); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("I", (unsigned int) num_samples); +out: + return ret; +} + +static PyObject * +Tree_get_num_tracked_samples(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + size_t num_tracked_samples; + int err, node; + + if (Tree_get_node_argument(self, args, &node) != 0) { + goto out; + } + err = tsk_tree_get_num_tracked_samples(self->tree, (tsk_id_t) node, + &num_tracked_samples); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("I", (unsigned int) num_tracked_samples); +out: + return ret; +} + +static PyObject * +Tree_get_sites(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = convert_sites(self->tree->sites, self->tree->sites_length); +out: + return ret; +} + +static PyObject * +Tree_get_num_sites(Tree *self) +{ + PyObject *ret = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) self->tree->sites_length); +out: + return ret; +} + +static PyObject * +Tree_get_newick(Tree *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = {"root", "precision", NULL}; + int precision = 14; + int root, err; + size_t buffer_size; + char *buffer = NULL; + + if (Tree_check_tree(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "i|i", kwlist, &root, &precision)) { + goto out; + } + if (precision < 0 || precision > 16) { + PyErr_SetString(PyExc_ValueError, "Precision must be between 0 and 16, inclusive"); + goto out; + } + buffer_size = tsk_treeseq_get_num_nodes(self->tree->tree_sequence); + /* For every node, we have roughly precision bytes, plus bracketing and leading values. + * This is a rough guess, so add 10 just to be on the safe side. We might need + * to be more precise with this though if we have large time values. + */ + buffer_size *= precision + 10; + buffer = PyMem_Malloc(buffer_size); + if (buffer == NULL) { + PyErr_NoMemory(); + } + err = tsk_convert_newick(self->tree, (tsk_id_t) root, precision, 0, + buffer_size, buffer); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = PyBytes_FromString(buffer); +out: + if (buffer != NULL) { + PyMem_Free(buffer); + } + return ret; +} + +static PyMemberDef Tree_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef Tree_methods[] = { + {"free", (PyCFunction) Tree_free, METH_NOARGS, + "Frees the underlying tree object." }, + {"get_sample_size", (PyCFunction) Tree_get_sample_size, METH_NOARGS, + "Returns the number of samples in this tree." }, + {"get_num_nodes", (PyCFunction) Tree_get_num_nodes, METH_NOARGS, + "Returns the number of nodes in the sparse tree." }, + {"get_num_roots", (PyCFunction) Tree_get_num_roots, METH_NOARGS, + "Returns the number of roots in the sparse tree." }, + {"get_index", (PyCFunction) Tree_get_index, METH_NOARGS, + "Returns the index this tree occupies within the tree sequence." }, + {"get_left_root", (PyCFunction) Tree_get_left_root, METH_NOARGS, + "Returns the root of the tree." }, + {"get_left", (PyCFunction) Tree_get_left, METH_NOARGS, + "Returns the left-most coordinate (inclusive)." }, + {"get_right", (PyCFunction) Tree_get_right, METH_NOARGS, + "Returns the right-most coordinate (exclusive)." }, + {"get_sites", (PyCFunction) Tree_get_sites, METH_NOARGS, + "Returns the list of sites on this tree." }, + {"get_flags", (PyCFunction) Tree_get_flags, METH_NOARGS, + "Returns the value of the flags variable." }, + {"get_num_sites", (PyCFunction) Tree_get_num_sites, METH_NOARGS, + "Returns the number of sites on this tree." }, + {"is_sample", (PyCFunction) Tree_is_sample, METH_VARARGS, + "Returns True if the specified node is a sample." }, + {"get_parent", (PyCFunction) Tree_get_parent, METH_VARARGS, + "Returns the parent of node u" }, + {"get_time", (PyCFunction) Tree_get_time, METH_VARARGS, + "Returns the time of node u" }, + {"get_population", (PyCFunction) Tree_get_population, METH_VARARGS, + "Returns the population of node u" }, + {"get_left_child", (PyCFunction) Tree_get_left_child, METH_VARARGS, + "Returns the left-most child of node u" }, + {"get_right_child", (PyCFunction) Tree_get_right_child, METH_VARARGS, + "Returns the right-most child of node u" }, + {"get_left_sib", (PyCFunction) Tree_get_left_sib, METH_VARARGS, + "Returns the left-most sib of node u" }, + {"get_right_sib", (PyCFunction) Tree_get_right_sib, METH_VARARGS, + "Returns the right-most sib of node u" }, + {"get_children", (PyCFunction) Tree_get_children, METH_VARARGS, + "Returns the children of u in left-right order." }, + {"get_left_sample", (PyCFunction) Tree_get_left_sample, METH_VARARGS, + "Returns the index of the left-most sample descending from u." }, + {"get_right_sample", (PyCFunction) Tree_get_right_sample, METH_VARARGS, + "Returns the index of the right-most sample descending from u." }, + {"get_next_sample", (PyCFunction) Tree_get_next_sample, METH_VARARGS, + "Returns the index of the next sample after the specified sample index." }, + {"get_mrca", (PyCFunction) Tree_get_mrca, METH_VARARGS, + "Returns the MRCA of nodes u and v" }, + {"get_num_samples", (PyCFunction) Tree_get_num_samples, METH_VARARGS, + "Returns the number of samples below node u." }, + {"get_num_tracked_samples", (PyCFunction) Tree_get_num_tracked_samples, + METH_VARARGS, + "Returns the number of tracked samples below node u." }, + {"get_newick", (PyCFunction) Tree_get_newick, + METH_VARARGS|METH_KEYWORDS, + "Returns the newick representation of this tree." }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject TreeType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.Tree", /* tp_name */ + sizeof(Tree), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)Tree_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "Tree objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tree_methods, /* tp_methods */ + Tree_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Tree_init, /* tp_init */ +}; + + + +/*=================================================================== + * TreeDiffIterator + *=================================================================== + */ + +static int +TreeDiffIterator_check_state(TreeDiffIterator *self) +{ + int ret = 0; + if (self->tree_diff_iterator == NULL) { + PyErr_SetString(PyExc_SystemError, "iterator not initialised"); + ret = -1; + } + return ret; +} + +static void +TreeDiffIterator_dealloc(TreeDiffIterator* self) +{ + if (self->tree_diff_iterator != NULL) { + tsk_diff_iter_free(self->tree_diff_iterator); + PyMem_Free(self->tree_diff_iterator); + self->tree_diff_iterator = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +TreeDiffIterator_init(TreeDiffIterator *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", NULL}; + TreeSequence *tree_sequence; + + self->tree_diff_iterator = NULL; + self->tree_sequence = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TreeSequenceType, &tree_sequence)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(self->tree_sequence) != 0) { + goto out; + } + self->tree_diff_iterator = PyMem_Malloc(sizeof(tsk_diff_iter_t)); + if (self->tree_diff_iterator == NULL) { + PyErr_NoMemory(); + goto out; + } + memset(self->tree_diff_iterator, 0, sizeof(tsk_diff_iter_t)); + err = tsk_diff_iter_alloc(self->tree_diff_iterator, + self->tree_sequence->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +TreeDiffIterator_next(TreeDiffIterator *self) +{ + PyObject *ret = NULL; + PyObject *out_list = NULL; + PyObject *in_list = NULL; + PyObject *value = NULL; + int err; + double left, right; + size_t list_size, j; + tsk_edge_list_t *records_out, *records_in, *record; + + if (TreeDiffIterator_check_state(self) != 0) { + goto out; + } + err = tsk_diff_iter_next(self->tree_diff_iterator, &left, &right, + &records_out, &records_in); + if (err < 0) { + handle_library_error(err); + goto out; + } + if (err == 1) { + /* out records */ + record = records_out; + list_size = 0; + while (record != NULL) { + list_size++; + record = record->next; + } + out_list = PyList_New(list_size); + if (out_list == NULL) { + goto out; + } + record = records_out; + j = 0; + while (record != NULL) { + value = Py_BuildValue("ddii", record->edge.left, record->edge.right, + record->edge.parent, record->edge.child); + if (value == NULL) { + goto out; + } + PyList_SET_ITEM(out_list, j, value); + record = record->next; + j++; + } + /* in records */ + record = records_in; + list_size = 0; + while (record != NULL) { + list_size++; + record = record->next; + } + in_list = PyList_New(list_size); + if (in_list == NULL) { + goto out; + } + record = records_in; + j = 0; + while (record != NULL) { + value = Py_BuildValue("ddii", record->edge.left, record->edge.right, + record->edge.parent, record->edge.child); + if (value == NULL) { + goto out; + } + PyList_SET_ITEM(in_list, j, value); + record = record->next; + j++; + } + ret = Py_BuildValue("(dd)OO", left, right, out_list, in_list); + } +out: + Py_XDECREF(out_list); + Py_XDECREF(in_list); + return ret; +} + +static PyMemberDef TreeDiffIterator_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef TreeDiffIterator_methods[] = { + {NULL} /* Sentinel */ +}; + +static PyTypeObject TreeDiffIteratorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.TreeDiffIterator", /* tp_name */ + sizeof(TreeDiffIterator), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)TreeDiffIterator_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "TreeDiffIterator objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc) TreeDiffIterator_next, /* tp_iternext */ + TreeDiffIterator_methods, /* tp_methods */ + TreeDiffIterator_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)TreeDiffIterator_init, /* tp_init */ +}; + +/*=================================================================== + * TreeIterator + *=================================================================== + */ + +static int +TreeIterator_check_state(TreeIterator *self) +{ + int ret = 0; + return ret; +} + +static void +TreeIterator_dealloc(TreeIterator* self) +{ + Py_XDECREF(self->tree); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +TreeIterator_init(TreeIterator *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + static char *kwlist[] = {"tree", NULL}; + Tree *tree; + + self->first = 1; + self->tree = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TreeType, &tree)) { + goto out; + } + self->tree = tree; + Py_INCREF(self->tree); + if (Tree_check_tree(self->tree) != 0) { + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +TreeIterator_next(TreeIterator *self) +{ + PyObject *ret = NULL; + int err; + + if (TreeIterator_check_state(self) != 0) { + goto out; + } + + if (self->first) { + err = tsk_tree_first(self->tree->tree); + self->first = 0; + } else { + err = tsk_tree_next(self->tree->tree); + } + if (err < 0) { + handle_library_error(err); + goto out; + } + if (err == 1) { + ret = (PyObject *) self->tree; + Py_INCREF(ret); + } +out: + return ret; +} + +static PyMemberDef TreeIterator_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef TreeIterator_methods[] = { + {NULL} /* Sentinel */ +}; + +static PyTypeObject TreeIteratorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.TreeIterator", /* tp_name */ + sizeof(TreeIterator), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)TreeIterator_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "TreeIterator objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc) TreeIterator_next, /* tp_iternext */ + TreeIterator_methods, /* tp_methods */ + TreeIterator_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)TreeIterator_init, /* tp_init */ +}; + +/*=================================================================== + * VcfConverter + *=================================================================== + */ + +static int +VcfConverter_check_state(VcfConverter *self) +{ + int ret = 0; + if (self->tsk_vcf_converter == NULL) { + PyErr_SetString(PyExc_SystemError, "converter not initialised"); + ret = -1; + } + return ret; +} + +static void +VcfConverter_dealloc(VcfConverter* self) +{ + if (self->tsk_vcf_converter != NULL) { + tsk_vcf_converter_free(self->tsk_vcf_converter); + PyMem_Free(self->tsk_vcf_converter); + self->tsk_vcf_converter = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +VcfConverter_init(VcfConverter *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", "ploidy", "contig_id", NULL}; + unsigned int ploidy = 1; + const char *contig_id = "1"; + TreeSequence *tree_sequence; + + self->tsk_vcf_converter = NULL; + self->tree_sequence = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|Is", kwlist, + &TreeSequenceType, &tree_sequence, &ploidy, &contig_id)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(self->tree_sequence) != 0) { + goto out; + } + if (ploidy < 1) { + PyErr_SetString(PyExc_ValueError, "Ploidy must be >= 1"); + goto out; + } + if (strlen(contig_id) == 0) { + PyErr_SetString(PyExc_ValueError, "contig_id cannot be the empty string"); + goto out; + } + self->tsk_vcf_converter = PyMem_Malloc(sizeof(tsk_vcf_converter_t)); + if (self->tsk_vcf_converter == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_vcf_converter_alloc(self->tsk_vcf_converter, + self->tree_sequence->tree_sequence, ploidy, contig_id); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +VcfConverter_next(VcfConverter *self) +{ + PyObject *ret = NULL; + char *record; + int err; + + if (VcfConverter_check_state(self) != 0) { + goto out; + } + err = tsk_vcf_converter_next(self->tsk_vcf_converter, &record); + if (err < 0) { + handle_library_error(err); + goto out; + } + if (err == 1) { + ret = Py_BuildValue("s", record); + } +out: + return ret; +} + +static PyObject * +VcfConverter_get_header(VcfConverter *self) +{ + PyObject *ret = NULL; + int err; + char *header; + + if (VcfConverter_check_state(self) != 0) { + goto out; + } + err = tsk_vcf_converter_get_header(self->tsk_vcf_converter, &header); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("s", header); +out: + return ret; +} + +static PyMemberDef VcfConverter_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef VcfConverter_methods[] = { + {"get_header", (PyCFunction) VcfConverter_get_header, METH_NOARGS, + "Returns the VCF header as plain text." }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject VcfConverterType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.VcfConverter", /* tp_name */ + sizeof(VcfConverter), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)VcfConverter_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "VcfConverter objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc) VcfConverter_next, /* tp_iternext */ + VcfConverter_methods, /* tp_methods */ + VcfConverter_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)VcfConverter_init, /* tp_init */ +}; + +/*=================================================================== + * HaplotypeGenerator + *=================================================================== + */ + +static int +HaplotypeGenerator_check_state(HaplotypeGenerator *self) +{ + int ret = 0; + if (self->haplotype_generator == NULL) { + PyErr_SetString(PyExc_SystemError, "converter not initialised"); + ret = -1; + } + return ret; +} + +static void +HaplotypeGenerator_dealloc(HaplotypeGenerator* self) +{ + if (self->haplotype_generator != NULL) { + tsk_hapgen_free(self->haplotype_generator); + PyMem_Free(self->haplotype_generator); + self->haplotype_generator = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +HaplotypeGenerator_init(HaplotypeGenerator *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", NULL}; + TreeSequence *tree_sequence; + + self->haplotype_generator = NULL; + self->tree_sequence = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TreeSequenceType, &tree_sequence)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(self->tree_sequence) != 0) { + goto out; + } + self->haplotype_generator = PyMem_Malloc(sizeof(tsk_hapgen_t)); + if (self->haplotype_generator == NULL) { + PyErr_NoMemory(); + goto out; + } + memset(self->haplotype_generator, 0, sizeof(tsk_hapgen_t)); + err = tsk_hapgen_alloc(self->haplotype_generator, + self->tree_sequence->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +HaplotypeGenerator_get_haplotype(HaplotypeGenerator *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + char *haplotype; + unsigned int sample_id; + + if (HaplotypeGenerator_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "I", &sample_id)) { + goto out; + } + err = tsk_hapgen_get_haplotype(self->haplotype_generator, + (uint32_t) sample_id, &haplotype); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("s", haplotype); +out: + return ret; +} + +static PyMemberDef HaplotypeGenerator_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef HaplotypeGenerator_methods[] = { + {"get_haplotype", (PyCFunction) HaplotypeGenerator_get_haplotype, + METH_VARARGS, "Returns the haplotype for the specified sample"}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject HaplotypeGeneratorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.HaplotypeGenerator", /* tp_name */ + sizeof(HaplotypeGenerator), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)HaplotypeGenerator_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "HaplotypeGenerator objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + HaplotypeGenerator_methods, /* tp_methods */ + HaplotypeGenerator_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)HaplotypeGenerator_init, /* tp_init */ +}; + + +/*=================================================================== + * VariantGenerator + *=================================================================== + */ + +static int +VariantGenerator_check_state(VariantGenerator *self) +{ + int ret = 0; + if (self->variant_generator == NULL) { + PyErr_SetString(PyExc_SystemError, "converter not initialised"); + ret = -1; + } + return ret; +} + +static void +VariantGenerator_dealloc(VariantGenerator* self) +{ + if (self->variant_generator != NULL) { + tsk_vargen_free(self->variant_generator); + PyMem_Free(self->variant_generator); + self->variant_generator = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +VariantGenerator_init(VariantGenerator *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", "samples", NULL}; + TreeSequence *tree_sequence = NULL; + PyObject *samples_input = Py_None; + PyArrayObject *samples_array = NULL; + tsk_id_t *samples = NULL; + size_t num_samples = 0; + npy_intp *shape; + + /* TODO add option for 16 bit genotypes */ + self->variant_generator = NULL; + self->tree_sequence = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|O", kwlist, + &TreeSequenceType, &tree_sequence, &samples_input)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(self->tree_sequence) != 0) { + goto out; + } + if (samples_input != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY(samples_input, NPY_INT32, 1, 1, + NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + num_samples = (size_t) shape[0]; + samples = PyArray_DATA(samples_array); + } + self->variant_generator = PyMem_Malloc(sizeof(tsk_vargen_t)); + if (self->variant_generator == NULL) { + PyErr_NoMemory(); + goto out; + } + /* Note: the vargen currently takes a copy of the samples list. If we wanted + * to avoid this we would INCREF the samples array above and keep a reference + * to in the object struct */ + err = tsk_vargen_alloc(self->variant_generator, + self->tree_sequence->tree_sequence, samples, num_samples, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(samples_array); + return ret; +} + +static PyObject * +VariantGenerator_next(VariantGenerator *self) +{ + PyObject *ret = NULL; + tsk_variant_t *var; + int err; + + if (VariantGenerator_check_state(self) != 0) { + goto out; + } + err = tsk_vargen_next(self->variant_generator, &var); + if (err < 0) { + handle_library_error(err); + goto out; + } + if (err == 1) { + ret = make_variant(var, self->variant_generator->num_samples); + } +out: + return ret; +} + +static PyMemberDef VariantGenerator_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef VariantGenerator_methods[] = { + {NULL} /* Sentinel */ +}; + +static PyTypeObject VariantGeneratorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.VariantGenerator", /* tp_name */ + sizeof(VariantGenerator), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)VariantGenerator_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "VariantGenerator objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc) VariantGenerator_next, /* tp_iternext */ + VariantGenerator_methods, /* tp_methods */ + VariantGenerator_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)VariantGenerator_init, /* tp_init */ +}; + +/*=================================================================== + * LdCalculator + *=================================================================== + */ + +static int +LdCalculator_check_state(LdCalculator *self) +{ + int ret = 0; + if (self->ld_calc == NULL) { + PyErr_SetString(PyExc_SystemError, "converter not initialised"); + ret = -1; + } + return ret; +} + +static void +LdCalculator_dealloc(LdCalculator* self) +{ + if (self->ld_calc != NULL) { + tsk_ld_calc_free(self->ld_calc); + PyMem_Free(self->ld_calc); + self->ld_calc = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +LdCalculator_init(LdCalculator *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", NULL}; + TreeSequence *tree_sequence; + + self->ld_calc = NULL; + self->tree_sequence = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &TreeSequenceType, &tree_sequence)) { + goto out; + } + self->tree_sequence = tree_sequence; + Py_INCREF(self->tree_sequence); + if (TreeSequence_check_tree_sequence(self->tree_sequence) != 0) { + goto out; + } + self->ld_calc = PyMem_Malloc(sizeof(tsk_ld_calc_t)); + if (self->ld_calc == NULL) { + PyErr_NoMemory(); + goto out; + } + memset(self->ld_calc, 0, sizeof(tsk_ld_calc_t)); + err = tsk_ld_calc_alloc(self->ld_calc, self->tree_sequence->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +LdCalculator_get_r2(LdCalculator *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + Py_ssize_t a, b; + double r2; + + if (LdCalculator_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "nn", &a, &b)) { + goto out; + } + Py_BEGIN_ALLOW_THREADS + err = tsk_ld_calc_get_r2(self->ld_calc, (size_t) a, (size_t) b, &r2); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("d", r2); +out: + return ret; +} + +static PyObject * +LdCalculator_get_r2_array(LdCalculator *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + static char *kwlist[] = { + "dest", "source_index", "direction", "max_mutations", + "max_distance", NULL}; + PyObject *dest = NULL; + Py_buffer buffer; + Py_ssize_t source_index; + Py_ssize_t max_mutations = -1; + double max_distance = DBL_MAX; + int direction = TSK_DIR_FORWARD; + size_t num_r2_values = 0; + int buffer_acquired = 0; + + if (LdCalculator_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "On|ind", kwlist, + &dest, &source_index, &direction, &max_mutations, &max_distance)) { + goto out; + } + if (direction != TSK_DIR_FORWARD && direction != TSK_DIR_REVERSE) { + PyErr_SetString(PyExc_ValueError, + "direction must be FORWARD or REVERSE"); + goto out; + } + if (max_distance < 0) { + PyErr_SetString(PyExc_ValueError, "max_distance must be >= 0"); + goto out; + } + if (!PyObject_CheckBuffer(dest)) { + PyErr_SetString(PyExc_TypeError, + "dest buffer must support the Python buffer protocol."); + goto out; + } + if (PyObject_GetBuffer(dest, &buffer, PyBUF_SIMPLE|PyBUF_WRITABLE) != 0) { + goto out; + } + buffer_acquired = 1; + if (max_mutations == -1) { + max_mutations = buffer.len / sizeof(double); + } else if (max_mutations * sizeof(double) > buffer.len) { + PyErr_SetString(PyExc_BufferError, + "dest buffer is too small for the results"); + goto out; + } + + Py_BEGIN_ALLOW_THREADS + err = tsk_ld_calc_get_r2_array( + self->ld_calc, (size_t) source_index, direction, + (size_t) max_mutations, max_distance, + (double *) buffer.buf, &num_r2_values); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("n", (Py_ssize_t) num_r2_values); +out: + if (buffer_acquired) { + PyBuffer_Release(&buffer); + } + return ret; +} + +static PyMemberDef LdCalculator_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef LdCalculator_methods[] = { + {"get_r2", (PyCFunction) LdCalculator_get_r2, METH_VARARGS, + "Returns the value of the r2 statistic between the specified pair of " + "mutation indexes"}, + {"get_r2_array", (PyCFunction) LdCalculator_get_r2_array, + METH_VARARGS|METH_KEYWORDS, + "Returns r2 statistic for a given mutation over specified range"}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject LdCalculatorType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tskit.LdCalculator", /* tp_name */ + sizeof(LdCalculator), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)LdCalculator_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "LdCalculator objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + LdCalculator_methods, /* tp_methods */ + LdCalculator_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)LdCalculator_init, /* tp_init */ +}; + +/*=================================================================== + * Module level functions + *=================================================================== + */ + +static PyObject * +tskit_get_kastore_version(PyObject *self) +{ + /* TODO if we provide the option of linking against kastore separately, we + * should return the link time version using kas_get_version */ + return Py_BuildValue("iii", KAS_VERSION_MAJOR, KAS_VERSION_MINOR, KAS_VERSION_PATCH); +} + +static PyMethodDef tskit_methods[] = { + {"get_kastore_version", (PyCFunction) tskit_get_kastore_version, METH_NOARGS, + "Returns the version of kastore we have built in." }, + {NULL} /* Sentinel */ +}; + +/* Initialisation code supports Python 2.x and 3.x. The framework uses the + * recommended structure from http://docs.python.org/howto/cporting.html. + * I've ignored the point about storing state in globals, as the examples + * from the Python documentation still use this idiom. + */ + +#if PY_MAJOR_VERSION >= 3 + +static struct PyModuleDef tskitmodule = { + PyModuleDef_HEAD_INIT, + "_tskit", /* name of module */ + MODULE_DOC, /* module documentation, may be NULL */ + -1, + tskit_methods, + NULL, NULL, NULL, NULL +}; + +#define INITERROR return NULL + +PyObject * +PyInit__tskit(void) + +#else +#define INITERROR return + +void +init_tskit(void) +#endif +{ +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&tskitmodule); +#else + PyObject *module = Py_InitModule3("_tskit", tskit_methods, MODULE_DOC); +#endif + if (module == NULL) { + INITERROR; + } + import_array(); + + /* LightweightTableCollection type */ + LightweightTableCollectionType.tp_new = PyType_GenericNew; + if (PyType_Ready(&LightweightTableCollectionType) < 0) { + INITERROR; + } + Py_INCREF(&LightweightTableCollectionType); + PyModule_AddObject(module, "LightweightTableCollection", + (PyObject *) &LightweightTableCollectionType); + + /* IndividualTable type */ + IndividualTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&IndividualTableType) < 0) { + INITERROR; + } + Py_INCREF(&IndividualTableType); + PyModule_AddObject(module, "IndividualTable", (PyObject *) &IndividualTableType); + + /* NodeTable type */ + NodeTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&NodeTableType) < 0) { + INITERROR; + } + Py_INCREF(&NodeTableType); + PyModule_AddObject(module, "NodeTable", (PyObject *) &NodeTableType); + + /* EdgeTable type */ + EdgeTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&EdgeTableType) < 0) { + INITERROR; + } + Py_INCREF(&EdgeTableType); + PyModule_AddObject(module, "EdgeTable", (PyObject *) &EdgeTableType); + + /* MigrationTable type */ + MigrationTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&MigrationTableType) < 0) { + INITERROR; + } + Py_INCREF(&MigrationTableType); + PyModule_AddObject(module, "MigrationTable", (PyObject *) &MigrationTableType); + + /* SiteTable type */ + SiteTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&SiteTableType) < 0) { + INITERROR; + } + Py_INCREF(&SiteTableType); + PyModule_AddObject(module, "SiteTable", (PyObject *) &SiteTableType); + + /* MutationTable type */ + MutationTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&MutationTableType) < 0) { + INITERROR; + } + Py_INCREF(&MutationTableType); + PyModule_AddObject(module, "MutationTable", (PyObject *) &MutationTableType); + + /* PopulationTable type */ + PopulationTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&PopulationTableType) < 0) { + INITERROR; + } + Py_INCREF(&PopulationTableType); + PyModule_AddObject(module, "PopulationTable", (PyObject *) &PopulationTableType); + + /* ProvenanceTable type */ + ProvenanceTableType.tp_new = PyType_GenericNew; + if (PyType_Ready(&ProvenanceTableType) < 0) { + INITERROR; + } + Py_INCREF(&ProvenanceTableType); + PyModule_AddObject(module, "ProvenanceTable", (PyObject *) &ProvenanceTableType); + + /* TableCollectionTable type */ + TableCollectionType.tp_new = PyType_GenericNew; + if (PyType_Ready(&TableCollectionType) < 0) { + INITERROR; + } + Py_INCREF(&TableCollectionType); + PyModule_AddObject(module, "TableCollection", (PyObject *) &TableCollectionType); + + /* TreeSequence type */ + TreeSequenceType.tp_new = PyType_GenericNew; + if (PyType_Ready(&TreeSequenceType) < 0) { + INITERROR; + } + Py_INCREF(&TreeSequenceType); + PyModule_AddObject(module, "TreeSequence", (PyObject *) &TreeSequenceType); + + /* Tree type */ + TreeType.tp_new = PyType_GenericNew; + if (PyType_Ready(&TreeType) < 0) { + INITERROR; + } + Py_INCREF(&TreeType); + PyModule_AddObject(module, "Tree", (PyObject *) &TreeType); + + /* TreeIterator type */ + TreeIteratorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&TreeIteratorType) < 0) { + INITERROR; + } + Py_INCREF(&TreeIteratorType); + PyModule_AddObject(module, "TreeIterator", (PyObject *) &TreeIteratorType); + + /* TreeDiffIterator type */ + TreeDiffIteratorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&TreeDiffIteratorType) < 0) { + INITERROR; + } + Py_INCREF(&TreeDiffIteratorType); + PyModule_AddObject(module, "TreeDiffIterator", (PyObject *) &TreeDiffIteratorType); + + /* VcfConverter type */ + VcfConverterType.tp_new = PyType_GenericNew; + if (PyType_Ready(&VcfConverterType) < 0) { + INITERROR; + } + Py_INCREF(&VcfConverterType); + PyModule_AddObject(module, "VcfConverter", (PyObject *) &VcfConverterType); + + /* HaplotypeGenerator type */ + HaplotypeGeneratorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&HaplotypeGeneratorType) < 0) { + INITERROR; + } + Py_INCREF(&HaplotypeGeneratorType); + PyModule_AddObject(module, "HaplotypeGenerator", + (PyObject *) &HaplotypeGeneratorType); + + /* VariantGenerator type */ + VariantGeneratorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&VariantGeneratorType) < 0) { + INITERROR; + } + Py_INCREF(&VariantGeneratorType); + PyModule_AddObject(module, "VariantGenerator", (PyObject *) &VariantGeneratorType); + + /* LdCalculator type */ + LdCalculatorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&LdCalculatorType) < 0) { + INITERROR; + } + Py_INCREF(&LdCalculatorType); + PyModule_AddObject(module, "LdCalculator", (PyObject *) &LdCalculatorType); + + /* Errors and constants */ + TskitException = PyErr_NewException("_tskit.TskitException", NULL, NULL); + Py_INCREF(TskitException); + PyModule_AddObject(module, "TskitException", TskitException); + TskitLibraryError = PyErr_NewException("_tskit.LibraryError", TskitException, NULL); + Py_INCREF(TskitLibraryError); + PyModule_AddObject(module, "LibraryError", TskitLibraryError); + TskitFileFormatError = PyErr_NewException("_tskit.FileFormatError", NULL, NULL); + Py_INCREF(TskitFileFormatError); + PyModule_AddObject(module, "FileFormatError", TskitFileFormatError); + TskitVersionTooNewError = PyErr_NewException("_tskit.VersionTooNewError", + TskitException, NULL); + Py_INCREF(TskitVersionTooNewError); + PyModule_AddObject(module, "VersionTooNewError", TskitVersionTooNewError); + TskitVersionTooOldError = PyErr_NewException("_tskit.VersionTooOldError", + TskitException, NULL); + Py_INCREF(TskitVersionTooOldError); + PyModule_AddObject(module, "VersionTooOldError", TskitVersionTooOldError); + + PyModule_AddIntConstant(module, "NULL", TSK_NULL); + /* Node flags */ + PyModule_AddIntConstant(module, "NODE_IS_SAMPLE", TSK_NODE_IS_SAMPLE); + /* Tree flags */ + PyModule_AddIntConstant(module, "SAMPLE_COUNTS", TSK_SAMPLE_COUNTS); + PyModule_AddIntConstant(module, "SAMPLE_LISTS", TSK_SAMPLE_LISTS); + /* Directions */ + PyModule_AddIntConstant(module, "FORWARD", TSK_DIR_FORWARD); + PyModule_AddIntConstant(module, "REVERSE", TSK_DIR_REVERSE); + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} diff --git a/python/lib b/python/lib new file mode 120000 index 0000000000..9e8cf0c2ee --- /dev/null +++ b/python/lib @@ -0,0 +1 @@ +../c/ \ No newline at end of file diff --git a/python/setup.cfg b/python/setup.cfg new file mode 100644 index 0000000000..1a55f624ed --- /dev/null +++ b/python/setup.cfg @@ -0,0 +1,6 @@ +[bdist_wheel] +# This flag says to generate wheels that support both Python 2 and Python +# 3. If your code will not run unchanged on both Python 2 and 3, you will +# need to generate separate wheels for each Python version that you +# support. +universal=0 diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000000..d6b9feb1cb --- /dev/null +++ b/python/setup.py @@ -0,0 +1,99 @@ +import sys +import os.path +import codecs +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +# Obscure magic required to allow numpy be used as a 'setup_requires'. +# Based on https://stackoverflow.com/questions/19919905 +class local_build_ext(build_ext): + def finalize_options(self): + build_ext.finalize_options(self) + if sys.version_info[0] >= 3: + import builtins + else: + import __builtin__ as builtins + # Prevent numpy from thinking it is still in its setup process: + builtins.__NUMPY_SETUP__ = False + import numpy + self.include_dirs.append(numpy.get_include()) + + +libdir = "lib" +kastore_dir = os.path.join(libdir, "kastore", "c") +tsk_source_files = [ + "tsk_core.c", + "tsk_tables.c", + "tsk_trees.c", + "tsk_genotypes.c", + "tsk_stats.c", + "tsk_convert.c", +] +sources = ["_tskitmodule.c"] + [ + os.path.join(libdir, f) for f in tsk_source_files] + [ + os.path.join(kastore_dir, "kastore.c")] + +_tskit_module = Extension( + '_tskit', + sources=sources, + extra_compile_args=["-std=c99"], + # Enable asserts + undef_macros=["NDEBUG"], + include_dirs=[libdir, kastore_dir], +) + +here = os.path.abspath(os.path.dirname(__file__)) +with codecs.open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: + long_description = f.read() + +# After exec'ing this file we have tskit_version defined. +tskit_version = None # Keep PEP8 happy. +version_file = os.path.join("tskit", "_version.py") +with open(version_file) as f: + exec(f.read()) + +numpy_ver = "numpy>=1.7" + +setup( + name='tskit', + description='The tree sequence toolkit.', + long_description=long_description, + url='https://github.com/tskit-dev/tskit', + author='tskit developers', + version=tskit_version, + # TODO setup a tskit developers email address. + author_email='jerome.kelleher@well.ox.ac.uk', + classifiers=[ + 'Development Status :: 2 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Bio-Informatics', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + ], + keywords='tree sequence', + packages=['tskit'], + include_package_data=True, + ext_modules=[_tskit_module], + install_requires=[numpy_ver, "h5py", "svgwrite", "six", "jsonschema"], + entry_points={ + 'console_scripts': [ + 'tskit=tskit.__main__:main', + ], + }, + project_urls={ + 'Bug Reports': 'https://github.com/tskit-dev/tskit/issues', + 'Source': 'https://github.com/tskit-dev/tskit', + }, + setup_requires=[numpy_ver], + cmdclass={ + 'build_ext': local_build_ext + }, + license="MIT", + platforms=["POSIX", "Windows", "MacOS X"], +) diff --git a/python/stress_lowlevel.py b/python/stress_lowlevel.py new file mode 100644 index 0000000000..a92be9cb2c --- /dev/null +++ b/python/stress_lowlevel.py @@ -0,0 +1,93 @@ +""" +Code to stress the low-level API as much as possible to expose +any memory leaks or error handling issues. +""" +from __future__ import print_function +from __future__ import division + +import argparse +import unittest +import random +import resource +import os +import sys +import time +import logging + +import tests.test_demography as test_demography +import tests.test_highlevel as test_highlevel +import tests.test_lowlevel as test_lowlevel +import tskit_tests.test_vcf as test_vcf +import tskit_tests.test_threads as test_threads +import tskit_tests.test_stats as test_stats +import tskit_tests.test_tables as test_tables +import tskit_tests.test_topology as test_topology +import tskit_tests.test_file_format as test_file_format +import tskit_tests.test_dict_encoding as test_dict_encoding + + +def main(): + modules = { + "demography": test_demography, + "highlevel": test_highlevel, + "lowlevel": test_lowlevel, + "vcf": test_vcf, + "threads": test_threads, + "stats": test_stats, + "tables": test_tables, + "file_format": test_file_format, + "topology": test_topology, + "dict_encoding": test_dict_encoding, + } + parser = argparse.ArgumentParser( + description="Run tests in a loop to stress low-level interface") + parser.add_argument( + "-m", "--module", help="Run tests only on this module", + choices=list(modules.keys())) + args = parser.parse_args() + test_modules = list(modules.values()) + if args.module is not None: + test_modules = [modules[args.module]] + + # Need to do this to silence the errors from the file_format tests. + logging.basicConfig(level=logging.ERROR) + + print("iter\ttests\terr\tfail\tskip\tRSS\tmin\tmax\tmax@iter") + max_rss = 0 + max_rss_iter = 0 + min_rss = 1e100 + iteration = 0 + last_print = time.time() + devnull = open(os.devnull, 'w') + while True: + # We don't want any random variation in the amount of memory + # used from test-to-test. + random.seed(1) + testloader = unittest.TestLoader() + suite = testloader.loadTestsFromModule(test_modules[0]) + for mod in test_modules[1:]: + suite.addTests(testloader.loadTestsFromModule(mod)) + runner = unittest.TextTestRunner(verbosity=0, stream=devnull) + result = runner.run(suite) + rusage = resource.getrusage(resource.RUSAGE_SELF) + if max_rss < rusage.ru_maxrss: + max_rss = rusage.ru_maxrss + max_rss_iter = iteration + if min_rss > rusage.ru_maxrss: + min_rss = rusage.ru_maxrss + + # We don't want to flood stdout, so we rate-limit to 1 per second. + if time.time() - last_print > 1: + print( + iteration, result.testsRun, len(result.failures), + len(result.errors), len(result.skipped), + rusage.ru_maxrss, min_rss, max_rss, max_rss_iter, + sep="\t", end="\r") + last_print = time.time() + sys.stdout.flush() + + iteration += 1 + + +if __name__ == "__main__": + main() diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000000..3da9fc063a --- /dev/null +++ b/python/tests/__init__.py @@ -0,0 +1,534 @@ +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import division + +import base64 + +# TODO remove this code and refactor elsewhere. + +from .simplify import * # NOQA + +import tskit + + +class PythonTree(object): + """ + Presents the same interface as the Tree object for testing. This + is tightly coupled with the PythonTreeSequence object below which updates + the internal structures during iteration. + """ + def __init__(self, num_nodes): + self.num_nodes = num_nodes + self.parent = [tskit.NULL for _ in range(num_nodes)] + self.left_child = [tskit.NULL for _ in range(num_nodes)] + self.right_child = [tskit.NULL for _ in range(num_nodes)] + self.left_sib = [tskit.NULL for _ in range(num_nodes)] + self.right_sib = [tskit.NULL for _ in range(num_nodes)] + self.above_sample = [False for _ in range(num_nodes)] + self.is_sample = [False for _ in range(num_nodes)] + self.left = 0 + self.right = 0 + self.root = 0 + self.index = -1 + self.left_root = -1 + # We need a sites function, so this name is taken. + self.site_list = [] + + @classmethod + def from_tree(cls, tree): + ret = PythonTree(tree.num_nodes) + ret.left, ret.right = tree.get_interval() + ret.site_list = list(tree.sites()) + ret.index = tree.get_index() + ret.left_root = tree.left_root + ret.tree = tree + for u in range(ret.num_nodes): + ret.parent[u] = tree.parent(u) + ret.left_child[u] = tree.left_child(u) + ret.right_child[u] = tree.right_child(u) + ret.left_sib[u] = tree.left_sib(u) + ret.right_sib[u] = tree.right_sib(u) + assert ret == tree + return ret + + @property + def roots(self): + u = self.left_root + roots = [] + while u != tskit.NULL: + roots.append(u) + u = self.right_sib[u] + return roots + + def children(self, u): + v = self.left_child[u] + ret = [] + while v != tskit.NULL: + ret.append(v) + v = self.right_sib[v] + return ret + + def _preorder_nodes(self, u, l): + l.append(u) + for c in self.children(u): + self._preorder_nodes(c, l) + + def _postorder_nodes(self, u, l): + for c in self.children(u): + self._postorder_nodes(c, l) + l.append(u) + + def _inorder_nodes(self, u, l): + children = self.children(u) + if len(children) > 0: + mid = len(children) // 2 + for v in children[:mid]: + self._inorder_nodes(v, l) + l.append(u) + for v in children[mid:]: + self._inorder_nodes(v, l) + else: + l.append(u) + + def _levelorder_nodes(self, u, l, level): + l[level].append(u) if level < len(l) else l.append([u]) + for c in self.children(u): + self._levelorder_nodes(c, l, level + 1) + + def nodes(self, root=None, order="preorder"): + roots = [root] + if root is None: + roots = self.roots + for u in roots: + node_list = [] + if order == "preorder": + self._preorder_nodes(u, node_list) + elif order == "inorder": + self._inorder_nodes(u, node_list) + elif order == "postorder": + self._postorder_nodes(u, node_list) + elif order == "levelorder" or order == "breadthfirst": + # Returns nodes in their respective levels + # Nested list comprehension flattens node_list in order + self._levelorder_nodes(u, node_list, 0) + node_list = iter([i for level in node_list for i in level]) + else: + raise ValueError("order not supported") + for v in node_list: + yield v + + def get_interval(self): + return self.left, self.right + + def get_parent(self, node): + return self.parent[node] + + def get_children(self, node): + return self.children[node] + + def get_index(self): + return self.index + + def get_parent_dict(self): + d = { + u: self.parent[u] for u in range(self.num_nodes) + if self.parent[u] != tskit.NULL} + return d + + def sites(self): + return iter(self.site_list) + + def __eq__(self, other): + return ( + self.get_parent_dict() == other.get_parent_dict() and + self.get_interval() == other.get_interval() and + self.roots == other.roots and + self.get_index() == other.get_index() and + list(self.sites()) == list(other.sites())) + + def __ne__(self, other): + return not self.__eq__(other) + + def newick(self, root=None, precision=16, node_labels=None): + if node_labels is None: + node_labels = {u: str(u + 1) for u in self.tree.leaves()} + if root is None: + root = self.left_root + return self._build_newick(root, precision, node_labels) + ";" + + def _build_newick(self, node, precision, node_labels): + label = node_labels.get(node, "") + if self.left_child[node] == tskit.NULL: + s = label + else: + s = "(" + for child in self.children(node): + branch_length = self.tree.branch_length(child) + subtree = self._build_newick(child, precision, node_labels) + s += subtree + ":{0:.{1}f},".format(branch_length, precision) + s = s[:-1] + label + ")" + return s + + +class PythonTreeSequence(object): + """ + A python implementation of the TreeSequence object. + """ + def __init__(self, tree_sequence, breakpoints=None): + self._tree_sequence = tree_sequence + self._num_samples = tree_sequence.get_num_samples() + self._breakpoints = breakpoints + self._sites = [] + + def make_mutation(id_): + site, node, derived_state, parent, metadata = tree_sequence.get_mutation(id_) + return tskit.Mutation( + id_=id_, site=site, node=node, derived_state=derived_state, + parent=parent, metadata=metadata) + for j in range(tree_sequence.get_num_sites()): + pos, ancestral_state, ll_mutations, id_, metadata = tree_sequence.get_site(j) + self._sites.append(tskit.Site( + id_=id_, position=pos, ancestral_state=ancestral_state, + mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations], + metadata=metadata)) + + def edge_diffs(self): + M = self._tree_sequence.get_num_edges() + sequence_length = self._tree_sequence.get_sequence_length() + edges = [tskit.Edge(*self._tree_sequence.get_edge(j)) for j in range(M)] + time = [self._tree_sequence.get_node(edge.parent)[1] for edge in edges] + in_order = sorted(range(M), key=lambda j: ( + edges[j].left, time[j], edges[j].parent, edges[j].child)) + out_order = sorted(range(M), key=lambda j: ( + edges[j].right, -time[j], -edges[j].parent, -edges[j].child)) + j = 0 + k = 0 + left = 0 + while j < M or left < sequence_length: + e_out = [] + e_in = [] + while k < M and edges[out_order[k]].right == left: + h = out_order[k] + e_out.append(edges[h]) + k += 1 + while j < M and edges[in_order[j]].left == left: + h = in_order[j] + e_in.append(edges[h]) + j += 1 + right = sequence_length + if j < M: + right = min(right, edges[in_order[j]].left) + if k < M: + right = min(right, edges[out_order[k]].right) + yield (left, right), e_out, e_in + left = right + + def trees(self): + M = self._tree_sequence.get_num_edges() + sequence_length = self._tree_sequence.get_sequence_length() + edges = [ + tskit.Edge(*self._tree_sequence.get_edge(j)) for j in range(M)] + t = [ + self._tree_sequence.get_node(j)[1] + for j in range(self._tree_sequence.get_num_nodes())] + in_order = sorted( + range(M), key=lambda j: ( + edges[j].left, t[edges[j].parent], edges[j].parent, edges[j].child)) + out_order = sorted( + range(M), key=lambda j: ( + edges[j].right, -t[edges[j].parent], -edges[j].parent, -edges[j].child)) + j = 0 + k = 0 + N = self._tree_sequence.get_num_nodes() + st = PythonTree(N) + + samples = list(self._tree_sequence.get_samples()) + for l in range(len(samples)): + if l < len(samples) - 1: + st.right_sib[samples[l]] = samples[l + 1] + if l > 0: + st.left_sib[samples[l]] = samples[l - 1] + st.above_sample[samples[l]] = True + st.is_sample[samples[l]] = True + + st.left_root = tskit.NULL + if len(samples) > 0: + st.left_root = samples[0] + + u = st.left_root + roots = [] + while u != -1: + roots.append(u) + v = st.right_sib[u] + if v != -1: + assert st.left_sib[v] == u + u = v + + st.left = 0 + while j < M or st.left < sequence_length: + while k < M and edges[out_order[k]].right == st.left: + p = edges[out_order[k]].parent + c = edges[out_order[k]].child + k += 1 + + lsib = st.left_sib[c] + rsib = st.right_sib[c] + if lsib == tskit.NULL: + st.left_child[p] = rsib + else: + st.right_sib[lsib] = rsib + if rsib == tskit.NULL: + st.right_child[p] = lsib + else: + st.left_sib[rsib] = lsib + st.parent[c] = tskit.NULL + st.left_sib[c] = tskit.NULL + st.right_sib[c] = tskit.NULL + + # If c is not above a sample then we have nothing to do as we + # cannot affect the status of any roots. + if st.above_sample[c]: + # Compute the new above sample status for the nodes from + # p up to root. + v = p + above_sample = False + while v != tskit.NULL and not above_sample: + above_sample = st.is_sample[v] + u = st.left_child[v] + while u != tskit.NULL: + above_sample = above_sample or st.above_sample[u] + u = st.right_sib[u] + st.above_sample[v] = above_sample + root = v + v = st.parent[v] + + if not above_sample: + # root is no longer above samples. Remove it from the root list. + lroot = st.left_sib[root] + rroot = st.right_sib[root] + st.left_root = tskit.NULL + if lroot != tskit.NULL: + st.right_sib[lroot] = rroot + st.left_root = lroot + if rroot != tskit.NULL: + st.left_sib[rroot] = lroot + st.left_root = rroot + st.left_sib[root] = tskit.NULL + st.right_sib[root] = tskit.NULL + + # Add c to the root list. + # print("Insert ", c, "into root list") + if st.left_root != tskit.NULL: + lroot = st.left_sib[st.left_root] + if lroot != tskit.NULL: + st.right_sib[lroot] = c + st.left_sib[c] = lroot + st.left_sib[st.left_root] = c + st.right_sib[c] = st.left_root + st.left_root = c + + while j < M and edges[in_order[j]].left == st.left: + p = edges[in_order[j]].parent + c = edges[in_order[j]].child + j += 1 + + # print("insert ", c, "->", p) + st.parent[c] = p + u = st.right_child[p] + lsib = st.left_sib[c] + rsib = st.right_sib[c] + if u == tskit.NULL: + st.left_child[p] = c + st.left_sib[c] = tskit.NULL + st.right_sib[c] = tskit.NULL + else: + st.right_sib[u] = c + st.left_sib[c] = u + st.right_sib[c] = tskit.NULL + st.right_child[p] = c + + if st.above_sample[c]: + v = p + above_sample = False + while v != tskit.NULL and not above_sample: + above_sample = st.above_sample[v] + st.above_sample[v] = st.above_sample[v] or st.above_sample[c] + root = v + v = st.parent[v] + # print("root = ", root, st.above_sample[root]) + + if not above_sample: + # Replace c with root in root list. + # print("replacing", root, "with ", c ," in root list") + if lsib != tskit.NULL: + st.right_sib[lsib] = root + if rsib != tskit.NULL: + st.left_sib[rsib] = root + st.left_sib[root] = lsib + st.right_sib[root] = rsib + st.left_root = root + else: + # Remove c from root list. + # print("remove ", c ," from root list") + st.left_root = tskit.NULL + if lsib != tskit.NULL: + st.right_sib[lsib] = rsib + st.left_root = lsib + if rsib != tskit.NULL: + st.left_sib[rsib] = lsib + st.left_root = rsib + + st.right = sequence_length + if j < M: + st.right = min(st.right, edges[in_order[j]].left) + if k < M: + st.right = min(st.right, edges[out_order[k]].right) + assert st.left_root != tskit.NULL + while st.left_sib[st.left_root] != tskit.NULL: + st.left_root = st.left_sib[st.left_root] + st.index += 1 + # Add in all the sites + st.site_list = [ + site for site in self._sites if st.left <= site.position < st.right] + yield st + st.left = st.right + + +class MRCACalculator(object): + """ + Class to that allows us to compute the nearest common ancestor of arbitrary + nodes in an oriented forest. + + This is an implementation of Schieber and Vishkin's nearest common ancestor + algorithm from TAOCP volume 4A, pg.164-167 [K11]_. Preprocesses the + input tree into a sideways heap in O(n) time and processes queries for the + nearest common ancestor between an arbitary pair of nodes in O(1) time. + + :param oriented_forest: the input oriented forest + :type oriented_forest: list of integers + """ + LAMBDA = 0 + + def __init__(self, oriented_forest): + # We turn this oriened forest into a 1 based array by adding 1 + # to everything + converted = [0] + [x + 1 for x in oriented_forest] + self.__preprocess(converted) + + def __preprocess(self, oriented_forest): + """ + Preprocess the oriented forest, so that we can answer mrca queries + in constant time. + """ + n = len(oriented_forest) + child = [self.LAMBDA for i in range(n)] + parent = [self.LAMBDA for i in range(n)] + sib = [self.LAMBDA for i in range(n)] + self.__lambda = [0 for i in range(n)] + self.__pi = [0 for i in range(n)] + self.__tau = [0 for i in range(n)] + self.__beta = [0 for i in range(n)] + self.__alpha = [0 for i in range(n)] + for u in range(n): + v = oriented_forest[u] + sib[u] = child[v] + child[v] = u + parent[u] = v + p = child[self.LAMBDA] + n = 0 + self.__lambda[0] = -1 + while p != self.LAMBDA: + notDone = True + while notDone: + n += 1 + self.__pi[p] = n + self.__tau[n] = self.LAMBDA + self.__lambda[n] = 1 + self.__lambda[n >> 1] + if child[p] != self.LAMBDA: + p = child[p] + else: + notDone = False + self.__beta[p] = n + notDone = True + while notDone: + self.__tau[self.__beta[p]] = parent[p] + if sib[p] != self.LAMBDA: + p = sib[p] + notDone = False + else: + p = parent[p] + if p != self.LAMBDA: + h = self.__lambda[n & -self.__pi[p]] + self.__beta[p] = ((n >> h) | 1) << h + else: + notDone = False + # Begin the second traversal + self.__lambda[0] = self.__lambda[n] + self.__pi[self.LAMBDA] = 0 + self.__beta[self.LAMBDA] = 0 + self.__alpha[self.LAMBDA] = 0 + p = child[self.LAMBDA] + while p != self.LAMBDA: + notDone = True + while notDone: + a = ( + self.__alpha[parent[p]] | + (self.__beta[p] & -self.__beta[p]) + ) + self.__alpha[p] = a + if child[p] != self.LAMBDA: + p = child[p] + else: + notDone = False + notDone = True + while notDone: + if sib[p] != self.LAMBDA: + p = sib[p] + notDone = False + else: + p = parent[p] + notDone = p != self.LAMBDA + + def get_mrca(self, x, y): + """ + Returns the most recent common ancestor of the nodes x and y, + or -1 if the nodes belong to different trees. + + :param x: the first node + :param y: the second node + :return: the MRCA of nodes x and y + """ + # WE need to rescale here because SV expects 1-based arrays. + return self._sv_mrca(x + 1, y + 1) - 1 + + def _sv_mrca(self, x, y): + if self.__beta[x] <= self.__beta[y]: + h = self.__lambda[self.__beta[y] & -self.__beta[x]] + else: + h = self.__lambda[self.__beta[x] & -self.__beta[y]] + k = self.__alpha[x] & self.__alpha[y] & -(1 << h) + h = self.__lambda[k & -k] + j = ((self.__beta[x] >> h) | 1) << h + if j == self.__beta[x]: + xhat = x + else: + ell = self.__lambda[self.__alpha[x] & ((1 << h) - 1)] + xhat = self.__tau[((self.__beta[x] >> ell) | 1) << ell] + if j == self.__beta[y]: + yhat = y + else: + ell = self.__lambda[self.__alpha[y] & ((1 << h) - 1)] + yhat = self.__tau[((self.__beta[y] >> ell) | 1) << ell] + if self.__pi[xhat] <= self.__pi[yhat]: + z = xhat + else: + z = yhat + return z + + +def base64_encode(metadata): + """ + Returns the specified metadata bytes object encoded as an ASCII-safe + string. + """ + return base64.b64encode(metadata).decode('utf8') diff --git a/python/tests/data/SLiM/README b/python/tests/data/SLiM/README new file mode 100644 index 0000000000..28f8d4abb5 --- /dev/null +++ b/python/tests/data/SLiM/README @@ -0,0 +1 @@ +The files in this directory are generated by SLiM. diff --git a/python/tests/data/SLiM/minimal-example.trees b/python/tests/data/SLiM/minimal-example.trees new file mode 100644 index 0000000000..47cc0794a8 Binary files /dev/null and b/python/tests/data/SLiM/minimal-example.trees differ diff --git a/python/tests/data/SLiM/minimal-example.txt b/python/tests/data/SLiM/minimal-example.txt new file mode 100644 index 0000000000..d867d40926 --- /dev/null +++ b/python/tests/data/SLiM/minimal-example.txt @@ -0,0 +1,15 @@ +initialize() { + initializeTreeSeq(); + initializeMutationRate(0.0); + initializeMutationType("m1", 0.5, "f", -0.1); + initializeGenomicElementType("g1", m1, 1.0); + initializeGenomicElement(g1, 0, 9); + initializeRecombinationRate(1e-1); +} +1 { + sim.addSubpop("p1", 5); +} +3 { + sim.treeSeqOutput("tests/data/SLiM/minimal-example.trees"); + sim.simulationFinished(); +} diff --git a/python/tests/data/SLiM/single-locus-example.trees b/python/tests/data/SLiM/single-locus-example.trees new file mode 100644 index 0000000000..3ca57aed20 Binary files /dev/null and b/python/tests/data/SLiM/single-locus-example.trees differ diff --git a/python/tests/data/SLiM/single-locus-example.txt b/python/tests/data/SLiM/single-locus-example.txt new file mode 100644 index 0000000000..15e7e9757b --- /dev/null +++ b/python/tests/data/SLiM/single-locus-example.txt @@ -0,0 +1,15 @@ +initialize() { + initializeTreeSeq(); + initializeMutationRate(0.0); + initializeMutationType("m1", 0.5, "f", -0.1); + initializeGenomicElementType("g1", m1, 1.0); + initializeGenomicElement(g1, 0, 9); + initializeRecombinationRate(0); +} +1 { + sim.addSubpop("p1", 5); +} +3 { + sim.treeSeqOutput("tests/data/SLiM/single-locus-example.trees"); + sim.simulationFinished(); +} diff --git a/python/tests/data/hdf5-formats/msprime-0.3.0_v2.0.hdf5 b/python/tests/data/hdf5-formats/msprime-0.3.0_v2.0.hdf5 new file mode 100644 index 0000000000..6e294ac5bd Binary files /dev/null and b/python/tests/data/hdf5-formats/msprime-0.3.0_v2.0.hdf5 differ diff --git a/python/tests/data/hdf5-formats/msprime-0.4.0_v3.1.hdf5 b/python/tests/data/hdf5-formats/msprime-0.4.0_v3.1.hdf5 new file mode 100644 index 0000000000..d6a23150db Binary files /dev/null and b/python/tests/data/hdf5-formats/msprime-0.4.0_v3.1.hdf5 differ diff --git a/python/tests/data/hdf5-formats/msprime-0.5.0_v10.0.hdf5 b/python/tests/data/hdf5-formats/msprime-0.5.0_v10.0.hdf5 new file mode 100644 index 0000000000..16e9c7a2b6 Binary files /dev/null and b/python/tests/data/hdf5-formats/msprime-0.5.0_v10.0.hdf5 differ diff --git a/python/tests/data/simplify-bugs/01-edges.txt b/python/tests/data/simplify-bugs/01-edges.txt new file mode 100644 index 0000000000..3bb7f16d73 --- /dev/null +++ b/python/tests/data/simplify-bugs/01-edges.txt @@ -0,0 +1,24 @@ +left right parent child +0.000000 4.000000 5 2,3 +4.000000 9.000000 5 3 +22.000000 28.000000 5 3 +0.000000 18.000000 6 0,1,4 +18.000000 19.000000 6 0,1,4,5 +19.000000 28.000000 6 0,1,5 +0.000000 19.000000 7 6 +19.000000 28.000000 7 2,6 +0.000000 28.000000 8 7 +0.000000 28.000000 9 8 +0.000000 18.000000 10 5,9 +18.000000 28.000000 10 9 +0.000000 19.000000 11 10 +19.000000 28.000000 11 4,10 +0.000000 9.000000 12 11 +9.000000 22.000000 12 3,11 +22.000000 28.000000 12 11 +0.000000 28.000000 13 12 +0.000000 28.000000 14 13 +0.000000 28.000000 15 14 +0.000000 4.000000 16 15 +4.000000 19.000000 16 2,15 +19.000000 28.000000 16 15 diff --git a/python/tests/data/simplify-bugs/01-mutations.txt b/python/tests/data/simplify-bugs/01-mutations.txt new file mode 100644 index 0000000000..789b7b2e1c --- /dev/null +++ b/python/tests/data/simplify-bugs/01-mutations.txt @@ -0,0 +1 @@ +site node derived_state diff --git a/python/tests/data/simplify-bugs/01-nodes.txt b/python/tests/data/simplify-bugs/01-nodes.txt new file mode 100644 index 0000000000..ee4aa7c878 --- /dev/null +++ b/python/tests/data/simplify-bugs/01-nodes.txt @@ -0,0 +1,18 @@ +is_sample time +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +0 5.000000 +0 6.000000 +0 7.000000 +0 8.000000 +0 9.000000 +0 10.000000 +0 11.000000 +0 12.000000 +0 13.000000 +0 14.000000 +0 15.000000 +0 16.000000 diff --git a/python/tests/data/simplify-bugs/01-sites.txt b/python/tests/data/simplify-bugs/01-sites.txt new file mode 100644 index 0000000000..3f12d4059a --- /dev/null +++ b/python/tests/data/simplify-bugs/01-sites.txt @@ -0,0 +1 @@ +position ancestral_state diff --git a/python/tests/data/simplify-bugs/02-edges.txt b/python/tests/data/simplify-bugs/02-edges.txt new file mode 100644 index 0000000000..8ddd2fddc6 --- /dev/null +++ b/python/tests/data/simplify-bugs/02-edges.txt @@ -0,0 +1,378 @@ +left right parent child +63.824647 100.000000 100 47,85 +41.913156 100.000000 101 81,83 +0.000000 100.000000 102 88,98 +0.000000 76.111147 103 38,76 +0.000000 100.000000 104 12,63 +0.000000 100.000000 105 3,79 +0.000000 100.000000 106 23,95 +0.000000 100.000000 107 46,61 +0.000000 63.824647 108 47,87 +0.000000 41.913156 109 19,81 +41.913156 100.000000 109 19,101 +65.841615 100.000000 110 62,68 +0.000000 42.476761 111 59,72 +0.000000 79.893379 112 65,92 +0.000000 100.000000 113 60,107 +38.209860 53.470013 114 15,86 +99.748128 100.000000 115 30,96 +86.643371 95.255452 116 51,54 +0.000000 100.000000 117 2,21 +95.255452 100.000000 118 49,51 +91.315428 100.000000 119 74,78 +27.039936 100.000000 120 24,80 +27.368300 63.824647 121 94,108 +63.824647 100.000000 121 87,94 +92.895855 100.000000 122 7,64 +0.000000 46.386608 123 9,74 +0.000000 85.936096 124 56,91 +0.000000 47.640188 125 67,77 +30.422602 43.241239 126 40,55 +0.000000 6.538844 127 4,41 +51.451422 100.000000 128 9,52 +62.577097 90.247778 129 93,99 +0.000000 7.108776 130 85,90 +60.658379 100.000000 131 57,67 +0.000000 100.000000 132 18,109 +11.309867 79.893379 133 112,117 +79.893379 100.000000 133 92,117 +96.284451 99.748128 134 59,96 +42.476761 43.241239 135 72,126 +43.241239 45.405170 135 55,72 +0.000000 63.435890 136 89,104 +4.419653 34.689388 137 30,64 +3.568615 27.368300 138 62,108 +23.785721 60.658379 139 7,57 +60.658379 92.895855 139 7,131 +92.895855 100.000000 139 122,131 +37.043091 83.777115 140 5,49 +63.435890 100.000000 141 33,104 +0.000000 34.240481 142 14,43,45,66,113,136 +34.240481 58.841850 142 14,29,43,45,66,113,136 +58.841850 63.435890 142 14,43,45,66,113,136 +63.435890 75.398652 142 14,43,45,66,89,113 +75.398652 83.394074 142 14,43,45,66,89,96,113 +83.394074 85.706099 142 43,45,66,89,96,113 +85.706099 85.936096 142 45,66,89,96,113 +85.936096 92.349588 142 45,56,66,89,96,113 +92.349588 93.764750 142 45,56,66,89,96 +93.764750 96.284451 142 45,56,89,96 +96.284451 99.748128 142 45,56,89,134 +99.748128 100.000000 142 45,56,59,89 +0.000000 20.449369 143 24,97 +60.445610 74.287710 144 42,59 +52.173268 61.719977 145 64,99 +25.474716 25.584357 146 51,96 +88.701340 95.255452 147 49,106 +95.255452 100.000000 147 106,118 +0.000000 16.128261 148 55,82 +46.492800 51.451422 149 9,133 +43.241239 100.000000 150 40,142 +16.128261 20.956228 151 82,86 +83.693827 92.003768 152 72,75 +7.108776 23.785721 153 7,90 +23.785721 93.565018 153 90,139 +0.000000 23.444517 154 33,35 +50.484324 52.173268 155 0,64 +14.560625 15.750456 156 70,102 +72.191272 73.218934 157 6,54 +23.051339 38.269459 158 6,105 +83.122541 89.826128 158 25,105 +19.401604 21.440137 159 0,32 +83.534148 100.000000 159 13,65 +86.401076 100.000000 160 1,82 +55.154018 60.905848 161 28,93 +1.397707 4.419653 162 30,37 +4.419653 30.811847 162 37,137 +55.451624 83.122541 163 11,105 +83.122541 89.826128 163 11,158 +89.826128 100.000000 163 11,105 +50.640219 60.445610 164 53,59 +60.445610 74.287710 164 53,144 +74.287710 96.284451 164 53,59 +0.000000 7.108776 165 7,32 +7.108776 19.401604 165 32,153 +20.956228 27.039936 166 80,86 +27.039936 38.209860 166 86,120 +54.284561 55.451624 166 11,120 +55.451624 100.000000 166 120,163 +0.000000 37.043091 167 5,78 +45.914485 46.492800 168 99,133 +46.492800 51.451422 168 99,149 +51.451422 52.173268 168 99,133 +52.173268 54.775535 168 133,145 +0.000000 1.397707 169 13,30 +1.397707 18.000399 169 13,162 +13.905651 20.721380 170 34,69 +23.444517 55.154018 171 33,93 +21.440137 30.600251 172 0,4 +99.132998 100.000000 173 27,160 +0.000000 22.327425 174 96,105 +26.281985 32.637756 175 54,124 +32.070756 38.209860 176 85,166 +38.209860 51.803381 176 85,120 +37.135096 73.914710 177 103,132 +30.600251 39.656727 178 4,50 +41.362197 55.154018 179 34,171 +55.154018 63.435890 179 33,34 +63.435890 77.115170 179 34,141 +79.777617 83.693827 180 72,133 +39.656727 47.640188 181 50,125 +47.640188 60.658379 181 50,67 +99.753960 100.000000 182 20,93 +74.287710 100.000000 183 42,121 +22.902252 37.043091 184 13,49 +37.043091 83.534148 184 13,140 +34.067864 34.689388 185 137,184 +34.689388 51.690836 185 30,184 +0.000000 11.300701 186 8,17,20,22,42,52,57,71,75,94,117,142,143 +11.300701 11.309867 186 8,17,20,22,42,57,71,75,94,117,142,143 +11.309867 18.561545 186 8,17,20,22,42,57,71,75,94,133,142,143 +18.561545 20.449369 186 17,20,22,42,57,71,75,94,133,142,143 +20.449369 20.575506 186 17,20,22,42,57,71,75,94,97,133,142 +20.575506 23.785721 186 17,20,22,42,57,71,75,97,133,142 +23.785721 24.710938 186 17,20,22,42,71,75,97,133,142 +24.710938 25.328103 186 17,20,22,31,42,71,75,97,133,142 +25.328103 27.114357 186 17,20,22,31,42,71,75,97,133,138,142 +27.114357 27.368300 186 1,17,20,22,31,42,71,75,97,133,138,142 +27.368300 27.631858 186 1,17,20,22,31,42,62,71,75,97,133,142 +27.631858 32.637756 186 1,17,20,31,42,62,71,75,97,133,142 +32.637756 33.673379 186 1,17,20,31,42,54,62,71,75,97,133,142 +33.673379 34.689388 186 1,17,20,31,42,62,71,75,97,133,142 +34.689388 36.134011 186 1,17,20,31,42,62,64,71,75,97,133,142 +36.134011 42.476761 186 1,17,20,31,42,62,64,71,75,97,126,133,142 +42.476761 42.751756 186 1,17,20,31,42,62,64,71,75,97,133,135,142 +42.751756 43.241239 186 1,17,20,31,42,62,64,71,75,97,121,133,135,142 +43.241239 44.698548 186 1,17,20,31,42,62,64,71,75,97,121,133,135,150 +44.698548 45.405170 186 1,11,17,20,31,42,62,64,71,75,97,121,133,135,150 +45.405170 45.914485 186 1,11,17,20,31,42,62,64,71,72,75,97,121,133,150 +45.914485 46.386608 186 1,11,17,20,31,42,62,64,71,72,75,97,121,150 +46.386608 50.484324 186 1,11,17,20,31,42,62,64,71,72,74,75,97,121,150 +50.484324 51.344527 186 1,11,17,20,31,42,62,71,72,74,75,97,121,150 +51.344527 53.470013 186 1,11,20,31,42,62,71,72,74,75,97,121,150 +53.470013 54.284561 186 1,11,20,31,42,62,71,72,74,75,86,97,121,150 +54.284561 54.414499 186 1,20,31,42,62,71,72,74,75,86,97,121,150 +54.414499 54.775535 186 1,20,31,42,62,72,74,75,86,97,121,150 +54.775535 58.841850 186 1,20,31,42,62,72,74,75,86,97,121,133,150 +58.841850 58.916166 186 1,20,29,31,42,62,72,74,75,86,97,121,133,150 +58.916166 60.445610 186 1,20,31,42,62,72,74,75,86,97,121,133,150 +60.445610 65.841615 186 1,20,31,62,72,74,75,86,97,121,133,150 +65.841615 67.041287 186 1,20,31,72,74,75,86,97,110,121,133,150 +67.041287 71.852095 186 1,20,31,72,74,75,78,86,97,110,121,133,150 +71.852095 74.244135 186 1,20,31,72,74,75,78,86,97,110,121,124,133,150 +74.244135 74.287710 186 1,20,31,74,75,78,86,97,110,121,124,133,150 +74.287710 76.111147 186 1,20,31,74,75,78,86,97,110,124,133,150,183 +76.111147 79.777617 186 1,20,31,38,74,75,78,86,97,110,124,133,150,183 +79.777617 80.236310 186 1,20,31,38,74,75,78,86,97,110,124,150,183 +80.236310 81.076006 186 1,20,31,38,41,74,75,78,86,97,110,124,150,183 +81.076006 83.693827 186 1,20,31,38,41,74,75,78,86,97,110,150,183 +83.693827 84.317521 186 1,20,31,38,41,74,78,86,97,110,150,152,183 +84.317521 86.401076 186 1,20,31,38,41,74,78,86,97,110,152,183 +86.401076 87.032912 186 20,31,38,41,74,78,86,97,110,152,160,183 +87.032912 91.315428 186 20,31,38,41,71,74,78,86,97,110,152,160,183 +91.315428 92.003768 186 20,31,38,41,71,86,97,110,119,152,160,183 +92.003768 92.357754 186 20,31,38,41,71,72,75,86,97,110,119,160,183 +92.357754 94.331891 186 20,31,38,41,58,71,72,75,86,97,110,119,160,183 +94.331891 94.869872 186 20,31,38,39,41,58,71,72,75,86,97,110,119,160,183 +94.869872 99.132998 186 20,31,38,39,41,58,71,72,75,86,97,102,110,119,160,183 +99.132998 99.753960 186 20,31,38,39,41,58,71,72,75,86,97,102,110,119,173,183 +99.753960 100.000000 186 31,38,39,41,58,71,72,75,86,97,102,110,119,173,183 +81.203782 82.034281 187 30,50 +93.836844 99.748128 187 30,36 +99.748128 100.000000 187 36,115 +93.565018 94.899812 188 14,90 +21.243630 30.422602 189 25,55 +43.757852 69.373708 189 8,25 +88.947794 100.000000 190 28,77 +85.951325 94.331891 191 39,84 +57.108058 93.565018 192 153,164 +93.565018 96.284451 192 139,164 +96.284451 100.000000 192 53,139 +18.561545 21.243630 193 8,55 +23.120737 43.757852 193 8,82 +31.693641 38.209860 194 15,68 +38.209860 38.778292 194 68,114 +81.076006 85.936096 195 124,192 +85.936096 100.000000 195 91,192 +16.654066 18.561545 196 55,68 +18.561545 21.243630 196 68,193 +21.243630 23.120737 196 8,68 +23.120737 31.693641 196 68,193 +81.752010 95.899004 197 26,27 +20.721380 41.362197 198 34,44 +41.362197 77.115170 198 44,179 +77.115170 79.005768 198 34,44 +0.000000 3.568615 199 62,73 +3.568615 19.497026 199 73,138 +0.000000 3.475795 200 16,69 +30.811847 36.777478 201 10,37 +0.000000 22.902252 202 49,84 +51.803381 63.824647 202 84,85 +63.824647 85.951325 202 84,100 +85.951325 94.331891 202 100,191 +94.331891 94.444668 202 84,100 +73.914710 78.645366 203 10,132 +85.706099 100.000000 204 8,43 +0.000000 21.243630 205 25,27 +21.243630 23.725869 205 27,189 +61.095745 72.191272 206 6,177 +72.191272 73.218934 206 157,177 +73.218934 73.914710 206 6,177 +73.914710 76.111147 206 6,103 +76.111147 100.000000 206 6,76 +39.199376 42.476761 207 58,111 +42.476761 49.108935 207 58,59 +0.000000 14.560625 208 70,111 +14.560625 15.750456 208 111,156 +15.750456 39.199376 208 70,111 +39.199376 46.891849 208 70,207 +39.263142 55.154018 209 28,54 +55.154018 60.905848 209 54,161 +60.905848 72.191272 209 28,54 +0.000000 22.902252 210 31,202 +22.902252 24.710938 210 31,84 +34.233492 35.877983 211 16,102 +62.308779 63.126608 212 0,202 +0.000000 4.851737 213 103,169 +38.778292 53.470013 214 78,114 +53.470013 67.041287 214 15,78 +78.884401 99.468331 215 4,186 +0.000000 1.397707 216 0,1,6,10,11,15,26,28,29,34,36,37,39,40,44,48,50,51,53,54,58,64,68,80,83,86,93,99,102,106,108,112,123,124,125,127,130,132,148,154,165,167,174,186,199,200,205,208,210,213 +1.397707 3.475795 216 0,1,6,10,11,15,26,28,29,34,36,39,40,44,48,50,51,53,54,58,64,68,80,83,86,93,99,102,106,108,112,123,124,125,127,130,132,148,154,165,167,174,186,199,200,205,208,210,213 +3.475795 3.568615 216 0,1,6,10,11,15,16,26,28,29,34,36,39,40,44,48,50,51,53,54,58,64,68,69,80,83,86,93,99,102,106,108,112,123,124,125,127,130,132,148,154,165,167,174,186,199,205,208,210,213 +3.568615 4.419653 216 0,1,6,10,11,15,16,26,28,29,34,36,39,40,44,48,50,51,53,54,58,64,68,69,80,83,86,93,99,102,106,112,123,124,125,127,130,132,148,154,165,167,174,186,199,205,208,210,213 +4.419653 4.851737 216 0,1,6,10,11,15,16,26,28,29,34,36,39,40,44,48,50,51,53,54,58,68,69,80,83,86,93,99,102,106,112,123,124,125,127,130,132,148,154,165,167,174,186,199,205,208,210,213 +4.851737 6.538844 216 0,1,6,10,11,15,16,26,28,29,34,36,39,40,44,48,50,51,53,54,58,68,69,80,83,86,93,99,102,103,106,112,123,124,125,127,130,132,148,154,165,167,169,174,186,199,205,208,210 +6.538844 7.108776 216 0,1,4,6,10,11,15,16,26,28,29,34,36,39,40,41,44,48,50,51,53,54,58,68,69,80,83,86,93,99,102,103,106,112,123,124,125,130,132,148,154,165,167,169,174,186,199,205,208,210 +7.108776 11.300701 216 0,1,4,6,10,11,15,16,26,28,29,34,36,39,40,41,44,48,50,51,53,54,58,68,69,80,83,85,86,93,99,102,103,106,112,123,124,125,132,148,154,165,167,169,174,186,199,205,208,210 +11.300701 11.309867 216 0,1,4,6,10,11,15,16,26,28,29,34,36,39,40,41,44,48,50,51,52,53,54,58,68,69,80,83,85,86,93,99,102,103,106,112,123,124,125,132,148,154,165,167,169,174,186,199,205,208,210 +11.309867 13.905651 216 0,1,4,6,10,11,15,16,26,28,29,34,36,39,40,41,44,48,50,51,52,53,54,58,68,69,80,83,85,86,93,99,102,103,106,123,124,125,132,148,154,165,167,169,174,186,199,205,208,210 +13.905651 14.560625 216 0,1,4,6,10,11,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,68,80,83,85,86,93,99,102,103,106,123,124,125,132,148,154,165,167,169,170,174,186,199,205,208,210 +14.560625 15.750456 216 0,1,4,6,10,11,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,68,80,83,85,86,93,99,103,106,123,124,125,132,148,154,165,167,169,170,174,186,199,205,208,210 +15.750456 16.128261 216 0,1,4,6,10,11,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,68,80,83,85,86,93,99,102,103,106,123,124,125,132,148,154,165,167,169,170,174,186,199,205,208,210 +16.128261 16.654066 216 0,1,4,6,10,11,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,55,58,68,80,83,85,93,99,102,103,106,123,124,125,132,151,154,165,167,169,170,174,186,199,205,208,210 +16.654066 18.000399 216 0,1,4,6,10,11,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,80,83,85,93,99,102,103,106,123,124,125,132,151,154,165,167,169,170,174,186,196,199,205,208,210 +18.000399 19.401604 216 0,1,4,6,10,11,13,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,80,83,85,93,99,102,103,106,123,124,125,132,151,154,162,165,167,170,174,186,196,199,205,208,210 +19.401604 19.497026 216 1,4,6,10,11,13,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,80,83,85,93,99,102,103,106,123,124,125,132,151,153,154,159,162,167,170,174,186,196,199,205,208,210 +19.497026 20.449369 216 1,4,6,10,11,13,15,16,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,73,80,83,85,93,99,102,103,106,123,124,125,132,138,151,153,154,159,162,167,170,174,186,196,205,208,210 +20.449369 20.575506 216 1,4,6,10,11,13,15,16,24,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,73,80,83,85,93,99,102,103,106,123,124,125,132,138,151,153,154,159,162,167,170,174,186,196,205,208,210 +20.575506 20.721380 216 1,4,6,10,11,13,15,16,24,26,28,29,36,39,40,41,44,48,50,51,52,53,54,58,73,80,83,85,93,94,99,102,103,106,123,124,125,132,138,151,153,154,159,162,167,170,174,186,196,205,208,210 +20.721380 20.956228 216 1,4,6,10,11,13,15,16,24,26,28,29,36,39,40,41,48,50,51,52,53,54,58,69,73,80,83,85,93,94,99,102,103,106,123,124,125,132,138,151,153,154,159,162,167,174,186,196,198,205,208,210 +20.956228 21.440137 216 1,4,6,10,11,13,15,16,24,26,28,29,36,39,40,41,48,50,51,52,53,54,58,69,73,82,83,85,93,94,99,102,103,106,123,124,125,132,138,153,154,159,162,166,167,174,186,196,198,205,208,210 +21.440137 22.327425 216 1,6,10,11,13,15,16,24,26,28,29,32,36,39,40,41,48,50,51,52,53,54,58,69,73,82,83,85,93,94,99,102,103,106,123,124,125,132,138,153,154,162,166,167,172,174,186,196,198,205,208,210 +22.327425 22.902252 216 1,6,10,11,13,15,16,24,26,28,29,32,36,39,40,41,48,50,51,52,53,54,58,69,73,82,83,85,93,94,96,99,102,103,105,106,123,124,125,132,138,153,154,162,166,167,172,186,196,198,205,208,210 +22.902252 23.051339 216 1,6,10,11,15,16,24,26,28,29,32,36,39,40,41,48,50,51,52,53,54,58,69,73,82,83,85,93,94,96,99,102,103,105,106,123,124,125,132,138,153,154,162,166,167,172,184,186,196,198,205,208,210 +23.051339 23.120737 216 1,10,11,15,16,24,26,28,29,32,36,39,40,41,48,50,51,52,53,54,58,69,73,82,83,85,93,94,96,99,102,103,106,123,124,125,132,138,153,154,158,162,166,167,172,184,186,196,198,205,208,210 +23.120737 23.444517 216 1,10,11,15,16,24,26,28,29,32,36,39,40,41,48,50,51,52,53,54,58,69,73,83,85,93,94,96,99,102,103,106,123,124,125,132,138,153,154,158,162,166,167,172,184,186,196,198,205,208,210 +23.444517 23.725869 216 1,10,11,15,16,24,26,28,29,32,35,36,39,40,41,48,50,51,52,53,54,58,69,73,83,85,94,96,99,102,103,106,123,124,125,132,138,153,158,162,166,167,171,172,184,186,196,198,205,208,210 +23.725869 24.710938 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,54,58,69,73,83,85,94,96,99,102,103,106,123,124,125,132,138,153,158,162,166,167,171,172,184,186,189,196,198,208,210 +24.710938 25.328103 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,54,58,69,73,83,84,85,94,96,99,102,103,106,123,124,125,132,138,153,158,162,166,167,171,172,184,186,189,196,198,208 +25.328103 25.474716 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,54,58,69,73,83,84,85,94,96,99,102,103,106,123,124,125,132,153,158,162,166,167,171,172,184,186,189,196,198,208 +25.474716 25.584357 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,52,53,54,58,69,73,83,84,85,94,99,102,103,106,123,124,125,132,146,153,158,162,166,167,171,172,184,186,189,196,198,208 +25.584357 26.281985 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,54,58,69,73,83,84,85,94,96,99,102,103,106,123,124,125,132,153,158,162,166,167,171,172,184,186,189,196,198,208 +26.281985 27.039936 216 1,10,11,15,16,24,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,58,69,73,83,84,85,94,96,99,102,103,106,123,125,132,153,158,162,166,167,171,172,175,184,186,189,196,198,208 +27.039936 27.114357 216 1,10,11,15,16,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,58,69,73,83,84,85,94,96,99,102,103,106,123,125,132,153,158,162,166,167,171,172,175,184,186,189,196,198,208 +27.114357 27.368300 216 10,11,15,16,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,58,69,73,83,84,85,94,96,99,102,103,106,123,125,132,153,158,162,166,167,171,172,175,184,186,189,196,198,208 +27.368300 27.631858 216 10,11,15,16,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,132,153,158,162,166,167,171,172,175,184,186,189,196,198,208 +27.631858 30.422602 216 10,11,15,16,22,26,27,28,29,32,35,36,39,40,41,48,50,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,132,153,158,162,166,167,171,172,175,184,186,189,196,198,208 +30.422602 30.600251 216 10,11,15,16,22,25,26,27,28,29,32,35,36,39,41,48,50,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,126,132,153,158,162,166,167,171,172,175,184,186,196,198,208 +30.600251 30.811847 216 0,10,11,15,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,126,132,153,158,162,166,167,171,175,178,184,186,196,198,208 +30.811847 31.693641 216 0,11,15,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,126,132,137,153,158,166,167,171,175,178,184,186,196,198,201,208 +31.693641 32.070756 216 0,11,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,58,69,73,83,84,85,96,99,102,103,106,121,123,125,126,132,137,153,158,166,167,171,175,178,184,186,193,194,198,201,208 +32.070756 32.637756 216 0,11,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,58,69,73,83,84,96,99,102,103,106,121,123,125,126,132,137,153,158,167,171,175,176,178,184,186,193,194,198,201,208 +32.637756 33.673379 216 0,11,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,58,69,73,83,84,96,99,102,103,106,121,123,124,125,126,132,137,153,158,167,171,176,178,184,186,193,194,198,201,208 +33.673379 34.067864 216 0,11,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,102,103,106,121,123,124,125,126,132,137,153,158,167,171,176,178,184,186,193,194,198,201,208 +34.067864 34.233492 216 0,11,16,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,102,103,106,121,123,124,125,126,132,153,158,167,171,176,178,185,186,193,194,198,201,208 +34.233492 34.240481 216 0,11,22,25,26,27,28,29,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,103,106,121,123,124,125,126,132,153,158,167,171,176,178,185,186,193,194,198,201,208,211 +34.240481 35.877983 216 0,11,22,25,26,27,28,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,103,106,121,123,124,125,126,132,153,158,167,171,176,178,185,186,193,194,198,201,208,211 +35.877983 36.134011 216 0,11,16,22,25,26,27,28,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,102,103,106,121,123,124,125,126,132,153,158,167,171,176,178,185,186,193,194,198,201,208 +36.134011 36.777478 216 0,11,16,22,25,26,27,28,32,35,36,39,41,48,51,52,53,54,58,69,73,83,84,96,99,102,103,106,121,123,124,125,132,153,158,167,171,176,178,185,186,193,194,198,201,208 +36.777478 37.043091 216 0,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,58,69,73,83,84,96,99,102,103,106,121,123,124,125,132,153,158,167,171,176,178,185,186,193,194,198,208 +37.043091 37.135096 216 0,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,58,69,73,78,83,84,96,99,102,103,106,121,123,124,125,132,153,158,171,176,178,185,186,193,194,198,208 +37.135096 38.269459 216 0,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,58,69,73,78,83,84,96,99,102,106,121,123,124,125,153,158,171,176,177,178,185,186,193,194,198,208 +38.269459 38.778292 216 0,6,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,58,69,73,78,83,84,96,99,102,105,106,121,123,124,125,153,171,176,177,178,185,186,193,194,198,208 +38.778292 39.199376 216 0,6,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,58,68,69,73,83,84,96,99,102,105,106,121,123,124,125,153,171,176,177,178,185,186,193,198,208,214 +39.199376 39.263142 216 0,6,10,11,16,22,25,26,27,28,32,35,36,37,39,41,48,51,52,53,54,68,69,73,83,84,96,99,102,105,106,121,123,124,125,153,171,176,177,178,185,186,193,198,208,214 +39.263142 39.656727 216 0,6,10,11,16,22,25,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,83,84,96,99,102,105,106,121,123,124,125,153,171,176,177,178,185,186,193,198,208,209,214 +39.656727 41.362197 216 0,4,6,10,11,16,22,25,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,83,84,96,99,102,105,106,121,123,124,153,171,176,177,181,185,186,193,198,208,209,214 +41.362197 41.913156 216 0,4,6,10,11,16,22,25,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,83,84,96,99,102,105,106,121,123,124,153,176,177,181,185,186,193,198,208,209,214 +41.913156 42.751756 216 0,4,6,10,11,16,22,25,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,84,96,99,102,105,106,121,123,124,153,176,177,181,185,186,193,198,208,209,214 +42.751756 43.757852 216 0,4,6,10,11,16,22,25,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,84,96,99,102,105,106,123,124,153,176,177,181,185,186,193,198,208,209,214 +43.757852 44.698548 216 0,4,6,10,11,16,22,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,82,84,96,99,102,105,106,123,124,153,176,177,181,185,186,189,198,208,209,214 +44.698548 45.405170 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,68,69,73,82,84,96,99,102,105,106,123,124,153,176,177,181,185,186,189,198,208,209,214 +45.405170 45.914485 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,73,82,84,96,99,102,105,106,123,124,153,176,177,181,185,186,189,198,208,209,214 +45.914485 46.386608 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,73,82,84,96,102,105,106,123,124,153,168,176,177,181,185,186,189,198,208,209,214 +46.386608 46.492800 216 0,4,6,9,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,73,82,84,96,102,105,106,124,153,168,176,177,181,185,186,189,198,208,209,214 +46.492800 46.891849 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,73,82,84,96,102,105,106,124,153,168,176,177,181,185,186,189,198,208,209,214 +46.891849 47.640188 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,70,73,82,84,96,102,105,106,124,153,168,176,177,181,185,186,189,198,207,209,214 +47.640188 49.108935 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,68,69,70,73,77,82,84,96,102,105,106,124,153,168,176,177,181,185,186,189,198,207,209,214 +49.108935 50.484324 216 0,4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,58,59,68,69,70,73,77,82,84,96,102,105,106,124,153,168,176,177,181,185,186,189,198,209,214 +50.484324 50.640219 216 4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,53,55,58,59,68,69,70,73,77,82,84,96,102,105,106,124,153,155,168,176,177,181,185,186,189,198,209,214 +50.640219 51.344527 216 4,6,10,16,22,26,27,32,35,36,37,39,41,48,51,52,55,58,68,69,70,73,77,82,84,96,102,105,106,124,153,155,164,168,176,177,181,185,186,189,198,209,214 +51.344527 51.451422 216 4,6,10,16,17,22,26,27,32,35,36,37,39,41,48,51,52,55,58,68,69,70,73,77,82,84,96,102,105,106,124,153,155,164,168,176,177,181,185,186,189,198,209,214 +51.451422 51.690836 216 4,6,10,16,17,22,26,27,32,35,36,37,39,41,48,51,55,58,68,69,70,73,77,82,84,96,102,105,106,124,128,153,155,164,168,176,177,181,185,186,189,198,209,214 +51.690836 51.803381 216 4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,73,77,82,84,96,102,105,106,124,128,153,155,164,168,176,177,181,184,186,189,198,209,214 +51.803381 52.173268 216 4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,73,77,82,96,102,105,106,120,124,128,153,155,164,168,177,181,184,186,189,198,202,209,214 +52.173268 54.284561 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,73,77,82,96,102,105,106,120,124,128,153,164,168,177,181,184,186,189,198,202,209,214 +54.284561 54.414499 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,73,77,82,96,102,105,106,124,128,153,164,166,168,177,181,184,186,189,198,202,209,214 +54.414499 54.775535 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,71,73,77,82,96,102,105,106,124,128,153,164,166,168,177,181,184,186,189,198,202,209,214 +54.775535 55.451624 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,71,73,77,82,96,102,105,106,124,128,145,153,164,166,177,181,184,186,189,198,202,209,214 +55.451624 57.108058 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,71,73,77,82,96,102,106,124,128,145,153,164,166,177,181,184,186,189,198,202,209,214 +57.108058 58.916166 216 0,4,6,10,16,17,22,26,27,30,32,35,36,37,39,41,48,51,55,58,68,69,70,71,73,77,82,96,102,106,124,128,145,166,177,181,184,186,189,192,198,202,209,214 +58.916166 60.658379 216 0,4,6,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,51,55,58,68,69,70,71,73,77,82,96,102,106,124,128,145,166,177,181,184,186,189,192,198,202,209,214 +60.658379 60.905848 216 0,4,6,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,68,69,70,71,73,77,82,96,102,106,124,128,145,166,177,184,186,189,192,198,202,209,214 +60.905848 61.095745 216 0,4,6,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,68,69,70,71,73,77,82,93,96,102,106,124,128,145,166,177,184,186,189,192,198,202,209,214 +61.095745 61.719977 216 0,4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,68,69,70,71,73,77,82,93,96,102,106,124,128,145,166,184,186,189,192,198,202,206,209,214 +61.719977 62.308779 216 0,4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,68,69,70,71,73,77,82,93,96,99,102,106,124,128,166,184,186,189,192,198,202,206,209,214 +62.308779 62.577097 216 4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,68,69,70,71,73,77,82,93,96,99,102,106,124,128,166,184,186,189,192,198,206,209,212,214 +62.577097 63.126608 216 4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,68,69,70,71,73,77,82,96,102,106,124,128,129,166,184,186,189,192,198,206,209,212,214 +63.126608 65.841615 216 0,4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,68,69,70,71,73,77,82,96,102,106,124,128,129,166,184,186,189,192,198,202,206,209,214 +65.841615 67.041287 216 0,4,10,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,69,70,71,73,77,82,96,102,106,124,128,129,166,184,186,189,192,198,202,206,209,214 +67.041287 69.373708 216 0,4,10,15,16,17,22,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,69,70,71,73,77,82,96,102,106,124,128,129,166,184,186,189,192,198,202,206,209 +69.373708 71.852095 216 0,4,8,10,15,16,17,22,25,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,69,70,71,73,77,82,96,102,106,124,128,129,166,184,186,192,198,202,206,209 +71.852095 72.191272 216 0,4,8,10,15,16,17,22,25,26,27,29,30,32,35,36,37,39,41,48,50,51,55,58,64,69,70,71,73,77,82,96,102,106,128,129,166,184,186,192,198,202,206,209 +72.191272 73.218934 216 0,4,8,10,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,55,58,64,69,70,71,73,77,82,96,102,106,128,129,166,184,186,192,198,202,206 +73.218934 73.914710 216 0,4,8,10,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,73,77,82,96,102,106,128,129,166,184,186,192,198,202,206 +73.914710 74.244135 216 0,4,8,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,73,77,82,96,102,106,128,129,166,184,186,192,198,202,203,206 +74.244135 75.398652 216 0,4,8,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,72,73,77,82,96,102,106,128,129,166,184,186,192,198,202,203,206 +75.398652 77.115170 216 0,4,8,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,72,73,77,82,102,106,128,129,166,184,186,192,198,202,203,206 +77.115170 78.645366 216 0,4,8,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,72,73,77,82,102,106,128,129,141,166,184,186,192,198,202,203,206 +78.645366 78.884401 216 0,4,8,10,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,72,73,77,82,102,106,128,129,132,141,166,184,186,192,198,202,206 +78.884401 79.005768 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,35,36,37,39,41,48,50,51,54,55,58,64,69,70,71,72,73,77,82,102,106,128,129,132,141,166,184,192,198,202,206,215 +79.005768 79.777617 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,34,35,36,37,39,41,44,48,50,51,54,55,58,64,69,70,71,72,73,77,82,102,106,128,129,132,141,166,184,192,202,206,215 +79.777617 79.893379 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,34,35,36,37,39,41,44,48,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,192,202,206,215 +79.893379 80.236310 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,34,35,36,37,39,41,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,192,202,206,215 +80.236310 81.076006 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,192,202,206,215 +81.076006 81.203782 216 0,8,10,15,16,17,22,25,26,27,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,195,202,206,215 +81.203782 81.752010 216 0,8,10,15,16,17,22,25,26,27,28,29,32,34,35,36,37,39,44,48,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,187,195,202,206,215 +81.752010 82.034281 216 0,8,10,15,16,17,22,25,28,29,32,34,35,36,37,39,44,48,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,187,195,197,202,206,215 +82.034281 83.122541 216 0,8,10,15,16,17,22,25,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,195,197,202,206,215 +83.122541 83.394074 216 0,8,10,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,195,197,202,206,215 +83.394074 83.534148 216 0,8,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,65,69,70,71,73,77,82,102,106,128,129,132,141,166,180,184,195,197,202,206,215 +83.534148 83.693827 216 0,8,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,140,141,159,166,180,195,197,202,206,215 +83.693827 83.777115 216 0,8,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,133,140,141,159,166,195,197,202,206,215 +83.777115 84.317521 216 0,5,8,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,49,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,133,141,159,166,195,197,202,206,215 +84.317521 85.706099 216 0,5,8,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,49,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,133,141,150,159,166,195,197,202,206,215 +85.706099 85.951325 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,39,44,48,49,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,133,141,150,159,166,195,197,202,204,206,215 +85.951325 86.401076 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,44,48,49,50,51,54,55,58,64,69,70,71,73,77,82,102,106,128,129,132,133,141,150,159,166,195,197,202,204,206,215 +86.401076 86.643371 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,44,48,49,50,51,54,55,58,64,69,70,71,73,77,102,106,128,129,132,133,141,150,159,166,195,197,202,204,206,215 +86.643371 87.032912 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,44,48,49,50,55,58,64,69,70,71,73,77,102,106,116,128,129,132,133,141,150,159,166,195,197,202,204,206,215 +87.032912 88.701340 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,44,48,49,50,55,58,64,69,70,73,77,102,106,116,128,129,132,133,141,150,159,166,195,197,202,204,206,215 +88.701340 88.947794 216 0,5,10,14,15,16,17,22,28,29,30,32,34,35,36,37,44,48,50,55,58,64,69,70,73,77,102,116,128,129,132,133,141,147,150,159,166,195,197,202,204,206,215 +88.947794 89.826128 216 0,5,10,14,15,16,17,22,29,30,32,34,35,36,37,44,48,50,55,58,64,69,70,73,102,116,128,129,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +89.826128 90.247778 216 0,5,10,14,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,58,64,69,70,73,102,116,128,129,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +90.247778 92.349588 216 0,5,10,14,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,58,64,69,70,73,93,99,102,116,128,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +92.349588 92.357754 216 0,5,10,14,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,58,64,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +92.357754 92.895855 216 0,5,10,14,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,64,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +92.895855 93.565018 216 0,5,10,14,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,190,195,197,202,204,206,215 +93.565018 93.764750 216 0,5,10,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,188,190,195,197,202,204,206,215 +93.764750 93.836844 216 0,5,10,15,16,17,22,25,29,30,32,34,35,36,37,44,48,50,55,66,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,188,190,195,197,202,204,206,215 +93.836844 94.444668 216 0,5,10,15,16,17,22,25,29,32,34,35,37,44,48,50,55,66,69,70,73,93,99,102,113,116,128,132,133,141,147,150,159,166,187,188,190,195,197,202,204,206,215 +94.444668 94.869872 216 0,5,10,15,16,17,22,25,29,32,34,35,37,44,48,50,55,66,69,70,73,84,93,99,100,102,113,116,128,132,133,141,147,150,159,166,187,188,190,195,197,204,206,215 +94.869872 94.899812 216 0,5,10,15,16,17,22,25,29,32,34,35,37,44,48,50,55,66,69,70,73,84,93,99,100,113,116,128,132,133,141,147,150,159,166,187,188,190,195,197,204,206,215 +94.899812 95.255452 216 0,5,10,14,15,16,17,22,25,29,32,34,35,37,44,48,50,55,66,69,70,73,84,90,93,99,100,113,116,128,132,133,141,147,150,159,166,187,190,195,197,204,206,215 +95.255452 95.899004 216 0,5,10,14,15,16,17,22,25,29,32,34,35,37,44,48,50,54,55,66,69,70,73,84,90,93,99,100,113,128,132,133,141,147,150,159,166,187,190,195,197,204,206,215 +95.899004 99.132998 216 0,5,10,14,15,16,17,22,25,26,27,29,32,34,35,37,44,48,50,54,55,66,69,70,73,84,90,93,99,100,113,128,132,133,141,147,150,159,166,187,190,195,204,206,215 +99.132998 99.468331 216 0,5,10,14,15,16,17,22,25,26,29,32,34,35,37,44,48,50,54,55,66,69,70,73,84,90,93,99,100,113,128,132,133,141,147,150,159,166,187,190,195,204,206,215 +99.468331 99.753960 216 0,4,5,10,14,15,16,17,22,25,26,29,32,34,35,37,44,48,50,54,55,66,69,70,73,84,90,93,99,100,113,128,132,133,141,147,150,159,166,186,187,190,195,204,206 +99.753960 100.000000 216 0,4,5,10,14,15,16,17,22,25,26,29,32,34,35,37,44,48,50,54,55,66,69,70,73,84,90,99,100,113,128,132,133,141,147,150,159,166,182,186,187,190,195,204,206 diff --git a/python/tests/data/simplify-bugs/02-mutations.txt b/python/tests/data/simplify-bugs/02-mutations.txt new file mode 100644 index 0000000000..789b7b2e1c --- /dev/null +++ b/python/tests/data/simplify-bugs/02-mutations.txt @@ -0,0 +1 @@ +site node derived_state diff --git a/python/tests/data/simplify-bugs/02-nodes.txt b/python/tests/data/simplify-bugs/02-nodes.txt new file mode 100644 index 0000000000..f09836feb1 --- /dev/null +++ b/python/tests/data/simplify-bugs/02-nodes.txt @@ -0,0 +1,218 @@ +is_sample time +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +1 0.000000 +0 0.000194 +0 0.000317 +0 0.000403 +0 0.000539 +0 0.001031 +0 0.001435 +0 0.001762 +0 0.001774 +0 0.001809 +0 0.002119 +0 0.002788 +0 0.002811 +0 0.003626 +0 0.003640 +0 0.003920 +0 0.003996 +0 0.004180 +0 0.004187 +0 0.004326 +0 0.004453 +0 0.005014 +0 0.005035 +0 0.005512 +0 0.005679 +0 0.005842 +0 0.006024 +0 0.006182 +0 0.006282 +0 0.006540 +0 0.006850 +0 0.006989 +0 0.007400 +0 0.007440 +0 0.007559 +0 0.007880 +0 0.008043 +0 0.008337 +0 0.008406 +0 0.008968 +0 0.009216 +0 0.009236 +0 0.009300 +0 0.010000 +0 0.010592 +0 0.011448 +0 0.011471 +0 0.011991 +0 0.012237 +0 0.012290 +0 0.012429 +0 0.012484 +0 0.013078 +0 0.013189 +0 0.014031 +0 0.014208 +0 0.014449 +0 0.014731 +0 0.015388 +0 0.015556 +0 0.015588 +0 0.015727 +0 0.015773 +0 0.015945 +0 0.016374 +0 0.016542 +0 0.016560 +0 0.016713 +0 0.017029 +0 0.017180 +0 0.017280 +0 0.017546 +0 0.017637 +0 0.017806 +0 0.017943 +0 0.017983 +0 0.018078 +0 0.018319 +0 0.018490 +0 0.018598 +0 0.018688 +0 0.019008 +0 0.019012 +0 0.019112 +0 0.019190 +0 0.019191 +0 0.019477 +0 0.020000 +0 0.020659 +0 0.020952 +0 0.021267 +0 0.021289 +0 0.021641 +0 0.021823 +0 0.022321 +0 0.022553 +0 0.022602 +0 0.023120 +0 0.023233 +0 0.024210 +0 0.024342 +0 0.024893 +0 0.024922 +0 0.024934 +0 0.025736 +0 0.025806 +0 0.025938 +0 0.026345 +0 0.026486 +0 0.026561 +0 0.026877 +0 0.027657 +0 0.028587 +0 0.029557 +0 0.029563 +0 0.029588 +0 0.029963 +0 0.030000 diff --git a/python/tests/data/simplify-bugs/02-sites.txt b/python/tests/data/simplify-bugs/02-sites.txt new file mode 100644 index 0000000000..3f12d4059a --- /dev/null +++ b/python/tests/data/simplify-bugs/02-sites.txt @@ -0,0 +1 @@ +position ancestral_state diff --git a/python/tests/data/simplify-bugs/03-edges.txt b/python/tests/data/simplify-bugs/03-edges.txt new file mode 100644 index 0000000000..1a6181527c --- /dev/null +++ b/python/tests/data/simplify-bugs/03-edges.txt @@ -0,0 +1,16 @@ +left right parent child +0.000000 10000.000000 50 29,31 +0.000000 10000.000000 51 11,15 +0.000000 1554.123401 52 1,51 +1554.123401 10000.000000 52 1 +0.000000 1736.203571 53 52 +1736.203571 10000.000000 53 51,52 +0.000000 10000.000000 54 4,12,27,38,39,40 +0.000000 10000.000000 55 17,25,45,48,49,50 +0.000000 10000.000000 56 24,55 +0.000000 1554.123401 57 56 +1554.123401 1736.203571 57 51,56 +1736.203571 10000.000000 57 56 +0.000000 10000.000000 58 0,13,22,57 +0.000000 10000.000000 59 2,3,5,6,7,8,9,10,14,16,18,19,20,21,23,26,28,30,32,33,34,35,36,37,41,42,43,44,46,47,53,54,58 +0.000000 10000.000000 60 59 diff --git a/python/tests/data/simplify-bugs/03-mutations.txt b/python/tests/data/simplify-bugs/03-mutations.txt new file mode 100644 index 0000000000..0fc3b2073a --- /dev/null +++ b/python/tests/data/simplify-bugs/03-mutations.txt @@ -0,0 +1,14 @@ +site node derived_state +0 52 1 +1 34 1 +2 57 1 +2 3 1 +3 58 1 +4 34 1 +5 56 1 +6 55 1 +6 1 1 +7 51 1 +8 43 1 +9 54 1 +9 0 1 diff --git a/python/tests/data/simplify-bugs/03-nodes.txt b/python/tests/data/simplify-bugs/03-nodes.txt new file mode 100644 index 0000000000..93ab916666 --- /dev/null +++ b/python/tests/data/simplify-bugs/03-nodes.txt @@ -0,0 +1,62 @@ +is_sample time population +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +0 50.000000 -1 +0 51.000000 -1 +0 52.000000 -1 +0 53.000000 -1 +0 54.000000 -1 +0 55.000000 -1 +0 56.000000 -1 +0 57.000000 -1 +0 58.000000 -1 +0 59.000000 -1 +0 60.000000 -1 diff --git a/python/tests/data/simplify-bugs/03-sites.txt b/python/tests/data/simplify-bugs/03-sites.txt new file mode 100644 index 0000000000..bd2b03ebe8 --- /dev/null +++ b/python/tests/data/simplify-bugs/03-sites.txt @@ -0,0 +1,11 @@ +position ancestral_state +284.252209 0 +1313.686815 0 +1554.123401 0 +1736.203571 0 +3310.290546 0 +4208.672558 0 +4995.288904 0 +5187.559857 0 +5211.162157 0 +5483.889413 0 diff --git a/python/tests/data/simplify-bugs/04-edges.txt b/python/tests/data/simplify-bugs/04-edges.txt new file mode 100644 index 0000000000..1849beb940 --- /dev/null +++ b/python/tests/data/simplify-bugs/04-edges.txt @@ -0,0 +1,14 @@ +left right parent child +0.000000 0.500000 6 0,1 +0.500000 1.000000 6 4,5 +0.000000 0.400000 7 2,3 +0.000000 0.500000 8 4,5 +0.500000 1.000000 8 0,1 +0.400000 1.000000 9 2,3 +0.400000 1.000000 10 8,9 +0.000000 0.100000 13 6,14 +0.100000 0.400000 15 7,14 +0.000000 0.100000 11 7,13 +0.100000 0.400000 11 6,15 +0.000000 0.400000 12 8,11 +0.400000 1.000000 12 6,10 diff --git a/python/tests/data/simplify-bugs/04-mutations.txt b/python/tests/data/simplify-bugs/04-mutations.txt new file mode 100644 index 0000000000..789b7b2e1c --- /dev/null +++ b/python/tests/data/simplify-bugs/04-mutations.txt @@ -0,0 +1 @@ +site node derived_state diff --git a/python/tests/data/simplify-bugs/04-nodes.txt b/python/tests/data/simplify-bugs/04-nodes.txt new file mode 100644 index 0000000000..1154050df3 --- /dev/null +++ b/python/tests/data/simplify-bugs/04-nodes.txt @@ -0,0 +1,17 @@ +is_sample time population +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +1 0.000000 -1 +0 1.000000 -1 +0 1.000000 -1 +0 1.000000 -1 +0 1.000000 -1 +0 2.000000 -1 +0 3.000000 -1 +0 4.000000 -1 +0 2.000000 -1 +0 1.000000 -1 +0 2.000000 -1 diff --git a/python/tests/data/simplify-bugs/04-sites.txt b/python/tests/data/simplify-bugs/04-sites.txt new file mode 100644 index 0000000000..3f12d4059a --- /dev/null +++ b/python/tests/data/simplify-bugs/04-sites.txt @@ -0,0 +1 @@ +position ancestral_state diff --git a/python/tests/data/simplify-bugs/05-edges.txt b/python/tests/data/simplify-bugs/05-edges.txt new file mode 100644 index 0000000000..96f2c53d1f --- /dev/null +++ b/python/tests/data/simplify-bugs/05-edges.txt @@ -0,0 +1,13 @@ +left right parent child +0.0 0.8 5 9 +0.3 1.0 5 10 +0.0 1.0 6 8 +0.0 0.3 6 10 +0.0 0.9 7 11 +0.0 1.0 7 12 +0.8 1.0 7 9 +0.9 1.0 1 11 +0.4 1.0 1 6 +0.0 0.4 4 6 +0.0 1.0 4 7 +0.0 1.0 0 1,2,4,5 diff --git a/python/tests/data/simplify-bugs/05-mutations.txt b/python/tests/data/simplify-bugs/05-mutations.txt new file mode 100644 index 0000000000..789b7b2e1c --- /dev/null +++ b/python/tests/data/simplify-bugs/05-mutations.txt @@ -0,0 +1 @@ +site node derived_state diff --git a/python/tests/data/simplify-bugs/05-nodes.txt b/python/tests/data/simplify-bugs/05-nodes.txt new file mode 100644 index 0000000000..0988a2e2db --- /dev/null +++ b/python/tests/data/simplify-bugs/05-nodes.txt @@ -0,0 +1,14 @@ +id is_sample population time +0 0 0 6.0 +1 0 0 2.0 +2 0 0 2.0 +3 0 0 2.0 +4 0 0 2.0 +5 0 0 1.0 +6 0 0 1.0 +7 0 0 1.0 +8 1 0 0.0 +9 1 0 0.0 +10 1 0 0.0 +11 1 0 0.0 +12 1 0 0.0 diff --git a/python/tests/data/simplify-bugs/05-sites.txt b/python/tests/data/simplify-bugs/05-sites.txt new file mode 100644 index 0000000000..3f12d4059a --- /dev/null +++ b/python/tests/data/simplify-bugs/05-sites.txt @@ -0,0 +1 @@ +position ancestral_state diff --git a/python/tests/simplify.py b/python/tests/simplify.py new file mode 100644 index 0000000000..8967783f6b --- /dev/null +++ b/python/tests/simplify.py @@ -0,0 +1,439 @@ +""" +Python implementation of the simplify algorithm. +""" +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import division + +import sys + +import numpy as np + +import tskit + + +def overlapping_segments(segments): + """ + Returns an iterator over the (left, right, X) tuples describing the + distinct overlapping segments in the specified set. + """ + S = sorted(segments, key=lambda x: x.left) + n = len(S) + # Insert a sentinel at the end for convenience. + S.append(Segment(sys.float_info.max, 0)) + right = S[0].left + X = [] + j = 0 + while j < n: + # Remove any elements of X with right <= left + left = right + X = [x for x in X if x.right > left] + if len(X) == 0: + left = S[j].left + while j < n and S[j].left == left: + X.append(S[j]) + j += 1 + j -= 1 + right = min(x.right for x in X) + right = min(right, S[j + 1].left) + yield left, right, X + j += 1 + + while len(X) > 0: + left = right + X = [x for x in X if x.right > left] + if len(X) > 0: + right = min(x.right for x in X) + yield left, right, X + + +class Segment(object): + """ + A class representing a single segment. Each segment has a left and right, + denoting the loci over which it spans, a node and a next, giving the next + in the chain. + + The node it records is the *output* node ID. + """ + def __init__(self, left=None, right=None, node=None, next=None): + self.left = left + self.right = right + self.node = node + self.next = next + + def __str__(self): + s = "({}-{}->{}:next={})".format( + self.left, self.right, self.node, repr(self.next)) + return s + + def __repr__(self): + return repr((self.left, self.right, self.node)) + + def __lt__(self, other): + return (self.left, self.right, self.node) < (other.left, other.right, self.node) + + +class Simplifier(object): + """ + Simplifies a tree sequence to its minimal representation given a subset + of the leaves. + """ + def __init__( + self, ts, sample, reduce_to_site_topology=False, filter_sites=True, + filter_populations=True, filter_individuals=True): + self.ts = ts + self.n = len(sample) + self.reduce_to_site_topology = reduce_to_site_topology + self.sequence_length = ts.sequence_length + self.filter_sites = filter_sites + self.filter_populations = filter_populations + self.filter_individuals = filter_individuals + self.num_mutations = ts.num_mutations + self.input_sites = list(ts.sites()) + self.A_head = [None for _ in range(ts.num_nodes)] + self.A_tail = [None for _ in range(ts.num_nodes)] + self.tables = tskit.TableCollection(sequence_length=ts.sequence_length) + self.edge_buffer = {} + self.node_id_map = np.zeros(ts.num_nodes, dtype=np.int32) - 1 + self.mutation_node_map = [-1 for _ in range(self.num_mutations)] + self.samples = set(sample) + for sample_id in sample: + output_id = self.record_node(sample_id, is_sample=True) + self.add_ancestry(sample_id, 0, self.sequence_length, output_id) + # We keep a map of input nodes to mutations. + self.mutation_map = [[] for _ in range(ts.num_nodes)] + position = ts.tables.sites.position + site = ts.tables.mutations.site + node = ts.tables.mutations.node + for mutation_id in range(ts.num_mutations): + site_position = position[site[mutation_id]] + self.mutation_map[node[mutation_id]].append((site_position, mutation_id)) + self.position_lookup = None + if self.reduce_to_site_topology: + self.position_lookup = np.hstack([[0], position, [self.sequence_length]]) + + def record_node(self, input_id, is_sample=False): + """ + Adds a new node to the output table corresponding to the specified input + node ID. + """ + node = self.ts.node(input_id) + flags = node.flags + # Need to zero out the sample flag + flags &= ~tskit.NODE_IS_SAMPLE + if is_sample: + flags |= tskit.NODE_IS_SAMPLE + output_id = self.tables.nodes.add_row( + flags=flags, time=node.time, population=node.population, + metadata=node.metadata, individual=node.individual) + self.node_id_map[input_id] = output_id + return output_id + + def rewind_node(self, input_id, output_id): + """ + Remove the mapping for the specified input and output node pair. This is + done because there are no edges referring to the node. + """ + assert output_id == len(self.tables.nodes) - 1 + assert output_id == self.node_id_map[input_id] + self.tables.nodes.truncate(output_id) + self.node_id_map[input_id] = -1 + + def flush_edges(self): + """ + Flush the edges to the output table after sorting and squashing + any redundant records. + """ + num_edges = 0 + for child in sorted(self.edge_buffer.keys()): + for edge in self.edge_buffer[child]: + self.tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child) + num_edges += 1 + self.edge_buffer.clear() + return num_edges + + def record_edge(self, left, right, parent, child): + """ + Adds an edge to the output list. + """ + if self.reduce_to_site_topology: + X = self.position_lookup + left_index = np.searchsorted(X, left) + right_index = np.searchsorted(X, right) + # Find the smallest site position index greater than or equal to left + # and right, i.e., slide each endpoint of an interval to the right + # until they hit a site position. If both left and right map to the + # the same position then we discard this edge. We also discard an + # edge if left = 0 and right is less than the first site position. + if left_index == right_index or (left_index == 0 and right_index == 1): + return + # Remap back to zero if the left end maps to the first site. + if left_index == 1: + left_index = 0 + left = X[left_index] + right = X[right_index] + if child not in self.edge_buffer: + self.edge_buffer[child] = [tskit.Edge(left, right, parent, child)] + else: + last = self.edge_buffer[child][-1] + if last.right == left: + last.right = right + else: + self.edge_buffer[child].append(tskit.Edge(left, right, parent, child)) + + def print_state(self): + print(".................") + print("Ancestors: ") + num_nodes = len(self.A_tail) + for j in range(num_nodes): + print("\t", j, "->", end="") + x = self.A_head[j] + while x is not None: + print("({}-{}->{})".format(x.left, x.right, x.node), end="") + x = x.next + print() + print("Mutation map:") + for u in range(len(self.mutation_map)): + v = self.mutation_map[u] + if len(v) > 0: + print("\t", u, "->", v) + print("Node ID map: (input->output)") + for input_id, output_id in enumerate(self.node_id_map): + print("\t", input_id, "->", output_id) + print("Mutation node map") + for j in range(self.num_mutations): + print("\t", j, "->", self.mutation_node_map[j]) + print("Output:") + print(self.tables) + self.check_state() + + def add_ancestry(self, input_id, left, right, node): + tail = self.A_tail[input_id] + if tail is None: + x = Segment(left, right, node) + self.A_head[input_id] = x + self.A_tail[input_id] = x + else: + if tail.right == left and tail.node == node: + tail.right = right + else: + x = Segment(left, right, node) + tail.next = x + self.A_tail[input_id] = x + + def merge_labeled_ancestors(self, S, input_id): + """ + All ancestry segments in S come together into a new parent. + The new parent must be assigned and any overlapping segments coalesced. + """ + output_id = self.node_id_map[input_id] + is_sample = output_id != -1 + if is_sample: + # Free up the existing ancestry mapping. + x = self.A_tail[input_id] + assert x.left == 0 and x.right == self.sequence_length + self.A_tail[input_id] = None + self.A_head[input_id] = None + + prev_right = 0 + for left, right, X in overlapping_segments(S): + if len(X) == 1: + ancestry_node = X[0].node + if is_sample: + self.record_edge(left, right, output_id, ancestry_node) + ancestry_node = output_id + else: + if output_id == -1: + output_id = self.record_node(input_id) + ancestry_node = output_id + for x in X: + self.record_edge(left, right, output_id, x.node) + if is_sample and left != prev_right: + # Fill in any gaps in the ancestry for the sample + self.add_ancestry(input_id, prev_right, left, output_id) + self.add_ancestry(input_id, left, right, ancestry_node) + prev_right = right + + if is_sample and prev_right != self.sequence_length: + # If a trailing gap exists in the sample ancestry, fill it in. + self.add_ancestry(input_id, prev_right, self.sequence_length, output_id) + if output_id != -1: + num_edges = self.flush_edges() + if num_edges == 0 and not is_sample: + self.rewind_node(input_id, output_id) + + def process_parent_edges(self, edges): + """ + Process all of the edges for a given parent. + """ + assert len(set(e.parent for e in edges)) == 1 + parent = edges[0].parent + S = [] + for edge in edges: + x = self.A_head[edge.child] + while x is not None: + if x.right > edge.left and edge.right > x.left: + y = Segment(max(x.left, edge.left), min(x.right, edge.right), x.node) + S.append(y) + x = x.next + self.merge_labeled_ancestors(S, parent) + self.check_state() + # self.print_state() + + def finalise_sites(self): + # Build a map from the old mutation IDs to new IDs. Any mutation that + # has not been mapped to a node in the new tree sequence will be removed. + mutation_id_map = [-1 for _ in range(self.num_mutations)] + num_output_mutations = 0 + + for site in self.ts.sites(): + num_output_site_mutations = 0 + for mut in site.mutations: + mapped_node = self.mutation_node_map[mut.id] + mapped_parent = -1 + if mut.parent != -1: + mapped_parent = mutation_id_map[mut.parent] + if mapped_node != -1: + mutation_id_map[mut.id] = num_output_mutations + num_output_mutations += 1 + num_output_site_mutations += 1 + output_site = True + if self.filter_sites and num_output_site_mutations == 0: + output_site = False + + if output_site: + for mut in site.mutations: + if mutation_id_map[mut.id] != -1: + mapped_parent = -1 + if mut.parent != -1: + mapped_parent = mutation_id_map[mut.parent] + self.tables.mutations.add_row( + site=len(self.tables.sites), + node=self.mutation_node_map[mut.id], + parent=mapped_parent, + derived_state=mut.derived_state, + metadata=mut.metadata) + self.tables.sites.add_row( + position=site.position, ancestral_state=site.ancestral_state, + metadata=site.metadata) + + def map_mutation_nodes(self): + for input_node in range(len(self.mutation_map)): + mutations = self.mutation_map[input_node] + seg = self.A_head[input_node] + m_index = 0 + while seg is not None and m_index < len(mutations): + x, mutation_id = mutations[m_index] + if seg.left <= x < seg.right: + self.mutation_node_map[mutation_id] = seg.node + m_index += 1 + elif x >= seg.right: + seg = seg.next + else: + assert x < seg.left + m_index += 1 + + def finalise_references(self): + input_populations = self.ts.tables.populations + population_id_map = np.arange(len(input_populations) + 1, dtype=np.int32) + # Trick to ensure the null population gets mapped to itself. + population_id_map[-1] = -1 + input_individuals = self.ts.tables.individuals + individual_id_map = np.arange(len(input_individuals) + 1, dtype=np.int32) + # Trick to ensure the null individual gets mapped to itself. + individual_id_map[-1] = -1 + + population_ref_count = np.ones(len(input_populations), dtype=int) + if self.filter_populations: + population_ref_count[:] = 0 + population_id_map[:] = -1 + individual_ref_count = np.ones(len(input_individuals), dtype=int) + if self.filter_individuals: + individual_ref_count[:] = 0 + individual_id_map[:] = -1 + + for node in self.tables.nodes: + if self.filter_populations and node.population != tskit.NULL: + population_ref_count[node.population] += 1 + if self.filter_individuals and node.individual != tskit.NULL: + individual_ref_count[node.individual] += 1 + + for input_id, count in enumerate(population_ref_count): + if count > 0: + row = input_populations[input_id] + output_id = self.tables.populations.add_row(metadata=row.metadata) + population_id_map[input_id] = output_id + for input_id, count in enumerate(individual_ref_count): + if count > 0: + row = input_individuals[input_id] + output_id = self.tables.individuals.add_row( + flags=row.flags, location=row.location, metadata=row.metadata) + individual_id_map[input_id] = output_id + + # Remap the population ID references for nodes. + nodes = self.tables.nodes + nodes.set_columns( + flags=nodes.flags, + time=nodes.time, + metadata=nodes.metadata, + metadata_offset=nodes.metadata_offset, + individual=individual_id_map[nodes.individual], + population=population_id_map[nodes.population]) + + # We don't support migrations for now. We'll need to remap these as well. + assert self.ts.num_migrations == 0 + + def simplify(self): + # self.print_state() + if self.ts.num_edges > 0: + all_edges = list(self.ts.edges()) + edges = all_edges[:1] + for e in all_edges[1:]: + if e.parent != edges[0].parent: + self.process_parent_edges(edges) + edges = [] + edges.append(e) + self.process_parent_edges(edges) + # self.print_state() + self.map_mutation_nodes() + self.finalise_sites() + self.finalise_references() + ts = self.tables.tree_sequence() + return ts, self.node_id_map + + def check_state(self): + num_nodes = len(self.A_head) + for j in range(num_nodes): + head = self.A_head[j] + tail = self.A_tail[j] + if head is None: + assert tail is None + else: + x = head + while x.next is not None: + x = x.next + assert x == tail + x = head.next + while x is not None: + assert x.left < x.right + if x.next is not None: + assert x.right <= x.next.left + # We should also not have any squashable segments. + if x.right == x.next.left: + assert x.node != x.next.node + x = x.next + + +if __name__ == "__main__": + # Simple CLI for running simplifier above. + ts = tskit.load(sys.argv[1]) + samples = list(map(int, sys.argv[2:])) + s = Simplifier(ts, samples) + # s.print_state() + tss, _ = s.simplify() + tables = tss.dump_tables() + print("Output:") + print(tables.nodes) + print(tables.edges) + print(tables.sites) + print(tables.mutations) diff --git a/python/tests/test_dict_encoding.py b/python/tests/test_dict_encoding.py new file mode 100644 index 0000000000..25d948c065 --- /dev/null +++ b/python/tests/test_dict_encoding.py @@ -0,0 +1,334 @@ +""" +Test cases for the low-level dictionary encoding used to move +data around in C. +""" +from __future__ import print_function +from __future__ import division + +import unittest + +import msprime +import numpy as np + +import _tskit as c_module +import tskit + + +def get_example_tables(): + """ + Return a tree sequence that has data in all fields. + """ + pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] + migration_matrix = [[0, 1], [1, 0]] + ts = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1) + + tables = ts.dump_tables() + for j in range(ts.num_samples): + tables.individuals.add_row(flags=j, location=np.arange(j), metadata=b"x" * j) + tables.nodes.clear() + for node in ts.nodes(): + tables.nodes.add_row( + flags=node.flags, time=node.time, population=node.population, + individual=node.id if node.id < ts.num_samples else -1, + metadata=b"y" * node.id) + tables.sites.clear() + for site in ts.sites(): + tables.sites.add_row( + position=site.position, ancestral_state="A" * site.id, + metadata=b"q" * site.id) + tables.mutations.clear() + for mutation in ts.mutations(): + mut_id = tables.mutations.add_row( + site=mutation.site, node=mutation.node, parent=-1, + derived_state="C" * mutation.id, metadata=b"x" * mutation.id) + # Add another mutation on the same branch. + tables.mutations.add_row( + site=mutation.site, node=mutation.node, parent=mut_id, + derived_state="G" * mutation.id, metadata=b"y" * mutation.id) + for j in range(10): + tables.populations.add_row(metadata=b"p" * j) + tables.provenances.add_row(timestamp="x" * j, record="y" * j) + return tables + + +class TestRoundTrip(unittest.TestCase): + """ + Tests if we can do a simple round trip on simulated data. + """ + def verify(self, tables): + lwt = c_module.LightweightTableCollection() + lwt.fromdict(tables.asdict()) + other_tables = tskit.TableCollection.fromdict(lwt.asdict()) + self.assertEqual(tables, other_tables) + + def test_simple(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=2) + self.verify(ts.tables) + + def test_empty(self): + tables = tskit.TableCollection(sequence_length=1) + self.verify(tables) + + def test_individuals(self): + n = 10 + ts = msprime.simulate(n, mutation_rate=1, random_seed=2) + tables = ts.dump_tables() + for j in range(n): + tables.individuals.add_row(flags=j, location=(j, j), metadata=b"x" * j) + self.verify(tables) + + def test_sequence_length(self): + ts = msprime.simulate( + 10, recombination_rate=0.1, mutation_rate=1, length=0.99, random_seed=2) + self.verify(ts.tables) + + def test_migration(self): + pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] + migration_matrix = [[0, 1], [1, 0]] + ts = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1) + self.verify(ts.tables) + + def test_example(self): + self.verify(get_example_tables()) + + +class TestMissingData(unittest.TestCase): + """ + Tests what happens when we have missing data in the encoded dict. + """ + def test_missing_sequence_length(self): + tables = get_example_tables() + d = tables.asdict() + del d["sequence_length"] + lwt = c_module.LightweightTableCollection() + with self.assertRaises(ValueError): + lwt.fromdict(d) + + def test_missing_tables(self): + tables = get_example_tables() + d = tables.asdict() + table_names = set(d.keys()) - set(["sequence_length"]) + for table_name in table_names: + d = tables.asdict() + del d[table_name] + lwt = c_module.LightweightTableCollection() + with self.assertRaises(ValueError): + lwt.fromdict(d) + + def test_missing_columns(self): + tables = get_example_tables() + d = tables.asdict() + table_names = set(d.keys()) - set(["sequence_length"]) + for table_name in table_names: + table_dict = d[table_name] + for colname in table_dict.keys(): + copy = dict(table_dict) + del copy[colname] + lwt = c_module.LightweightTableCollection() + d = tables.asdict() + d[table_name] = copy + with self.assertRaises(ValueError): + lwt.fromdict(d) + + +class TestBadTypes(unittest.TestCase): + """ + Tests for setting each column to a type that can't be converted to 1D numpy array. + """ + + def verify_columns(self, value): + tables = get_example_tables() + d = tables.asdict() + table_names = set(d.keys()) - set(["sequence_length"]) + for table_name in table_names: + table_dict = d[table_name] + for colname in table_dict.keys(): + copy = dict(table_dict) + copy[colname] = value + lwt = c_module.LightweightTableCollection() + d = tables.asdict() + d[table_name] = copy + with self.assertRaises(ValueError): + lwt.fromdict(d) + + def test_2d_array(self): + self.verify_columns([[1, 2], [3, 4]]) + + def test_str(self): + self.verify_columns("aserg") + + def test_bad_top_level_types(self): + tables = get_example_tables() + d = tables.asdict() + for key in d.keys(): + bad_type_dict = tables.asdict() + # A list should be a ValueError for both the tables and sequence_length + bad_type_dict[key] = ["12345"] + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(bad_type_dict) + + +class TestBadLengths(unittest.TestCase): + """ + Tests for setting each column to a length incompatible with the table. + """ + def verify(self, num_rows): + + tables = get_example_tables() + d = tables.asdict() + table_names = set(d.keys()) - set(["sequence_length"]) + for table_name in sorted(table_names): + table_dict = d[table_name] + for colname in sorted(table_dict.keys()): + copy = dict(table_dict) + copy[colname] = table_dict[colname][:num_rows].copy() + lwt = c_module.LightweightTableCollection() + d = tables.asdict() + d[table_name] = copy + with self.assertRaises(ValueError): + lwt.fromdict(d) + + def test_two_rows(self): + self.verify(2) + + def test_zero_rows(self): + self.verify(0) + + +class TestRequiredAndOptionalColumns(unittest.TestCase): + """ + Tests that specifying None for some columns will give the intended + outcome. + """ + def verify_required_columns(self, tables, table_name, required_cols): + d = tables.asdict() + table_dict = {col: None for col in d[table_name].keys()} + for col in required_cols: + table_dict[col] = d[table_name][col] + lwt = c_module.LightweightTableCollection() + d[table_name] = table_dict + lwt.fromdict(d) + other = lwt.asdict() + for col in required_cols: + self.assertTrue(np.array_equal(other[table_name][col], table_dict[col])) + + # Removing any one of these required columns gives an error. + for col in required_cols: + d = tables.asdict() + copy = dict(table_dict) + copy[col] = None + d[table_name] = copy + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(d) + + def verify_optional_column(self, tables, table_len, table_name, col_name): + d = tables.asdict() + table_dict = d[table_name] + table_dict[col_name] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertTrue(np.array_equal( + out[table_name][col_name], + np.zeros(table_len, dtype=np.int32) - 1)) + + def verify_offset_pair(self, tables, table_len, table_name, col_name): + offset_col = col_name + "_offset" + + d = tables.asdict() + table_dict = d[table_name] + table_dict[col_name] = None + table_dict[offset_col] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertEqual(out[table_name][col_name].shape, (0,)) + self.assertTrue(np.array_equal( + out[table_name][offset_col], + np.zeros(table_len + 1, dtype=np.uint32))) + + # Setting one or the other raises a ValueError + d = tables.asdict() + table_dict = d[table_name] + table_dict[col_name] = None + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(d) + + d = tables.asdict() + table_dict = d[table_name] + table_dict[offset_col] = None + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(d) + + d = tables.asdict() + table_dict = d[table_name] + bad_offset = np.zeros_like(table_dict[offset_col]) + bad_offset[:-1] = table_dict[offset_col][:-1][::-1] + bad_offset[-1] = table_dict[offset_col][-1] + table_dict[offset_col] = bad_offset + lwt = c_module.LightweightTableCollection() + with self.assertRaises(c_module.LibraryError): + lwt.fromdict(d) + + def test_individuals(self): + tables = get_example_tables() + self.verify_required_columns(tables, "individuals", ["flags"]) + self.verify_offset_pair( + tables, len(tables.individuals), "individuals", "location") + self.verify_offset_pair( + tables, len(tables.individuals), "individuals", "metadata") + + def test_nodes(self): + tables = get_example_tables() + self.verify_offset_pair(tables, len(tables.nodes), "nodes", "metadata") + self.verify_optional_column(tables, len(tables.nodes), "nodes", "population") + self.verify_optional_column(tables, len(tables.nodes), "nodes", "individual") + self.verify_required_columns(tables, "nodes", ["flags", "time"]) + + def test_edges(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "edges", ["left", "right", "parent", "child"]) + + def test_migrations(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "migrations", ["left", "right", "node", "source", "dest", "time"]) + + def test_sites(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "sites", ["position", "ancestral_state", "ancestral_state_offset"]) + self.verify_offset_pair(tables, len(tables.sites), "sites", "metadata") + + def test_mutations(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "mutations", + ["site", "node", "derived_state", "derived_state_offset"]) + self.verify_offset_pair(tables, len(tables.mutations), "mutations", "metadata") + + def test_populations(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "populations", ["metadata", "metadata_offset"]) + + def test_provenances(self): + tables = get_example_tables() + self.verify_required_columns( + tables, "provenances", + ["record", "record_offset", "timestamp", "timestamp_offset"]) diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py new file mode 100644 index 0000000000..dd6b2125ac --- /dev/null +++ b/python/tests/test_drawing.py @@ -0,0 +1,551 @@ +# -*- coding: utf-8 -*- +""" +Test cases for visualisation in tskit. +""" +from __future__ import print_function +from __future__ import division + +import os +import sys +import tempfile +import unittest +import xml.etree + +import msprime +import six +import tskit +import tests.tsutil as tsutil + +IS_PY2 = sys.version_info[0] < 3 + + +class TestTreeDraw(unittest.TestCase): + """ + Tests for the tree drawing functionality. + """ + def get_binary_tree(self): + ts = msprime.simulate(10, random_seed=1, mutation_rate=1) + return next(ts.trees()) + + def get_nonbinary_tree(self): + demographic_events = [ + msprime.SimpleBottleneck(time=0.1, population=0, proportion=0.5)] + ts = msprime.simulate( + 10, recombination_rate=5, mutation_rate=10, + demographic_events=demographic_events, random_seed=1) + for t in ts.trees(): + for u in t.nodes(): + if len(t.children(u)) > 2: + return t + assert False + + def get_zero_edge_tree(self): + tables = tskit.TableCollection(sequence_length=2) + # These must be samples or we will have zero roots. + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.sites.add_row(position=0, ancestral_state="0") + tables.mutations.add_row(site=0, node=0, derived_state="1") + tables.mutations.add_row(site=0, node=1, derived_state="1") + return tables.tree_sequence().first() + + def get_zero_roots_tree(self): + tables = tskit.TableCollection(sequence_length=2) + # If we have no samples we have zero roots + tables.nodes.add_row(time=0) + tables.nodes.add_row(time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 2, 2, 0) + tables.edges.add_row(0, 2, 2, 1) + tree = tables.tree_sequence().first() + self.assertEqual(tree.num_roots, 0) + return tree + + def get_multiroot_tree(self): + ts = msprime.simulate(15, random_seed=1) + # Take off the top quarter of edges + tables = ts.dump_tables() + edges = tables.edges + n = len(edges) - len(edges) // 4 + edges.set_columns( + left=edges.left[:n], right=edges.right[:n], + parent=edges.parent[:n], child=edges.child[:n]) + ts = tables.tree_sequence() + for t in ts.trees(): + if t.num_roots > 1: + return t + assert False + + def get_mutations_over_roots_tree(self): + ts = msprime.simulate(15, random_seed=1) + ts = tsutil.decapitate(ts, 20) + tables = ts.dump_tables() + delta = 1.0 / (ts.num_nodes + 1) + x = 0 + for node in range(ts.num_nodes): + site_id = tables.sites.add_row(x, ancestral_state="0") + x += delta + tables.mutations.add_row(site_id, node=node, derived_state="1") + ts = tables.tree_sequence() + tree = ts.first() + assert any( + tree.parent(mut.node) == tskit.NULL + for mut in tree.mutations()) + return tree + + def get_unary_node_tree(self): + ts = msprime.simulate(2, random_seed=1) + tables = ts.dump_tables() + edges = tables.edges + # Take out all the edges except 1 + n = 1 + edges.set_columns( + left=edges.left[:n], right=edges.right[:n], + parent=edges.parent[:n], child=edges.child[:n]) + ts = tables.tree_sequence() + for t in ts.trees(): + for u in t.nodes(): + if len(t.children(u)) == 1: + return t + assert False + + def get_empty_tree(self): + tables = tskit.TableCollection(sequence_length=1) + ts = tables.tree_sequence() + return next(ts.trees()) + + +class TestFormats(TestTreeDraw): + """ + Tests that formats are recognised correctly. + """ + def test_svg_variants(self): + t = self.get_binary_tree() + for svg in ["svg", "SVG", "sVg"]: + output = t.draw(format=svg) + root = xml.etree.ElementTree.fromstring(output) + self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg") + + def test_default(self): + # Default is SVG + t = self.get_binary_tree() + output = t.draw(format=None) + root = xml.etree.ElementTree.fromstring(output) + self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg") + output = t.draw() + root = xml.etree.ElementTree.fromstring(output) + self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg") + + def test_ascii_variants(self): + t = self.get_binary_tree() + for fmt in ["ascii", "ASCII", "AScii"]: + output = t.draw(format=fmt) + self.assertRaises( + xml.etree.ElementTree.ParseError, xml.etree.ElementTree.fromstring, + output) + + def test_unicode_variants(self): + t = self.get_binary_tree() + for fmt in ["unicode", "UNICODE", "uniCODE"]: + if IS_PY2: + self.assertRaises(ValueError, t.draw, format=fmt) + else: + output = t.draw(format=fmt) + self.assertRaises( + xml.etree.ElementTree.ParseError, xml.etree.ElementTree.fromstring, + output) + + def test_bad_formats(self): + t = self.get_binary_tree() + for bad_format in ["", "ASC", "SV", "jpeg"]: + self.assertRaises(ValueError, t.draw, format=bad_format) + + +# TODO we should gather some of these tests into a superclass as they are +# very similar for SVG and ASCII. + +class TestDrawText(TestTreeDraw): + """ + Tests the ASCII tree drawing method. + """ + drawing_format = "ascii" + example_label = "XXX" + + def verify_basic_text(self, text): + self.assertTrue(isinstance(text, str)) + # TODO surely something else we can verify about this... + + def test_draw_defaults(self): + t = self.get_binary_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_draw_nonbinary(self): + t = self.get_nonbinary_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_draw_multiroot(self): + t = self.get_multiroot_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_draw_mutations_over_roots(self): + t = self.get_mutations_over_roots_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_draw_unary(self): + t = self.get_unary_node_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_draw_empty_tree(self): + t = self.get_empty_tree() + self.assertRaises(ValueError, t.draw, format=self.drawing_format) + + def test_draw_zero_roots_tree(self): + t = self.get_zero_roots_tree() + self.assertRaises(ValueError, t.draw, format=self.drawing_format) + + def test_draw_zero_edge_tree(self): + t = self.get_zero_edge_tree() + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_even_num_children_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 1 + 2 1 2 + 3 1 1 + 4 1 4 + 5 1 5 + 6 1 7 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 6 0 + 0 1 6 1 + 0 1 6 2 + 0 1 6 3 + 0 1 6 4 + 0 1 6 5 + """) + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_odd_num_children_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 1 + 2 1 2 + 3 1 1 + 4 1 4 + 5 1 5 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 5 0 + 0 1 5 1 + 0 1 5 2 + 0 1 5 3 + 0 1 5 4 + """) + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + text = t.draw(format=self.drawing_format) + self.verify_basic_text(text) + + def test_node_labels(self): + t = self.get_binary_tree() + labels = {u: self.example_label for u in t.nodes()} + text = t.draw(format=self.drawing_format, node_labels=labels) + self.verify_basic_text(text) + j = 0 + for _ in t.nodes(): + j = text[j:].find(self.example_label) + self.assertNotEqual(j, -1) + + def test_no_node_labels(self): + t = self.get_binary_tree() + labels = {} + text = t.draw(format=self.drawing_format, node_labels=labels) + self.verify_basic_text(text) + for u in t.nodes(): + self.assertEqual(text.find(str(u)), -1) + + +@unittest.skipIf(IS_PY2, "Unicode tree drawing not supported on Python 2") +class TestDrawUnicode(TestDrawText): + """ + Tests the Unicode tree drawing method + """ + drawing_format = "unicode" + example_label = "\u20ac" * 10 # euro symbol + + def verify_text_rendering(self, drawn, drawn_tree, debug=False): + if debug: + print("Drawn:") + print(drawn) + print("Expected:") + print(drawn_tree) + tree_lines = drawn_tree.splitlines() + drawn_lines = drawn.splitlines() + self.assertEqual(len(tree_lines), len(drawn_lines)) + for l1, l2 in zip(tree_lines, drawn_lines): + # Trailing white space isn't significant. + self.assertEqual(l1.rstrip(), l2.rstrip()) + + def test_simple_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 2 0 + 0 1 2 1 + """) + tree = ( + " 2 \n" + "┏┻┓\n" + "0 1") + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + drawn = t.draw(format="unicode") + self.verify_text_rendering(drawn, tree) + + def test_trident_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 3 0 + 0 1 3 1 + 0 1 3 2 + """) + tree = ( + " 3 \n" + "┏━╋━┓\n" + "0 1 2\n") + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + drawn = t.draw(format="unicode") + self.verify_text_rendering(drawn, tree) + + def test_pitchfork_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 1 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 4 0 + 0 1 4 1 + 0 1 4 2 + 0 1 4 3 + """) + tree = ( + " 4 \n" + "┏━┳┻┳━┓\n" + "0 1 2 3\n") + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + # No labels + tree = ( + " ┃ \n" + "┏━┳┻┳━┓\n" + "┃ ┃ ┃ ┃\n") + drawn = t.draw(format="unicode", node_labels={}) + self.verify_text_rendering(drawn, tree) + # Some lables + tree = ( + " ┃ \n" + "┏━┳┻┳━┓\n" + "0 ┃ ┃ 3\n") + drawn = t.draw(format="unicode", node_labels={0: "0", 3: "3"}) + self.verify_text_rendering(drawn, tree) + + def test_stick_tree(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 1 + 2 1 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 1 0 + 0 1 2 1 + """) + tree = ( + "2\n" + "┃\n" + "1\n" + "┃\n" + "0\n") + ts = tskit.load_text(nodes, edges, strict=False) + t = next(ts.trees()) + drawn = t.draw(format="unicode") + self.verify_text_rendering(drawn, tree) + + +class TestDrawSvg(TestTreeDraw): + """ + Tests the SVG tree drawing. + """ + def verify_basic_svg(self, svg, width=200, height=200): + root = xml.etree.ElementTree.fromstring(svg) + self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg") + self.assertEqual(width, int(root.attrib["width"])) + self.assertEqual(height, int(root.attrib["height"])) + + def test_draw_file(self): + t = self.get_binary_tree() + fd, filename = tempfile.mkstemp(prefix="tskit_viz_") + try: + os.close(fd) + svg = t.draw(path=filename) + self.assertGreater(os.path.getsize(filename), 0) + with open(filename) as tmp: + other_svg = tmp.read() + self.assertEqual(svg, other_svg) + finally: + os.unlink(filename) + + def test_draw_defaults(self): + t = self.get_binary_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_draw_nonbinary(self): + t = self.get_nonbinary_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_draw_multiroot(self): + t = self.get_multiroot_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_draw_mutations_over_roots(self): + t = self.get_mutations_over_roots_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_draw_unary(self): + t = self.get_unary_node_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_draw_empty(self): + t = self.get_empty_tree() + self.assertRaises(ValueError, t.draw) + + def test_draw_zero_roots(self): + t = self.get_zero_roots_tree() + self.assertRaises(ValueError, t.draw) + + def test_draw_zero_edge(self): + t = self.get_zero_edge_tree() + svg = t.draw() + self.verify_basic_svg(svg) + + def test_width_height(self): + t = self.get_binary_tree() + w = 123 + h = 456 + svg = t.draw(width=w, height=h) + self.verify_basic_svg(svg, w, h) + + def test_node_labels(self): + t = self.get_binary_tree() + labels = {u: "XXX" for u in t.nodes()} + svg = t.draw(format="svg", node_labels=labels) + self.verify_basic_svg(svg) + self.assertEqual(svg.count("XXX"), t.num_nodes) + + def test_one_node_label(self): + t = self.get_binary_tree() + labels = {0: "XXX"} + svg = t.draw(format="svg", node_labels=labels) + self.verify_basic_svg(svg) + self.assertEqual(svg.count("XXX"), 1) + + def test_no_node_labels(self): + t = self.get_binary_tree() + labels = {} + svg = t.draw(format="svg", node_labels=labels) + self.verify_basic_svg(svg) + # Can't really test for much here if we don't understand the SVG + + def test_one_node_colour(self): + t = self.get_binary_tree() + colour = "rgb(0, 1, 2)" + colours = {0: colour} + svg = t.draw(format="svg", node_colours=colours) + self.verify_basic_svg(svg) + self.assertEqual(svg.count('fill="{}"'.format(colour)), 1) + + def test_all_nodes_colour(self): + t = self.get_binary_tree() + colours = {u: "rgb({}, {}, {})".format(u, u, u) for u in t.nodes()} + svg = t.draw(format="svg", node_colours=colours) + self.verify_basic_svg(svg) + for colour in colours.values(): + self.assertEqual(svg.count('fill="{}"'.format(colour)), 1) + + def test_mutation_labels(self): + t = self.get_binary_tree() + labels = {u.id: "XXX" for u in t.mutations()} + svg = t.draw(format="svg", mutation_labels=labels) + self.verify_basic_svg(svg) + self.assertEqual(svg.count("XXX"), t.num_mutations) + + def test_one_mutation_label(self): + t = self.get_binary_tree() + labels = {0: "XXX"} + svg = t.draw(format="svg", mutation_labels=labels) + self.verify_basic_svg(svg) + self.assertEqual(svg.count("XXX"), 1) + + def test_no_mutation_labels(self): + t = self.get_binary_tree() + labels = {} + svg = t.draw(format="svg", mutation_labels=labels) + self.verify_basic_svg(svg) + # Can't really test for much here if we don't understand the SVG + + def test_one_mutation_colour(self): + t = self.get_binary_tree() + colour = "rgb(0, 1, 2)" + colours = {0: colour} + svg = t.draw(format="svg", mutation_colours=colours) + self.verify_basic_svg(svg) + self.assertEqual(svg.count('fill="{}"'.format(colour)), 1) + + def test_all_mutations_colour(self): + t = self.get_binary_tree() + colours = { + mut.id: "rgb({}, {}, {})".format(mut.id, mut.id, mut.id) + for mut in t.mutations()} + svg = t.draw(format="svg", mutation_colours=colours) + self.verify_basic_svg(svg) + for colour in colours.values(): + self.assertEqual(svg.count('fill="{}"'.format(colour)), 1) diff --git a/python/tests/test_file_format.py b/python/tests/test_file_format.py new file mode 100644 index 0000000000..4381619db1 --- /dev/null +++ b/python/tests/test_file_format.py @@ -0,0 +1,675 @@ +""" +Test cases for tskit's file format. +""" +from __future__ import print_function +from __future__ import division + +import os +import tempfile +import unittest +import uuid as _uuid +import json + +import h5py +import kastore +import numpy as np +import msprime + +import tskit +import tskit.exceptions as exceptions +import tests.tsutil as tsutil + + +CURRENT_FILE_MAJOR = 12 + + +def single_locus_no_mutation_example(): + return msprime.simulate(10, random_seed=10) + + +def single_locus_with_mutation_example(): + return msprime.simulate(10, mutation_rate=10, random_seed=11) + + +def multi_locus_with_mutation_example(): + return msprime.simulate( + 10, recombination_rate=1, length=10, mutation_rate=10, + random_seed=2) + + +def recurrent_mutation_example(): + ts = msprime.simulate(10, recombination_rate=1, length=10, random_seed=2) + return tsutil.insert_branch_mutations(ts) + + +def general_mutation_example(): + ts = msprime.simulate(10, recombination_rate=1, length=10, random_seed=2) + tables = ts.dump_tables() + tables.sites.add_row(position=0, ancestral_state="A", metadata=b"{}") + tables.sites.add_row(position=1, ancestral_state="C", metadata=b"{'id':1}") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=1, node=0, derived_state="G") + return tables.tree_sequence() + + +def multichar_mutation_example(): + ts = msprime.simulate(10, recombination_rate=1, length=10, random_seed=2) + return tsutil.insert_multichar_mutations(ts) + + +def migration_example(): + n = 10 + t = 1 + population_configurations = [ + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(0), + ] + demographic_events = [ + msprime.MassMigration(time=t, source=0, destination=2), + msprime.MassMigration(time=t, source=1, destination=2), + ] + ts = msprime.simulate( + population_configurations=population_configurations, + demographic_events=demographic_events, + random_seed=1, record_migrations=True) + return ts + + +def bottleneck_example(): + return msprime.simulate( + random_seed=1, + sample_size=100, + recombination_rate=0.1, + length=10, + demographic_events=[ + msprime.SimpleBottleneck(time=0.01, population=0, proportion=0.75)]) + + +def historical_sample_example(): + return msprime.simulate( + recombination_rate=0.1, + length=10, + random_seed=1, + samples=[(0, j) for j in range(10)]) + + +def no_provenance_example(): + ts = msprime.simulate(10, random_seed=1) + tables = ts.dump_tables() + tables.provenances.clear() + return tables.tree_sequence() + + +def provenance_timestamp_only_example(): + ts = msprime.simulate(10, random_seed=1) + tables = ts.dump_tables() + provenances = tskit.ProvenanceTable() + provenances.add_row(timestamp="12345", record="") + return tables.tree_sequence() + + +def node_metadata_example(): + ts = msprime.simulate( + sample_size=100, recombination_rate=0.1, length=10, random_seed=1) + tables = ts.dump_tables() + metadatas = ["n_{}".format(u) for u in range(ts.num_nodes)] + packed, offset = tskit.pack_strings(metadatas) + tables.nodes.set_columns( + metadata=packed, metadata_offset=offset, + flags=tables.nodes.flags, time=tables.nodes.time) + return tables.tree_sequence() + + +def site_metadata_example(): + ts = msprime.simulate(10, length=10, random_seed=2) + tables = ts.dump_tables() + for j in range(10): + tables.sites.add_row(j, ancestral_state="a", metadata=b"1234") + return tables.tree_sequence() + + +def mutation_metadata_example(): + ts = msprime.simulate(10, length=10, random_seed=2) + tables = ts.dump_tables() + tables.sites.add_row(0, ancestral_state="a") + for j in range(10): + tables.mutations.add_row( + site=0, node=j, derived_state="t", metadata=b"1234") + return tables.tree_sequence() + + +class TestFileFormat(unittest.TestCase): + """ + Superclass of file format tests. + """ + def setUp(self): + fd, self.temp_file = tempfile.mkstemp(prefix="msp_file_test_") + os.close(fd) + + def tearDown(self): + os.unlink(self.temp_file) + + +class TestLoadLegacyExamples(TestFileFormat): + """ + Tests using the saved legacy file examples to ensure we can load them. + """ + def verify_tree_sequence(self, ts): + # Just some quick checks to make sure the tree sequence makes sense. + self.assertGreater(ts.sample_size, 0) + self.assertGreater(ts.num_edges, 0) + self.assertGreater(ts.num_sites, 0) + self.assertGreater(ts.num_mutations, 0) + self.assertGreater(ts.sequence_length, 0) + for t in ts.trees(): + left, right = t.interval + self.assertGreater(right, left) + for site in t.sites(): + self.assertTrue(left <= site.position < right) + for mut in site.mutations: + self.assertEqual(mut.site, site.id) + + def test_format_too_old_raised_for_hdf5(self): + files = [ + "tests/data/hdf5-formats/msprime-0.3.0_v2.0.hdf5", + "tests/data/hdf5-formats/msprime-0.4.0_v3.1.hdf5", + "tests/data/hdf5-formats/msprime-0.5.0_v10.0.hdf5"] + for filename in files: + self.assertRaises(exceptions.VersionTooOldError, tskit.load, filename) + + def test_msprime_v_0_5_0(self): + ts = tskit.load_legacy("tests/data/hdf5-formats/msprime-0.5.0_v10.0.hdf5") + self.verify_tree_sequence(ts) + + def test_msprime_v_0_4_0(self): + ts = tskit.load_legacy("tests/data/hdf5-formats/msprime-0.4.0_v3.1.hdf5") + self.verify_tree_sequence(ts) + + def test_msprime_v_0_3_0(self): + ts = tskit.load_legacy("tests/data/hdf5-formats/msprime-0.3.0_v2.0.hdf5") + self.verify_tree_sequence(ts) + + +class TestRoundTrip(TestFileFormat): + """ + Tests if we can round trip convert a tree sequence in memory + through a V2 file format and a V3 format. + """ + def verify_tree_sequences_equal(self, ts, tsp, simplify=True): + self.assertEqual(ts.sequence_length, tsp.sequence_length) + t1 = ts.tables + # We need to sort and squash the edges in the new format because it + # has gone through an edgesets representation. Simplest way to do this + # is to call simplify. + if simplify: + t2 = tsp.simplify().tables + else: + t2 = tsp.tables + self.assertEqual(t1.nodes, t2.nodes) + self.assertEqual(t1.edges, t2.edges) + self.assertEqual(t1.sites, t2.sites) + self.assertEqual(t1.mutations, t2.mutations) + + def verify_round_trip(self, ts, version): + tskit.dump_legacy(ts, self.temp_file, version=version) + tsp = tskit.load_legacy(self.temp_file) + simplify = version < 10 + self.verify_tree_sequences_equal(ts, tsp, simplify=simplify) + tsp.dump(self.temp_file) + tsp = tskit.load(self.temp_file) + self.verify_tree_sequences_equal(ts, tsp, simplify=simplify) + for provenance in tsp.provenances(): + tskit.validate_provenance(json.loads(provenance.record)) + + def verify_malformed_json_v2(self, ts, group_name, attr, bad_json): + tskit.dump_legacy(ts, self.temp_file, 2) + # Write some bad JSON to the provenance string. + root = h5py.File(self.temp_file, "r+") + group = root[group_name] + group.attrs[attr] = bad_json + root.close() + tsp = tskit.load_legacy(self.temp_file) + self.verify_tree_sequences_equal(ts, tsp) + + def test_malformed_json_v2(self): + ts = multi_locus_with_mutation_example() + for group_name in ["trees", "mutations"]: + for attr in ["environment", "parameters"]: + for bad_json in ["", "{", "{},"]: + self.verify_malformed_json_v2(ts, group_name, attr, bad_json) + + def test_single_locus_no_mutation(self): + self.verify_round_trip(single_locus_no_mutation_example(), 2) + self.verify_round_trip(single_locus_no_mutation_example(), 3) + self.verify_round_trip(single_locus_no_mutation_example(), 10) + + def test_single_locus_with_mutation(self): + self.verify_round_trip(single_locus_with_mutation_example(), 2) + self.verify_round_trip(single_locus_with_mutation_example(), 3) + self.verify_round_trip(single_locus_with_mutation_example(), 10) + + def test_multi_locus_with_mutation(self): + self.verify_round_trip(multi_locus_with_mutation_example(), 2) + self.verify_round_trip(multi_locus_with_mutation_example(), 3) + self.verify_round_trip(multi_locus_with_mutation_example(), 10) + + def test_migration_example(self): + self.verify_round_trip(migration_example(), 2) + self.verify_round_trip(migration_example(), 3) + self.verify_round_trip(migration_example(), 10) + + def test_bottleneck_example(self): + self.verify_round_trip(migration_example(), 3) + self.verify_round_trip(migration_example(), 10) + + def test_no_provenance(self): + self.verify_round_trip(no_provenance_example(), 10) + + def test_provenance_timestamp_only(self): + self.verify_round_trip(provenance_timestamp_only_example(), 10) + + def test_recurrent_mutation_example(self): + ts = recurrent_mutation_example() + for version in [2, 3]: + self.assertRaises( + ValueError, tskit.dump_legacy, ts, self.temp_file, version) + self.verify_round_trip(ts, 10) + + def test_general_mutation_example(self): + ts = general_mutation_example() + for version in [2, 3]: + self.assertRaises( + ValueError, tskit.dump_legacy, ts, self.temp_file, version) + self.verify_round_trip(ts, 10) + + def test_node_metadata_example(self): + self.verify_round_trip(node_metadata_example(), 10) + + def test_site_metadata_example(self): + self.verify_round_trip(site_metadata_example(), 10) + + def test_mutation_metadata_example(self): + self.verify_round_trip(mutation_metadata_example(), 10) + + def test_multichar_mutation_example(self): + self.verify_round_trip(multichar_mutation_example(), 10) + + def test_empty_file(self): + tables = tskit.TableCollection(sequence_length=3) + self.verify_round_trip(tables.tree_sequence(), 10) + + def test_zero_edges(self): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(time=0) + self.verify_round_trip(tables.tree_sequence(), 10) + + def test_v2_no_samples(self): + ts = multi_locus_with_mutation_example() + tskit.dump_legacy(ts, self.temp_file, version=2) + root = h5py.File(self.temp_file, "r+") + del root['samples'] + root.close() + tsp = tskit.load_legacy(self.temp_file) + self.verify_tree_sequences_equal(ts, tsp) + + def test_duplicate_mutation_positions_single_value(self): + ts = multi_locus_with_mutation_example() + for version in [2, 3]: + tskit.dump_legacy(ts, self.temp_file, version=version) + root = h5py.File(self.temp_file, "r+") + root['mutations/position'][:] = 0 + root.close() + self.assertRaises( + tskit.DuplicatePositionsError, tskit.load_legacy, self.temp_file) + tsp = tskit.load_legacy( + self.temp_file, remove_duplicate_positions=True) + self.assertEqual(tsp.num_sites, 1) + sites = list(tsp.sites()) + self.assertEqual(sites[0].position, 0) + + def test_duplicate_mutation_positions(self): + ts = multi_locus_with_mutation_example() + for version in [2, 3]: + tskit.dump_legacy(ts, self.temp_file, version=version) + root = h5py.File(self.temp_file, "r+") + position = np.array(root['mutations/position']) + position[0] = position[1] + root['mutations/position'][:] = position + root.close() + self.assertRaises( + tskit.DuplicatePositionsError, tskit.load_legacy, self.temp_file) + tsp = tskit.load_legacy( + self.temp_file, remove_duplicate_positions=True) + self.assertEqual(tsp.num_sites, position.shape[0] - 1) + position_after = list(s.position for s in tsp.sites()) + self.assertEqual(list(position[1:]), position_after) + + +class TestErrors(TestFileFormat): + """ + Test various API errors. + """ + def test_v2_non_binary_records(self): + demographic_events = [ + msprime.SimpleBottleneck(time=0.01, population=0, proportion=1) + ] + ts = msprime.simulate( + sample_size=10, + demographic_events=demographic_events, + random_seed=1) + self.assertRaises(ValueError, tskit.dump_legacy, ts, self.temp_file, 2) + + def test_unsupported_version(self): + ts = msprime.simulate(10) + self.assertRaises(ValueError, tskit.dump_legacy, ts, self.temp_file, version=4) + # Cannot read current files. + ts.dump(self.temp_file) + # Catch Exception here because h5py throws different exceptions on py2 and py3 + self.assertRaises(Exception, tskit.load_legacy, self.temp_file) + + def test_no_version_number(self): + root = h5py.File(self.temp_file, "w") + root.attrs["x"] = 0 + root.close() + self.assertRaises(ValueError, tskit.load_legacy, self.temp_file) + + +class TestDumpFormat(TestFileFormat): + """ + Tests on the on-disk file format. + """ + def verify_keys(self, ts): + keys = [ + "edges/child", + "edges/left", + "edges/parent", + "edges/right", + "format/name", + "format/version", + "indexes/edge_insertion_order", + "indexes/edge_removal_order", + "individuals/flags", + "individuals/location", + "individuals/location_offset", + "individuals/metadata", + "individuals/metadata_offset", + "migrations/dest", + "migrations/left", + "migrations/node", + "migrations/right", + "migrations/source", + "migrations/time", + "mutations/derived_state", + "mutations/derived_state_offset", + "mutations/metadata", + "mutations/metadata_offset", + "mutations/node", + "mutations/parent", + "mutations/site", + "nodes/flags", + "nodes/individual", + "nodes/metadata", + "nodes/metadata_offset", + "nodes/population", + "nodes/time", + "populations/metadata", + "populations/metadata_offset", + "provenances/record", + "provenances/record_offset", + "provenances/timestamp", + "provenances/timestamp_offset", + "sequence_length", + "sites/ancestral_state", + "sites/ancestral_state_offset", + "sites/metadata", + "sites/metadata_offset", + "sites/position", + "uuid", + ] + ts.dump(self.temp_file) + store = kastore.load(self.temp_file) + self.assertEqual(sorted(list(store.keys())), keys) + + def verify_uuid(self, ts, uuid): + self.assertEqual(len(uuid), 36) + # Check that the UUID is well-formed. + parsed = _uuid.UUID("{" + uuid + "}") + self.assertEqual(str(parsed), uuid) + self.assertEqual(uuid, ts.file_uuid) + + def verify_dump_format(self, ts): + ts.dump(self.temp_file) + self.assertTrue(os.path.exists(self.temp_file)) + self.assertGreater(os.path.getsize(self.temp_file), 0) + self.verify_keys(ts) + + store = kastore.load(self.temp_file) + # Check the basic root attributes + format_name = store['format/name'] + self.assertTrue(np.array_equal( + np.array(bytearray(b"tskit.trees"), dtype=np.int8), format_name)) + format_version = store['format/version'] + self.assertEqual(format_version[0], CURRENT_FILE_MAJOR) + self.assertEqual(format_version[1], 0) + self.assertEqual(ts.sequence_length, store['sequence_length'][0]) + # Load another copy from file so we can check the uuid. + other_ts = tskit.load(self.temp_file) + self.verify_uuid(other_ts, store["uuid"].tobytes().decode()) + + tables = ts.tables + + self.assertTrue(np.array_equal( + tables.individuals.flags, store["individuals/flags"])) + self.assertTrue(np.array_equal( + tables.individuals.location, store["individuals/location"])) + self.assertTrue(np.array_equal( + tables.individuals.location_offset, store["individuals/location_offset"])) + self.assertTrue(np.array_equal( + tables.individuals.metadata, store["individuals/metadata"])) + self.assertTrue(np.array_equal( + tables.individuals.metadata_offset, store["individuals/metadata_offset"])) + + self.assertTrue(np.array_equal(tables.nodes.flags, store["nodes/flags"])) + self.assertTrue(np.array_equal(tables.nodes.time, store["nodes/time"])) + self.assertTrue(np.array_equal( + tables.nodes.population, store["nodes/population"])) + self.assertTrue(np.array_equal( + tables.nodes.individual, store["nodes/individual"])) + self.assertTrue(np.array_equal( + tables.nodes.metadata, store["nodes/metadata"])) + self.assertTrue(np.array_equal( + tables.nodes.metadata_offset, store["nodes/metadata_offset"])) + + self.assertTrue(np.array_equal(tables.edges.left, store["edges/left"])) + self.assertTrue(np.array_equal(tables.edges.right, store["edges/right"])) + self.assertTrue(np.array_equal(tables.edges.parent, store["edges/parent"])) + self.assertTrue(np.array_equal(tables.edges.child, store["edges/child"])) + + left = tables.edges.left + right = tables.edges.right + parent = tables.edges.parent + child = tables.edges.child + time = tables.nodes.time + in_order = sorted( + range(ts.num_edges), + key=lambda j: (left[j], time[parent[j]], parent[j], child[j])) + out_order = sorted( + range(ts.num_edges), + key=lambda j: (right[j], -time[parent[j]], -parent[j], -child[j])) + self.assertTrue(np.array_equal( + np.array(in_order, dtype=np.int32), store["indexes/edge_insertion_order"])) + self.assertTrue(np.array_equal( + np.array(out_order, dtype=np.int32), store["indexes/edge_removal_order"])) + + self.assertTrue( + np.array_equal(tables.migrations.left, store["migrations/left"])) + self.assertTrue( + np.array_equal(tables.migrations.right, store["migrations/right"])) + self.assertTrue( + np.array_equal(tables.migrations.node, store["migrations/node"])) + self.assertTrue( + np.array_equal(tables.migrations.source, store["migrations/source"])) + self.assertTrue( + np.array_equal(tables.migrations.dest, store["migrations/dest"])) + self.assertTrue( + np.array_equal(tables.migrations.time, store["migrations/time"])) + + self.assertTrue(np.array_equal(tables.sites.position, store["sites/position"])) + self.assertTrue(np.array_equal( + tables.sites.ancestral_state, store["sites/ancestral_state"])) + self.assertTrue(np.array_equal( + tables.sites.ancestral_state_offset, store["sites/ancestral_state_offset"])) + self.assertTrue(np.array_equal( + tables.sites.metadata, store["sites/metadata"])) + self.assertTrue(np.array_equal( + tables.sites.metadata_offset, store["sites/metadata_offset"])) + + self.assertTrue(np.array_equal(tables.mutations.site, store["mutations/site"])) + self.assertTrue(np.array_equal(tables.mutations.node, store["mutations/node"])) + self.assertTrue(np.array_equal( + tables.mutations.parent, store["mutations/parent"])) + self.assertTrue(np.array_equal( + tables.mutations.derived_state, store["mutations/derived_state"])) + self.assertTrue(np.array_equal( + tables.mutations.derived_state_offset, + store["mutations/derived_state_offset"])) + self.assertTrue(np.array_equal( + tables.mutations.metadata, store["mutations/metadata"])) + self.assertTrue(np.array_equal( + tables.mutations.metadata_offset, store["mutations/metadata_offset"])) + + self.assertTrue(np.array_equal( + tables.populations.metadata, store["populations/metadata"])) + self.assertTrue(np.array_equal( + tables.populations.metadata_offset, store["populations/metadata_offset"])) + + self.assertTrue(np.array_equal( + tables.provenances.record, store["provenances/record"])) + self.assertTrue(np.array_equal( + tables.provenances.record_offset, store["provenances/record_offset"])) + self.assertTrue(np.array_equal( + tables.provenances.timestamp, store["provenances/timestamp"])) + self.assertTrue(np.array_equal( + tables.provenances.timestamp_offset, store["provenances/timestamp_offset"])) + + store.close() + + def test_single_locus_no_mutation(self): + self.verify_dump_format(single_locus_no_mutation_example()) + + def test_single_locus_with_mutation(self): + self.verify_dump_format(single_locus_with_mutation_example()) + + def test_multi_locus_with_mutation(self): + self.verify_dump_format(multi_locus_with_mutation_example()) + + def test_migration_example(self): + self.verify_dump_format(migration_example()) + + def test_bottleneck_example(self): + self.verify_dump_format(bottleneck_example()) + + def test_historical_sample_example(self): + self.verify_dump_format(historical_sample_example()) + + def test_node_metadata_example(self): + self.verify_dump_format(node_metadata_example()) + + def test_site_metadata_example(self): + self.verify_dump_format(site_metadata_example()) + + def test_mutation_metadata_example(self): + self.verify_dump_format(mutation_metadata_example()) + + def test_general_mutation_example(self): + self.verify_dump_format(general_mutation_example()) + + def test_multichar_mutation_example(self): + self.verify_dump_format(multichar_mutation_example()) + + +class TestUuid(TestFileFormat): + """ + Basic tests for the UUID generation. + """ + def test_different_files_same_ts(self): + ts = msprime.simulate(10) + uuids = [] + for _ in range(10): + ts.dump(self.temp_file) + with kastore.load(self.temp_file) as store: + uuids.append(store["uuid"].tobytes().decode()) + self.assertEqual(len(uuids), len(set(uuids))) + + +class TestFileFormatErrors(TestFileFormat): + """ + Tests for errors in the HDF5 format. + """ + current_major_version = 12 + + def verify_fields(self, ts): + ts.dump(self.temp_file) + with kastore.load(self.temp_file) as store: + all_data = dict(store) + for key in all_data.keys(): + data = dict(all_data) + del data[key] + kastore.dump(data, self.temp_file) + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + + def test_missing_fields(self): + self.verify_fields(migration_example()) + + def test_load_empty_kastore(self): + kastore.dump({}, self.temp_file) + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + + def test_load_non_tskit_hdf5(self): + with h5py.File(self.temp_file, "w") as root: + root["x"] = np.zeros(10) + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + + def test_old_version_load_error(self): + ts = msprime.simulate(10, random_seed=1) + for bad_version in [(0, 1), (0, 8), (2, 0), (CURRENT_FILE_MAJOR - 1, 0)]: + ts.dump(self.temp_file) + with kastore.load(self.temp_file) as store: + data = dict(store) + data["format/version"] = np.array(bad_version, dtype=np.uint32) + kastore.dump(data, self.temp_file) + self.assertRaises(tskit.VersionTooOldError, tskit.load, self.temp_file) + + def test_new_version_load_error(self): + ts = msprime.simulate(10, random_seed=1) + for bad_version in [(CURRENT_FILE_MAJOR + j, 0) for j in range(1, 5)]: + ts.dump(self.temp_file) + with kastore.load(self.temp_file) as store: + data = dict(store) + data["format/version"] = np.array(bad_version, dtype=np.uint32) + kastore.dump(data, self.temp_file) + self.assertRaises(tskit.VersionTooNewError, tskit.load, self.temp_file) + + def test_format_name_error(self): + ts = msprime.simulate(10) + for bad_name in ["tskit.tree", "tskit.treesAndOther", "", "x"*100]: + ts.dump(self.temp_file) + with kastore.load(self.temp_file) as store: + data = dict(store) + data["format/name"] = np.array(bytearray(bad_name.encode()), dtype=np.int8) + kastore.dump(data, self.temp_file) + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + + def test_load_bad_formats(self): + # try loading a bunch of files in various formats. + # First, check the emtpy file. + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + # Now some ascii text + with open(self.temp_file, "wb") as f: + f.write(b"Some ASCII text") + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) + # Now write 8k of random bytes + with open(self.temp_file, "wb") as f: + f.write(os.urandom(8192)) + self.assertRaises(exceptions.FileFormatError, tskit.load, self.temp_file) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py new file mode 100644 index 0000000000..59093191b3 --- /dev/null +++ b/python/tests/test_highlevel.py @@ -0,0 +1,2046 @@ +""" +Test cases for the high level interface to msprime. +""" +from __future__ import print_function +from __future__ import division + +try: + # We use the zip as iterator functionality here. + from future_builtins import zip +except ImportError: + # This fails for Python 3.x, but that's fine. + pass + +import collections +import itertools +import json +import math +import os +import random +import shutil +import six +import tempfile +import unittest +import warnings +import uuid as _uuid + +import numpy as np +import msprime + +import tskit +import _tskit +import tests as tests +import tests.tsutil as tsutil +import tests.simplify as simplify + + +def insert_uniform_mutations(tables, num_mutations, nodes): + """ + Returns n evenly mutations over the specified list of nodes. + """ + for j in range(num_mutations): + tables.sites.add_row( + position=j * (tables.sequence_length / num_mutations), ancestral_state='0', + metadata=json.dumps({"index": j}).encode()) + tables.mutations.add_row( + site=j, derived_state='1', node=nodes[j % len(nodes)], + metadata=json.dumps({"index": j}).encode()) + + +def get_table_collection_copy(tables, sequence_length): + """ + Returns a copy of the specified table collection with the specified + sequence length. + """ + table_dict = tables.asdict() + table_dict["sequence_length"] = sequence_length + return tskit.TableCollection.fromdict(table_dict) + + +def insert_gap(ts, position, length): + """ + Inserts a gap of the specified size into the specified tree sequence. + This involves: (1) breaking all edges that intersect with this point; + and (2) shifting all coordinates greater than this value up by the + gap length. + """ + new_edges = [] + for e in ts.edges(): + if e.left < position < e.right: + new_edges.append([e.left, position, e.parent, e.child]) + new_edges.append([position, e.right, e.parent, e.child]) + else: + new_edges.append([e.left, e.right, e.parent, e.child]) + + # Now shift up all coordinates. + for e in new_edges: + # Left coordinates == position get shifted + if e[0] >= position: + e[0] += length + # Right coordinates == position do not get shifted + if e[1] > position: + e[1] += length + tables = ts.dump_tables() + L = ts.sequence_length + length + tables = get_table_collection_copy(tables, L) + tables.edges.clear() + tables.sites.clear() + tables.mutations.clear() + for left, right, parent, child in new_edges: + tables.edges.add_row(left, right, parent, child) + tables.sort() + # Throw in a bunch of mutations over the whole sequence on the samples. + insert_uniform_mutations(tables, 100, list(ts.samples())) + return tables.tree_sequence() + + +def get_gap_examples(): + """ + Returns example tree sequences that contain gaps within the list of + edges. + """ + ts = msprime.simulate(20, random_seed=56, recombination_rate=1) + + assert ts.num_trees > 1 + + gap = 0.0125 + for x in [0, 0.1, 0.5, 0.75]: + ts = insert_gap(ts, x, gap) + found = False + for t in ts.trees(): + if t.interval[0] == x: + assert t.interval[1] == x + gap + assert len(t.parent_dict) == 0 + found = True + assert found + yield ts + # Give an example with a gap at the end. + ts = msprime.simulate(10, random_seed=5, recombination_rate=1) + tables = get_table_collection_copy(ts.dump_tables(), 2) + tables.sites.clear() + tables.mutations.clear() + insert_uniform_mutations(tables, 100, list(ts.samples())) + yield tables.tree_sequence() + + +def get_internal_samples_examples(): + """ + Returns example tree sequences with internal samples. + """ + n = 5 + ts = msprime.simulate(n, random_seed=10, mutation_rate=5) + assert ts.num_mutations > 0 + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + # Set all nodes to be samples. + flags[:] = tskit.NODE_IS_SAMPLE + nodes.set_columns(flags=flags, time=nodes.time, population=nodes.population) + yield tables.tree_sequence() + + # Set just internal nodes to be samples. + flags[:] = 0 + flags[n:] = tskit.NODE_IS_SAMPLE + nodes.set_columns(flags=flags, time=nodes.time, population=nodes.population) + yield tables.tree_sequence() + + # Set a mixture of internal and leaf samples. + flags[:] = 0 + flags[n // 2: n + n // 2] = tskit.NODE_IS_SAMPLE + nodes.set_columns(flags=flags, time=nodes.time, population=nodes.population) + yield tables.tree_sequence() + + +def get_decapitated_examples(): + """ + Returns example tree sequences in which the oldest edges have been removed. + """ + ts = msprime.simulate(10, random_seed=1234) + yield tsutil.decapitate(ts, ts.num_edges // 2) + + ts = msprime.simulate(20, recombination_rate=1, random_seed=1234) + assert ts.num_trees > 2 + yield tsutil.decapitate(ts, ts.num_edges // 4) + + +def get_example_tree_sequences(back_mutations=True, gaps=True, internal_samples=True): + if gaps: + for ts in get_decapitated_examples(): + yield ts + for ts in get_gap_examples(): + yield ts + if internal_samples: + for ts in get_internal_samples_examples(): + yield ts + seed = 1 + for n in [2, 3, 10, 100]: + for m in [1, 2, 32]: + for rho in [0, 0.1, 0.5]: + recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) + ts = msprime.simulate( + recombination_map=recomb_map, mutation_rate=0.1, + random_seed=seed, + population_configurations=[ + msprime.PopulationConfiguration(n), + msprime.PopulationConfiguration(0)], + migration_matrix=[[0, 1], [1, 0]]) + ts = tsutil.insert_random_ploidy_individuals(ts, 4, seed=seed) + yield tsutil.add_random_metadata(ts, seed=seed) + seed += 1 + for ts in get_bottleneck_examples(): + yield msprime.mutate( + ts, rate=0.1, random_seed=seed, + model=msprime.InfiniteSites(msprime.NUCLEOTIDES)) + ts = msprime.simulate(15, length=4, recombination_rate=1) + assert ts.num_trees > 1 + if back_mutations: + yield tsutil.insert_branch_mutations(ts, mutations_per_branch=2) + ts = tsutil.insert_multichar_mutations(ts) + yield ts + yield tsutil.add_random_metadata(ts) + + +def get_bottleneck_examples(): + """ + Returns an iterator of example tree sequences with nonbinary + trees. + """ + bottlenecks = [ + msprime.SimpleBottleneck(0.01, 0, proportion=0.05), + msprime.SimpleBottleneck(0.02, 0, proportion=0.25), + msprime.SimpleBottleneck(0.03, 0, proportion=1)] + for n in [3, 10, 100]: + ts = msprime.simulate( + n, length=100, recombination_rate=1, + demographic_events=bottlenecks, + random_seed=n) + yield ts + + +def get_back_mutation_examples(): + """ + Returns an iterator of example tree sequences with nonbinary + trees. + """ + ts = msprime.simulate(10, random_seed=1) + for j in [1, 2, 3]: + yield tsutil.insert_branch_mutations(ts, mutations_per_branch=j) + for ts in get_bottleneck_examples(): + yield tsutil.insert_branch_mutations(ts) + + +def simple_get_pairwise_diversity(haplotypes): + """ + Returns the value of pi for the specified haplotypes. + """ + # Very simplistic algorithm... + n = len(haplotypes) + pi = 0 + for k in range(n): + for j in range(k): + for u, v in zip(haplotypes[j], haplotypes[k]): + pi += u != v + return 2 * pi / (n * (n - 1)) + + +def get_pairwise_diversity(tree_sequence, samples=None): + """ + This is the exact algorithm used by the low-level C code + and should return identical results. + """ + if samples is None: + tracked_samples = tree_sequence.get_samples() + else: + tracked_samples = list(samples) + if len(tracked_samples) < 2: + raise ValueError("len(samples) must be >= 2") + pi = 0 + k = len(tracked_samples) + denom = k * (k - 1) / 2 + for t in tree_sequence.trees(tracked_samples=tracked_samples): + for mutation in t.mutations(): + j = t.get_num_tracked_samples(mutation.node) + pi += j * (k - j) / denom + return pi + + +def simplify_tree_sequence(ts, samples, filter_sites=True): + """ + Simple tree-by-tree algorithm to get a simplify of a tree sequence. + """ + s = simplify.Simplifier( + ts, samples, filter_sites=filter_sites) + return s.simplify() + + +def oriented_forests(n): + """ + Implementation of Algorithm O from TAOCP section 7.2.1.6. + Generates all canonical n-node oriented forests. + """ + p = [k - 1 for k in range(0, n + 1)] + k = 1 + while k != 0: + yield p + if p[n] > 0: + p[n] = p[p[n]] + yield p + k = n + while k > 0 and p[k] == 0: + k -= 1 + if k != 0: + j = p[k] + d = k - j + not_done = True + while not_done: + if p[k - d] == p[j]: + p[k] = p[j] + else: + p[k] = p[k - d] + d + if k == n: + not_done = False + else: + k += 1 + + +def get_mrca(pi, x, y): + """ + Returns the most recent common ancestor of nodes x and y in the + oriented forest pi. + """ + x_parents = [x] + j = x + while j != 0: + j = pi[j] + x_parents.append(j) + y_parents = {y: None} + j = y + while j != 0: + j = pi[j] + y_parents[j] = None + # We have the complete list of parents for x and y back to root. + mrca = 0 + j = 0 + while x_parents[j] not in y_parents: + j += 1 + mrca = x_parents[j] + return mrca + + +class TestMRCACalculator(unittest.TestCase): + """ + Class to test the Schieber-Vishkin algorithm. + + These tests are included here as we use the MRCA calculator below in + our tests. + """ + def test_all_oriented_forests(self): + # Runs through all possible oriented forests and checks all possible + # node pairs using an inferior algorithm. + for n in range(2, 9): + for pi in oriented_forests(n): + sv = tests.MRCACalculator(pi) + for j in range(1, n + 1): + for k in range(1, j + 1): + mrca = get_mrca(pi, j, k) + self.assertEqual(mrca, sv.get_mrca(j, k)) + + +class HighLevelTestCase(unittest.TestCase): + """ + Superclass of tests on the high level interface. + """ + def setUp(self): + self.temp_dir = tempfile.mkdtemp(prefix="tsk_hl_testcase_") + self.temp_file = os.path.join(self.temp_dir, "generic") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def verify_tree_mrcas(self, st): + # Check the mrcas + oriented_forest = [st.get_parent(j) for j in range(st.num_nodes)] + mrca_calc = tests.MRCACalculator(oriented_forest) + # We've done exhaustive tests elsewhere, no need to go + # through the combinations. + for j in range(st.num_nodes): + mrca = st.get_mrca(0, j) + self.assertEqual(mrca, mrca_calc.get_mrca(0, j)) + if mrca != tskit.NULL: + self.assertEqual(st.get_time(mrca), st.get_tmrca(0, j)) + + def verify_tree_branch_lengths(self, st): + for j in range(st.get_sample_size()): + u = j + while st.get_parent(u) != tskit.NULL: + length = st.get_time(st.get_parent(u)) - st.get_time(u) + self.assertGreater(length, 0.0) + self.assertEqual(st.get_branch_length(u), length) + u = st.get_parent(u) + + def verify_tree_structure(self, st): + roots = set() + for u in st.samples(): + # verify the path to root + self.assertTrue(st.is_sample(u)) + times = [] + while st.get_parent(u) != tskit.NULL: + v = st.get_parent(u) + times.append(st.get_time(v)) + self.assertGreaterEqual(st.get_time(v), 0.0) + self.assertIn(u, st.get_children(v)) + u = v + roots.add(u) + self.assertEqual(times, sorted(times)) + self.assertEqual(sorted(list(roots)), sorted(st.roots)) + self.assertEqual(len(st.roots), st.num_roots) + u = st.left_root + roots = [] + while u != tskit.NULL: + roots.append(u) + u = st.right_sib(u) + self.assertEqual(roots, st.roots) + # To a top-down traversal, and make sure we meet all the samples. + samples = [] + for root in st.roots: + stack = [root] + while len(stack) > 0: + u = stack.pop() + self.assertNotEqual(u, tskit.NULL) + if st.is_sample(u): + samples.append(u) + if st.is_leaf(u): + self.assertEqual(len(st.get_children(u)), 0) + else: + for c in reversed(st.get_children(u)): + stack.append(c) + # Check that we get the correct number of samples at each + # node. + self.assertEqual(st.get_num_samples(u), len(list(st.samples(u)))) + self.assertEqual(st.get_num_tracked_samples(u), 0) + self.assertEqual(sorted(samples), sorted(st.samples())) + # Check the parent dict + pi = st.get_parent_dict() + for root in st.roots: + self.assertNotIn(root, pi) + for k, v in pi.items(): + self.assertEqual(st.get_parent(k), v) + self.assertEqual(st.num_samples(), len(samples)) + self.assertEqual(sorted(st.samples()), sorted(samples)) + + def verify_tree(self, st): + self.verify_tree_mrcas(st) + self.verify_tree_branch_lengths(st) + self.verify_tree_structure(st) + + def verify_trees(self, ts): + pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + iter1 = ts.trees() + iter2 = pts.trees() + length = 0 + num_trees = 0 + breakpoints = [0] + for st1, st2 in zip(iter1, iter2): + self.assertEqual(st1.get_sample_size(), ts.get_sample_size()) + roots = set() + for u in ts.samples(): + root = u + while st1.get_parent(root) != tskit.NULL: + root = st1.get_parent(root) + roots.add(root) + self.assertEqual(sorted(list(roots)), sorted(st1.roots)) + if len(roots) > 1: + with self.assertRaises(ValueError): + st1.root + else: + self.assertEqual(st1.root, list(roots)[0]) + self.assertEqual(st2, st1) + self.assertFalse(st2 != st1) + l, r = st1.get_interval() + breakpoints.append(r) + self.assertAlmostEqual(l, length) + self.assertGreaterEqual(l, 0) + self.assertGreater(r, l) + self.assertLessEqual(r, ts.get_sequence_length()) + length += r - l + self.verify_tree(st1) + num_trees += 1 + self.assertRaises(StopIteration, next, iter1) + self.assertRaises(StopIteration, next, iter2) + self.assertEqual(ts.get_num_trees(), num_trees) + self.assertEqual(breakpoints, list(ts.breakpoints())) + self.assertAlmostEqual(length, ts.get_sequence_length()) + + def verify_haplotype_statistics(self, ts): + """ + Verifies the statistics calculated for the haplotypes + in the specified tree sequence. + """ + haplotypes = list(ts.haplotypes()) + pi1 = ts.get_pairwise_diversity() + pi2 = simple_get_pairwise_diversity(haplotypes) + pi3 = get_pairwise_diversity(ts) + self.assertAlmostEqual(pi1, pi2) + self.assertAlmostEqual(pi1, pi3) + self.assertGreaterEqual(pi1, 0.0) + self.assertFalse(math.isnan(pi1)) + # Check for a subsample. + num_samples = ts.get_sample_size() // 2 + 1 + samples = list(ts.samples())[:num_samples] + pi1 = ts.get_pairwise_diversity(samples) + pi2 = simple_get_pairwise_diversity([haplotypes[j] for j in range(num_samples)]) + pi3 = get_pairwise_diversity(ts, samples) + self.assertAlmostEqual(pi1, pi2) + self.assertAlmostEqual(pi1, pi3) + self.assertGreaterEqual(pi1, 0.0) + self.assertFalse(math.isnan(pi1)) + + def verify_mutations(self, ts): + """ + Verify the mutations on this tree sequence make sense. + """ + self.verify_haplotype_statistics(ts) + all_mutations = list(ts.mutations()) + # Mutations must be sorted by position + self.assertEqual( + all_mutations, sorted(all_mutations, key=lambda x: x[0])) + self.assertEqual(len(all_mutations), ts.get_num_mutations()) + all_tree_mutations = [] + j = 0 + for st in ts.trees(): + tree_mutations = list(st.mutations()) + self.assertEqual(st.get_num_mutations(), len(tree_mutations)) + all_tree_mutations.extend(tree_mutations) + for mutation in tree_mutations: + left, right = st.get_interval() + self.assertTrue(left <= mutation.position < right) + self.assertEqual(mutation.index, j) + j += 1 + self.assertEqual(all_tree_mutations, all_mutations) + pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + iter1 = ts.trees() + iter2 = pts.trees() + for st1, st2 in zip(iter1, iter2): + self.assertEqual(st1, st2) + + +class TestVariantGenerator(HighLevelTestCase): + """ + Tests the variants() method to ensure the output is consistent. + """ + def get_tree_sequence(self): + ts = msprime.simulate( + 10, length=10, recombination_rate=1, mutation_rate=10, random_seed=3) + self.assertGreater(ts.get_num_mutations(), 10) + return ts + + def test_as_bytes(self): + ts = self.get_tree_sequence() + n = ts.get_sample_size() + m = ts.get_num_mutations() + A = np.zeros((m, n), dtype='u1') + B = np.zeros((m, n), dtype='u1') + for variant in ts.variants(): + A[variant.index] = variant.genotypes + for variant in ts.variants(as_bytes=True): + self.assertIsInstance(variant.genotypes, bytes) + B[variant.index] = np.fromstring(variant.genotypes, np.uint8) - ord('0') + self.assertTrue(np.all(A == B)) + bytes_variants = list(ts.variants(as_bytes=True)) + for j, variant in enumerate(bytes_variants): + self.assertEqual(j, variant.index) + row = np.fromstring(variant.genotypes, np.uint8) - ord('0') + self.assertTrue(np.all(A[j] == row)) + + def test_as_bytes_fails(self): + ts = tsutil.insert_multichar_mutations(self.get_tree_sequence()) + self.assertRaises(ValueError, list, ts.variants(as_bytes=True)) + + def test_multichar_alleles(self): + ts = tsutil.insert_multichar_mutations(self.get_tree_sequence()) + for var in ts.variants(): + self.assertEqual(len(var.alleles), 2) + self.assertEqual(var.site.ancestral_state, var.alleles[0]) + self.assertEqual(var.site.mutations[0].derived_state, var.alleles[1]) + self.assertTrue(all(0 <= var.genotypes)) + self.assertTrue(all(var.genotypes <= 1)) + + def test_many_alleles(self): + ts = self.get_tree_sequence() + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + # This gives us a total of 360 permutations. + alleles = list(map("".join, itertools.permutations('ABCDEF', 4))) + self.assertGreater(len(alleles), 255) + tables.sites.add_row(0, alleles[0]) + parent = -1 + num_alleles = 1 + for allele in alleles[1:]: + ts = tables.tree_sequence() + if num_alleles > 255: + self.assertRaises(_tskit.LibraryError, next, ts.variants()) + else: + var = next(ts.variants()) + self.assertEqual(len(var.alleles), num_alleles) + self.assertEqual(list(var.alleles), alleles[:num_alleles]) + self.assertEqual( + var.alleles[var.genotypes[0]], alleles[num_alleles - 1]) + for u in ts.samples(): + if u != 0: + self.assertEqual(var.alleles[var.genotypes[u]], alleles[0]) + tables.mutations.add_row(0, 0, allele, parent=parent) + parent += 1 + num_alleles += 1 + + def test_site_information(self): + ts = self.get_tree_sequence() + for site, variant in zip(ts.sites(), ts.variants()): + self.assertEqual(site.position, variant.position) + self.assertEqual(site, variant.site) + + def test_no_mutations(self): + ts = msprime.simulate(10) + self.assertEqual(ts.get_num_mutations(), 0) + variants = list(ts.variants()) + self.assertEqual(len(variants), 0) + + def test_genotype_matrix(self): + ts = self.get_tree_sequence() + G = np.empty((ts.num_sites, ts.num_samples), dtype=np.uint8) + for v in ts.variants(): + G[v.index, :] = v.genotypes + self.assertTrue(np.array_equal(G, ts.genotype_matrix())) + + def test_recurrent_mutations_over_samples(self): + ts = self.get_tree_sequence() + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + num_sites = 5 + for j in range(num_sites): + tables.sites.add_row( + position=j * ts.sequence_length / num_sites, + ancestral_state="0") + for u in range(ts.sample_size): + tables.mutations.add_row(site=j, node=u, derived_state="1") + ts = tables.tree_sequence() + variants = list(ts.variants()) + self.assertEqual(len(variants), num_sites) + for site, variant in zip(ts.sites(), variants): + self.assertEqual(site.position, variant.position) + self.assertEqual(site, variant.site) + self.assertEqual(site.id, variant.index) + self.assertEqual(variant.alleles, ("0", "1")) + self.assertTrue(np.all(variant.genotypes == np.ones(ts.sample_size))) + + def test_recurrent_mutations_errors(self): + ts = self.get_tree_sequence() + tree = next(ts.trees()) + tables = ts.dump_tables() + for u in tree.nodes(): + for sample in tree.samples(u): + if sample != u: + tables.sites.clear() + tables.mutations.clear() + site = tables.sites.add_row(position=0, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + tables.mutations.add_row(site=site, node=sample, derived_state="1") + ts_new = tables.tree_sequence() + self.assertRaises(_tskit.LibraryError, list, ts_new.variants()) + + def test_zero_samples(self): + ts = self.get_tree_sequence() + for var1, var2 in zip(ts.variants(), ts.variants(samples=[])): + self.assertEqual(var1.site, var2.site) + self.assertEqual(var1.alleles, var2.alleles) + self.assertEqual(var2.genotypes.shape[0], 0) + + def test_samples(self): + n = 4 + ts = msprime.simulate( + n, length=5, recombination_rate=1, mutation_rate=5, random_seed=2) + self.assertGreater(ts.num_sites, 1) + samples = list(range(n)) + # Generate all possible sample lists. + for j in range(n + 1): + for s in itertools.permutations(samples, j): + s = np.array(s, dtype=np.int32) + count = 0 + for var1, var2 in zip(ts.variants(), ts.variants(samples=s)): + self.assertEqual(var1.site, var2.site) + self.assertEqual(var1.alleles, var2.alleles) + self.assertEqual(var2.genotypes.shape, (len(s),)) + self.assertTrue(np.array_equal(var1.genotypes[s], var2.genotypes)) + count += 1 + self.assertEqual(count, ts.num_sites) + + def test_non_sample_samples(self): + # We don't have to use sample nodes. This does make the terminology confusing + # but it's probably still the best option. + ts = msprime.simulate( + 10, length=5, recombination_rate=1, mutation_rate=5, random_seed=2) + tables = ts.dump_tables() + tables.nodes.set_columns( + flags=np.zeros_like(tables.nodes.flags) + tskit.NODE_IS_SAMPLE, + time=tables.nodes.time) + all_samples_ts = tables.tree_sequence() + self.assertEqual(all_samples_ts.num_samples, ts.num_nodes) + + count = 0 + samples = range(ts.num_nodes) + for var1, var2 in zip(all_samples_ts.variants(), ts.variants(samples=samples)): + self.assertEqual(var1.site, var2.site) + self.assertEqual(var1.alleles, var2.alleles) + self.assertEqual(var2.genotypes.shape, (len(samples),)) + self.assertTrue(np.array_equal(var1.genotypes, var2.genotypes)) + count += 1 + self.assertEqual(count, ts.num_sites) + + +class TestHaplotypeGenerator(HighLevelTestCase): + """ + Tests the haplotype generation code. + """ + + def verify_haplotypes(self, n, haplotypes): + """ + Verify that the specified set of haplotypes is consistent. + """ + self.assertEqual(len(haplotypes), n) + m = len(haplotypes[0]) + for h in haplotypes: + self.assertEqual(len(h), m) + # Examine each column in H; we must have a mixture of 0s and 1s + for k in range(m): + zeros = 0 + ones = 0 + col = "" + for j in range(n): + b = haplotypes[j][k] + zeros += b == '0' + ones += b == '1' + col += b + self.assertEqual(zeros + ones, n) + + def verify_tree_sequence(self, tree_sequence): + n = tree_sequence.sample_size + m = tree_sequence.num_sites + haplotypes = list(tree_sequence.haplotypes()) + A = np.zeros((n, m), dtype='u1') + B = np.zeros((n, m), dtype='u1') + for j, h in enumerate(haplotypes): + self.assertEqual(len(h), m) + A[j] = np.fromstring(h, np.uint8) - ord('0') + for variant in tree_sequence.variants(): + B[:, variant.index] = variant.genotypes + self.assertTrue(np.all(A == B)) + self.verify_haplotypes(n, haplotypes) + + def verify_simulation(self, n, m, r, theta): + """ + Verifies a simulation for the specified parameters. + """ + recomb_map = msprime.RecombinationMap.uniform_map(m, r, m) + tree_sequence = msprime.simulate( + n, recombination_map=recomb_map, mutation_rate=theta) + self.verify_tree_sequence(tree_sequence) + + def test_random_parameters(self): + num_random_sims = 10 + for j in range(num_random_sims): + n = random.randint(2, 100) + m = random.randint(10, 1000) + r = random.random() + theta = random.uniform(0, 2) + self.verify_simulation(n, m, r, theta) + + def test_nonbinary_trees(self): + for ts in get_bottleneck_examples(): + self.verify_tree_sequence(ts) + + def test_acgt_mutations(self): + ts = msprime.simulate(10, mutation_rate=10) + self.assertGreater(ts.num_sites, 0) + tables = ts.tables + sites = tables.sites + mutations = tables.mutations + sites.set_columns( + position=sites.position, + ancestral_state=np.zeros(ts.num_sites, dtype=np.int8) + ord("A"), + ancestral_state_offset=np.arange(ts.num_sites + 1, dtype=np.uint32)) + mutations.set_columns( + site=mutations.site, + node=mutations.node, + derived_state=np.zeros(ts.num_sites, dtype=np.int8) + ord("T"), + derived_state_offset=np.arange(ts.num_sites + 1, dtype=np.uint32)) + tsp = tables.tree_sequence() + H = [h.replace("0", "A").replace("1", "T") for h in ts.haplotypes()] + self.assertEqual(H, list(tsp.haplotypes())) + + def test_multiletter_mutations(self): + ts = msprime.simulate(10) + tables = ts.tables + tables.sites.add_row(0, "ACTG") + tsp = tables.tree_sequence() + self.assertRaises(_tskit.LibraryError, list, tsp.haplotypes()) + + def test_recurrent_mutations_over_samples(self): + for ts in get_bottleneck_examples(): + num_sites = 5 + tables = ts.dump_tables() + for j in range(num_sites): + tables.sites.add_row( + position=j * ts.sequence_length / num_sites, + ancestral_state="0") + for u in range(ts.sample_size): + tables.mutations.add_row(site=j, node=u, derived_state="1") + ts_new = tables.tree_sequence() + ones = "1" * num_sites + for h in ts_new.haplotypes(): + self.assertEqual(ones, h) + + def test_recurrent_mutations_errors(self): + for ts in get_bottleneck_examples(): + tables = ts.dump_tables() + tree = next(ts.trees()) + for u in tree.children(tree.root): + tables.sites.clear() + tables.mutations.clear() + site = tables.sites.add_row(position=0, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + tables.mutations.add_row(site=site, node=tree.root, derived_state="1") + ts_new = tables.tree_sequence() + self.assertRaises(_tskit.LibraryError, list, ts_new.haplotypes()) + ts_new.haplotypes() + + def test_back_mutations(self): + for ts in get_back_mutation_examples(): + self.verify_tree_sequence(ts) + + +class TestNumpySamples(unittest.TestCase): + """ + Tests that we correctly handle samples as numpy arrays when passed to + various methods. + """ + def get_tree_sequence(self, num_demes=4): + n = 40 + return msprime.simulate( + samples=[ + msprime.Sample(time=0, population=j % num_demes) for j in range(n)], + population_configurations=[ + msprime.PopulationConfiguration() for _ in range(num_demes)], + migration_matrix=[ + [int(j != k) for j in range(num_demes)] for k in range(num_demes)], + random_seed=1, + mutation_rate=10) + + def test_samples(self): + d = 4 + ts = self.get_tree_sequence(d) + self.assertTrue(np.array_equal( + ts.samples(), np.arange(ts.num_samples, dtype=np.int32))) + total = 0 + for pop in range(d): + subsample = ts.samples(pop) + total += subsample.shape[0] + self.assertTrue(np.array_equal(subsample, ts.samples(population=pop))) + self.assertEqual( + list(subsample), + [node.id for node in ts.nodes() + if node.population == pop and node.is_sample()]) + self.assertEqual(total, ts.num_samples) + + def test_genotype_matrix_indexing(self): + num_demes = 4 + ts = self.get_tree_sequence(num_demes) + G = ts.genotype_matrix() + for d in range(num_demes): + samples = ts.samples(population=d) + total = 0 + for tree in ts.trees(tracked_samples=samples): + for mutation in tree.mutations(): + total += tree.num_tracked_samples(mutation.node) + self.assertEqual(total, np.sum(G[:, samples])) + + def test_genotype_indexing(self): + num_demes = 6 + ts = self.get_tree_sequence(num_demes) + for d in range(num_demes): + samples = ts.samples(population=d) + total = 0 + for tree in ts.trees(tracked_samples=samples): + for mutation in tree.mutations(): + total += tree.num_tracked_samples(mutation.node) + other_total = 0 + for variant in ts.variants(): + other_total += np.sum(variant.genotypes[samples]) + self.assertEqual(total, other_total) + + def test_pairwise_diversity(self): + num_demes = 6 + ts = self.get_tree_sequence(num_demes) + pi1 = ts.pairwise_diversity(ts.samples()) + pi2 = ts.pairwise_diversity() + self.assertEqual(pi1, pi2) + for d in range(num_demes): + samples = ts.samples(population=d) + pi1 = ts.pairwise_diversity(samples) + pi2 = ts.pairwise_diversity(list(samples)) + self.assertEqual(pi1, pi2) + + def test_simplify(self): + num_demes = 3 + ts = self.get_tree_sequence(num_demes) + sts = ts.simplify(samples=ts.samples()) + self.assertEqual(ts.num_samples, sts.num_samples) + for d in range(num_demes): + samples = ts.samples(population=d) + sts = ts.simplify(samples=samples) + self.assertEqual(sts.num_samples, samples.shape[0]) + + +class TestTreeSequence(HighLevelTestCase): + """ + Tests for the tree sequence object. + """ + + def test_trees(self): + for ts in get_example_tree_sequences(): + self.verify_trees(ts) + + def test_mutations(self): + for ts in get_example_tree_sequences(): + self.verify_mutations(ts) + + def verify_edge_diffs(self, ts): + pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + d1 = list(ts.edge_diffs()) + d2 = list(pts.edge_diffs()) + self.assertEqual(d1, d2) + + # check that we have the correct set of children at all nodes. + children = collections.defaultdict(set) + trees = iter(ts.trees()) + tree = next(trees) + last_right = 0 + for (left, right), edges_out, edges_in in ts.edge_diffs(): + assert left == last_right + last_right = right + for edge in edges_out: + children[edge.parent].remove(edge.child) + for edge in edges_in: + children[edge.parent].add(edge.child) + while tree.interval[1] <= left: + tree = next(trees) + # print(left, right, tree.interval) + self.assertTrue(left >= tree.interval[0]) + self.assertTrue(right <= tree.interval[1]) + for u in tree.nodes(): + if tree.is_internal(u): + self.assertIn(u, children) + self.assertEqual(children[u], set(tree.children(u))) + + def test_edge_diffs(self): + for ts in get_example_tree_sequences(): + self.verify_edge_diffs(ts) + + def verify_edgesets(self, ts): + """ + Verifies that the edgesets we return are equivalent to the original edges. + """ + new_edges = [] + for edgeset in ts.edgesets(): + self.assertEqual(edgeset.children, sorted(edgeset.children)) + self.assertGreater(len(edgeset.children), 0) + for child in edgeset.children: + new_edges.append(tskit.Edge( + edgeset.left, edgeset.right, edgeset.parent, child)) + # squash the edges. + t = ts.dump_tables().nodes.time + new_edges.sort(key=lambda e: (t[e.parent], e.parent, e.child, e.left)) + + squashed = [] + last_e = new_edges[0] + for e in new_edges[1:]: + condition = ( + e.parent != last_e.parent or + e.child != last_e.child or + e.left != last_e.right) + if condition: + squashed.append(last_e) + last_e = e + last_e.right = e.right + squashed.append(last_e) + edges = list(ts.edges()) + self.assertEqual(len(squashed), len(edges)) + self.assertEqual(edges, squashed) + + def test_edgesets(self): + for ts in get_example_tree_sequences(): + self.verify_edgesets(ts) + + def verify_coalescence_records(self, ts): + """ + Checks that the coalescence records we output are correct. + """ + edgesets = list(ts.edgesets()) + records = list(ts.records()) + self.assertEqual(len(edgesets), len(records)) + for edgeset, record in zip(edgesets, records): + self.assertEqual(edgeset.left, record.left) + self.assertEqual(edgeset.right, record.right) + self.assertEqual(edgeset.parent, record.node) + self.assertEqual(edgeset.children, record.children) + parent = ts.node(edgeset.parent) + self.assertEqual(parent.time, record.time) + self.assertEqual(parent.population, record.population) + + def test_coalescence_records(self): + for ts in get_example_tree_sequences(): + self.verify_coalescence_records(ts) + + def test_compute_mutation_parent(self): + for ts in get_example_tree_sequences(): + tables = ts.dump_tables() + before = tables.mutations.parent[:] + tables.compute_mutation_parents() + parent = ts.tables.mutations.parent + self.assertTrue(np.array_equal(parent, before)) + + def verify_tracked_samples(self, ts): + # Should be empty list by default. + for tree in ts.trees(): + self.assertEqual(tree.get_num_tracked_samples(), 0) + for u in tree.nodes(): + self.assertEqual(tree.get_num_tracked_samples(u), 0) + samples = list(ts.samples()) + tracked_samples = samples[:2] + for tree in ts.trees(tracked_samples): + if len(tree.parent_dict) == 0: + # This is a crude way of checking if we have multiple roots. + # We'll need to fix this code up properly when we support multiple + # roots and remove this check + break + nu = [0 for j in range(ts.get_num_nodes())] + self.assertEqual(tree.get_num_tracked_samples(), len(tracked_samples)) + for j in tracked_samples: + u = j + while u != tskit.NULL: + nu[u] += 1 + u = tree.get_parent(u) + for u, count in enumerate(nu): + self.assertEqual(tree.get_num_tracked_samples(u), count) + + def test_tracked_samples(self): + for ts in get_example_tree_sequences(): + self.verify_tracked_samples(ts) + + def test_deprecated_sample_aliases(self): + for ts in get_example_tree_sequences(): + # Ensure that we get the same results from the various combinations + # of leaf_lists, sample_lists etc. + samples = list(ts.samples())[:2] + # tracked leaves/samples + trees_new = ts.trees(tracked_samples=samples) + trees_old = ts.trees(tracked_leaves=samples) + for t_new, t_old in zip(trees_new, trees_old): + for u in t_new.nodes(): + self.assertEqual( + t_new.num_tracked_samples(u), t_old.get_num_tracked_leaves(u)) + for on in [True, False]: + # sample/leaf counts + trees_new = ts.trees(sample_counts=on) + trees_old = ts.trees(leaf_counts=on) + for t_new, t_old in zip(trees_new, trees_old): + for u in t_new.nodes(): + self.assertEqual(t_new.num_samples(u), t_old.get_num_leaves(u)) + self.assertEqual( + list(t_new.samples(u)), list(t_old.get_leaves(u))) + trees_new = ts.trees(sample_lists=on) + trees_old = ts.trees(leaf_lists=on) + for t_new, t_old in zip(trees_new, trees_old): + for u in t_new.nodes(): + self.assertEqual(t_new.num_samples(u), t_old.get_num_leaves(u)) + self.assertEqual( + list(t_new.samples(u)), list(t_old.get_leaves(u))) + + def verify_samples(self, ts): + # We should get the same list of samples if we use the low-level + # sample lists or a simple traversal. + samples1 = [] + for t in ts.trees(sample_lists=False): + samples1.append(list(t.samples())) + samples2 = [] + for t in ts.trees(sample_lists=True): + samples2.append(list(t.samples())) + self.assertEqual(samples1, samples2) + + def test_samples(self): + for ts in get_example_tree_sequences(): + self.verify_samples(ts) + pops = set(node.population for node in ts.nodes()) + for pop in pops: + subsample = ts.samples(pop) + self.assertTrue(np.array_equal(subsample, ts.samples(population=pop))) + self.assertTrue(np.array_equal(subsample, ts.samples(population_id=pop))) + self.assertEqual( + list(subsample), + [node.id for node in ts.nodes() + if node.population == pop and node.is_sample()]) + self.assertRaises(ValueError, ts.samples, population=0, population_id=0) + + def test_first(self): + for ts in get_example_tree_sequences(): + t1 = ts.first() + t2 = next(ts.trees()) + self.assertFalse(t1 is t2) + self.assertEqual(t1.parent_dict, t2.parent_dict) + + def test_trees_interface(self): + ts = list(get_example_tree_sequences())[0] + # The defaults should make sense and count samples. + # get_num_tracked_samples + for t in ts.trees(): + self.assertEqual(t.get_num_samples(0), 1) + self.assertEqual(t.get_num_tracked_samples(0), 0) + self.assertEqual(list(t.samples(0)), [0]) + self.assertIs(t.tree_sequence, ts) + + for t in ts.trees(sample_counts=False): + self.assertEqual(t.get_num_samples(0), 1) + self.assertRaises(RuntimeError, t.get_num_tracked_samples, 0) + self.assertEqual(list(t.samples(0)), [0]) + + for t in ts.trees(sample_counts=True): + self.assertEqual(t.get_num_samples(0), 1) + self.assertEqual(t.get_num_tracked_samples(0), 0) + self.assertEqual(list(t.samples(0)), [0]) + + for t in ts.trees(sample_counts=True, tracked_samples=[0]): + self.assertEqual(t.get_num_samples(0), 1) + self.assertEqual(t.get_num_tracked_samples(0), 1) + self.assertEqual(list(t.samples(0)), [0]) + + for t in ts.trees(sample_lists=True, sample_counts=True): + self.assertEqual(t.get_num_samples(0), 1) + self.assertEqual(t.get_num_tracked_samples(0), 0) + self.assertEqual(list(t.samples(0)), [0]) + + for t in ts.trees(sample_lists=True, sample_counts=False): + self.assertEqual(t.get_num_samples(0), 1) + self.assertRaises(RuntimeError, t.get_num_tracked_samples, 0) + self.assertEqual(list(t.samples(0)), [0]) + + # This is a bit weird as we don't seem to actually execute the + # method until it is iterated. + self.assertRaises( + ValueError, list, ts.trees(sample_counts=False, tracked_samples=[0])) + + def test_get_pairwise_diversity(self): + for ts in get_example_tree_sequences(): + n = ts.get_sample_size() + self.assertRaises(ValueError, ts.get_pairwise_diversity, []) + self.assertRaises(ValueError, ts.get_pairwise_diversity, [1]) + self.assertRaises(ValueError, ts.get_pairwise_diversity, [1, n]) + samples = list(ts.samples()) + if any(len(site.mutations) > 1 for site in ts.sites()): + # Multi-mutations are not currenty supported when computing pi. + self.assertRaises(_tskit.LibraryError, ts.get_pairwise_diversity) + else: + self.assertEqual( + ts.get_pairwise_diversity(), + ts.get_pairwise_diversity(samples)) + self.assertEqual( + ts.get_pairwise_diversity(samples[:2]), + ts.get_pairwise_diversity(reversed(samples[:2]))) + + def test_populations(self): + more_than_zero = False + for ts in get_example_tree_sequences(): + N = ts.num_populations + if N > 0: + more_than_zero = True + pops = list(ts.populations()) + self.assertEqual(len(pops), N) + for j in range(N): + self.assertEqual(pops[j], ts.population(j)) + self.assertEqual(pops[j].id, j) + self.assertTrue(isinstance(pops[j].metadata, bytes)) + self.assertTrue(more_than_zero) + + def test_individuals(self): + more_than_zero = False + mapped_to_nodes = False + for ts in get_example_tree_sequences(): + ind_node_map = collections.defaultdict(list) + for node in ts.nodes(): + if node.individual != tskit.NULL: + ind_node_map[node.individual].append(node.id) + if len(ind_node_map) > 0: + mapped_to_nodes = True + N = ts.num_individuals + if N > 0: + more_than_zero = True + inds = list(ts.individuals()) + self.assertEqual(len(inds), N) + for j in range(N): + self.assertEqual(inds[j], ts.individual(j)) + self.assertEqual(inds[j].id, j) + self.assertTrue(isinstance(inds[j].metadata, bytes)) + self.assertTrue(isinstance(inds[j].location, np.ndarray)) + self.assertTrue(isinstance(inds[j].nodes, np.ndarray)) + self.assertEqual(ind_node_map[j], list(inds[j].nodes)) + + self.assertTrue(more_than_zero) + self.assertTrue(mapped_to_nodes) + + def test_get_population(self): + # Deprecated interface for ts.node(id).population + for ts in get_example_tree_sequences(): + N = ts.get_num_nodes() + self.assertRaises(ValueError, ts.get_population, -1) + self.assertRaises(ValueError, ts.get_population, N) + self.assertRaises(ValueError, ts.get_population, N + 1) + for node in [0, N - 1]: + self.assertEqual(ts.get_population(node), ts.node(node).population) + + def test_get_time(self): + # Deprecated interface for ts.node(id).time + for ts in get_example_tree_sequences(): + N = ts.get_num_nodes() + self.assertRaises(ValueError, ts.get_time, -1) + self.assertRaises(ValueError, ts.get_time, N) + self.assertRaises(ValueError, ts.get_time, N + 1) + for u in range(N): + self.assertEqual(ts.get_time(u), ts.node(u).time) + + def test_write_vcf_interface(self): + for ts in get_example_tree_sequences(): + n = ts.get_sample_size() + for bad_ploidy in [-1, 0, n + 1]: + self.assertRaises(ValueError, ts.write_vcf, self.temp_file, bad_ploidy) + + def verify_simplify_provenance(self, ts): + new_ts = ts.simplify() + self.assertEqual(new_ts.num_provenances, ts.num_provenances + 1) + old = list(ts.provenances()) + new = list(new_ts.provenances()) + self.assertEqual(old, new[:-1]) + # TODO call verify_provenance on this. + self.assertGreater(len(new[-1].timestamp), 0) + self.assertGreater(len(new[-1].record), 0) + + new_ts = ts.simplify(record_provenance=False) + self.assertEqual(new_ts.tables.provenances, ts.tables.provenances) + + def verify_simplify_topology(self, ts, sample): + new_ts, node_map = ts.simplify(sample, map_nodes=True) + if len(sample) == 0: + self.assertEqual(new_ts.num_nodes, 0) + self.assertEqual(new_ts.num_edges, 0) + self.assertEqual(new_ts.num_sites, 0) + self.assertEqual(new_ts.num_mutations, 0) + elif len(sample) == 1: + self.assertEqual(new_ts.num_nodes, 1) + self.assertEqual(new_ts.num_edges, 0) + # The output samples should be 0...n + self.assertEqual(new_ts.num_samples, len(sample)) + self.assertEqual(list(range(len(sample))), list(new_ts.samples())) + for j in range(new_ts.num_samples): + self.assertEqual(node_map[sample[j]], j) + for u in range(ts.num_nodes): + old_node = ts.node(u) + if node_map[u] != tskit.NULL: + new_node = new_ts.node(node_map[u]) + self.assertEqual(old_node.time, new_node.time) + self.assertEqual(old_node.population, new_node.population) + self.assertEqual(old_node.metadata, new_node.metadata) + for u in sample: + old_node = ts.node(u) + new_node = new_ts.node(node_map[u]) + self.assertEqual(old_node.flags, new_node.flags) + self.assertEqual(old_node.time, new_node.time) + self.assertEqual(old_node.population, new_node.population) + self.assertEqual(old_node.metadata, new_node.metadata) + old_trees = ts.trees() + old_tree = next(old_trees) + self.assertGreaterEqual(ts.get_num_trees(), new_ts.get_num_trees()) + for new_tree in new_ts.trees(): + new_left, new_right = new_tree.get_interval() + old_left, old_right = old_tree.get_interval() + # Skip ahead on the old tree until new_left is within its interval + while old_right <= new_left: + old_tree = next(old_trees) + old_left, old_right = old_tree.get_interval() + # If the MRCA of all pairs of samples is the same, then we have the + # same information. We limit this to at most 500 pairs + pairs = itertools.islice(itertools.combinations(sample, 2), 500) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = old_tree.get_mrca(*pair) + mrca2 = new_tree.get_mrca(*mapped_pair) + if mrca1 == tskit.NULL: + self.assertEqual(mrca2, mrca1) + else: + self.assertEqual(mrca2, node_map[mrca1]) + self.assertEqual(old_tree.get_time(mrca1), new_tree.get_time(mrca2)) + self.assertEqual( + old_tree.get_population(mrca1), new_tree.get_population(mrca2)) + + def verify_simplify_equality(self, ts, sample): + for filter_sites in [False, True]: + s1, node_map1 = ts.simplify( + sample, map_nodes=True, filter_sites=filter_sites) + t1 = s1.dump_tables() + s2, node_map2 = simplify_tree_sequence(ts, sample, filter_sites=filter_sites) + t2 = s2.dump_tables() + self.assertEqual(s1.num_samples, len(sample)) + self.assertEqual(s2.num_samples, len(sample)) + self.assertTrue(all(node_map1 == node_map2)) + self.assertEqual(t1.individuals, t2.individuals) + self.assertEqual(t1.nodes, t2.nodes) + self.assertEqual(t1.edges, t2.edges) + self.assertEqual(t1.migrations, t2.migrations) + self.assertEqual(t1.sites, t2.sites) + self.assertEqual(t1.mutations, t2.mutations) + self.assertEqual(t1.populations, t2.populations) + + def verify_simplify_variants(self, ts, sample): + subset = ts.simplify(sample) + sample_map = {u: j for j, u in enumerate(ts.samples())} + # Need to map IDs back to their sample indexes + s = np.array([sample_map[u] for u in sample]) + # Build a map of genotypes by position + full_genotypes = {} + for variant in ts.variants(): + alleles = [variant.alleles[g] for g in variant.genotypes] + full_genotypes[variant.position] = alleles + for variant in subset.variants(): + if variant.position in full_genotypes: + a1 = [full_genotypes[variant.position][u] for u in s] + a2 = [variant.alleles[g] for g in variant.genotypes] + self.assertEqual(a1, a2) + + def test_simplify(self): + num_mutations = 0 + for ts in get_example_tree_sequences(): + self.verify_simplify_provenance(ts) + n = ts.get_sample_size() + num_mutations += ts.get_num_mutations() + sample_sizes = {0, 1} + if n > 2: + sample_sizes |= set([2, max(2, n // 2), n - 1]) + for k in sample_sizes: + subset = random.sample(list(ts.samples()), k) + self.verify_simplify_topology(ts, subset) + self.verify_simplify_equality(ts, subset) + self.verify_simplify_variants(ts, subset) + self.assertGreater(num_mutations, 0) + + def test_simplify_bugs(self): + prefix = "tests/data/simplify-bugs/" + j = 1 + while True: + nodes_file = os.path.join(prefix, "{:02d}-nodes.txt".format(j)) + if not os.path.exists(nodes_file): + break + edges_file = os.path.join(prefix, "{:02d}-edges.txt".format(j)) + sites_file = os.path.join(prefix, "{:02d}-sites.txt".format(j)) + mutations_file = os.path.join(prefix, "{:02d}-mutations.txt".format(j)) + with open(nodes_file) as nodes, \ + open(edges_file) as edges,\ + open(sites_file) as sites,\ + open(mutations_file) as mutations: + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, + strict=False) + samples = list(ts.samples()) + self.verify_simplify_equality(ts, samples) + j += 1 + self.assertGreater(j, 1) + + def test_simplify_migrations_fails(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(10), + msprime.PopulationConfiguration(10)], + migration_matrix=[[0, 1], [1, 0]], + random_seed=2, + record_migrations=True) + self.assertGreater(ts.num_migrations, 0) + # We don't support simplify with migrations, so should fail. + with self.assertRaises(_tskit.LibraryError): + ts.simplify() + + def test_deprecated_apis(self): + ts = msprime.simulate(10, random_seed=1) + self.assertEqual(ts.get_ll_tree_sequence(), ts.ll_tree_sequence) + self.assertEqual(ts.get_sample_size(), ts.sample_size) + self.assertEqual(ts.get_sample_size(), ts.num_samples) + self.assertEqual(ts.get_sequence_length(), ts.sequence_length) + self.assertEqual(ts.get_num_trees(), ts.num_trees) + self.assertEqual(ts.get_num_mutations(), ts.num_mutations) + self.assertEqual(ts.get_num_nodes(), ts.num_nodes) + self.assertEqual(ts.get_pairwise_diversity(), ts.pairwise_diversity()) + samples = ts.samples() + self.assertEqual( + ts.get_pairwise_diversity(samples), ts.pairwise_diversity(samples)) + self.assertTrue(np.array_equal(ts.get_samples(), ts.samples())) + + def test_sites(self): + some_sites = False + for ts in get_example_tree_sequences(): + tables = ts.dump_tables() + sites = tables.sites + mutations = tables.mutations + self.assertEqual(ts.num_sites, len(sites)) + self.assertEqual(ts.num_mutations, len(mutations)) + previous_pos = -1 + mutation_index = 0 + ancestral_state = tskit.unpack_strings( + sites.ancestral_state, sites.ancestral_state_offset) + derived_state = tskit.unpack_strings( + mutations.derived_state, mutations.derived_state_offset) + + for index, site in enumerate(ts.sites()): + s2 = ts.site(site.id) + self.assertEqual(s2, site) + self.assertEqual(site.position, sites.position[index]) + self.assertGreater(site.position, previous_pos) + previous_pos = site.position + self.assertEqual(ancestral_state[index], site.ancestral_state) + self.assertEqual(site.id, index) + for mutation in site.mutations: + m2 = ts.mutation(mutation.id) + self.assertEqual(m2, mutation) + self.assertEqual(mutation.site, site.id) + self.assertEqual(mutation.site, mutations.site[mutation_index]) + self.assertEqual(mutation.node, mutations.node[mutation_index]) + self.assertEqual(mutation.parent, mutations.parent[mutation_index]) + self.assertEqual(mutation.id, mutation_index) + self.assertEqual( + derived_state[mutation_index], mutation.derived_state) + mutation_index += 1 + some_sites = True + total_sites = 0 + for tree in ts.trees(): + self.assertEqual(len(list(tree.sites())), tree.num_sites) + total_sites += tree.num_sites + self.assertEqual(ts.num_sites, total_sites) + self.assertEqual(mutation_index, len(mutations)) + self.assertTrue(some_sites) + + def verify_mutations(self, ts): + other_mutations = [] + for site in ts.sites(): + for mutation in site.mutations: + other_mutations.append(mutation) + mutations = list(ts.mutations()) + self.assertEqual(ts.num_mutations, len(other_mutations)) + self.assertEqual(ts.num_mutations, len(mutations)) + for mut, other_mut in zip(mutations, other_mutations): + # We cannot compare these directly as the mutations obtained + # from the mutations iterator will have extra deprecated + # attributes. + self.assertEqual(mut.id, other_mut.id) + self.assertEqual(mut.site, other_mut.site) + self.assertEqual(mut.parent, other_mut.parent) + self.assertEqual(mut.node, other_mut.node) + self.assertEqual(mut.metadata, other_mut.metadata) + # Check the deprecated attrs. + self.assertEqual(mut.position, ts.site(mut.site).position) + self.assertEqual(mut.index, mut.site) + + def test_sites_mutations(self): + # Check that the mutations iterator returns the correct values. + for ts in get_example_tree_sequences(): + self.verify_mutations(ts) + + def test_removed_methods(self): + ts = next(get_example_tree_sequences()) + self.assertRaises(NotImplementedError, ts.get_num_records) + self.assertRaises(NotImplementedError, ts.diffs) + self.assertRaises(NotImplementedError, ts.newick_trees) + + def test_zlib_compression_warning(self): + ts = msprime.simulate(5, random_seed=1) + with warnings.catch_warnings(record=True) as w: + ts.dump(self.temp_file, zlib_compression=True) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, RuntimeWarning)) + with warnings.catch_warnings(record=True) as w: + ts.dump(self.temp_file, zlib_compression=False) + self.assertEqual(len(w), 0) + + def test_tables_sequence_length_round_trip(self): + for sequence_length in [0.1, 1, 10, 100]: + ts = msprime.simulate(5, length=sequence_length, random_seed=1) + self.assertEqual(ts.sequence_length, sequence_length) + tables = ts.tables + self.assertEqual(tables.sequence_length, sequence_length) + new_ts = tables.tree_sequence() + self.assertEqual(new_ts.sequence_length, sequence_length) + + +class TestFileUuid(HighLevelTestCase): + """ + Tests that the file UUID attribute is handled correctly. + """ + def validate(self, ts): + self.assertIsNone(ts.file_uuid) + ts.dump(self.temp_file) + other_ts = tskit.load(self.temp_file) + self.assertIsNotNone(other_ts.file_uuid) + self.assertTrue(len(other_ts.file_uuid), 36) + uuid = other_ts.file_uuid + other_ts = tskit.load(self.temp_file) + self.assertEqual(other_ts.file_uuid, uuid) + self.assertEqual(ts.tables, other_ts.tables) + + # Check that the UUID is well-formed. + parsed = _uuid.UUID("{" + uuid + "}") + self.assertEqual(str(parsed), uuid) + + # Save the same tree sequence to the file. We should get a different UUID. + ts.dump(self.temp_file) + other_ts = tskit.load(self.temp_file) + self.assertIsNotNone(other_ts.file_uuid) + self.assertNotEqual(other_ts.file_uuid, uuid) + + # Even saving a ts that has a UUID to another file changes the UUID + old_uuid = other_ts.file_uuid + other_ts.dump(self.temp_file) + self.assertEqual(other_ts.file_uuid, old_uuid) + other_ts = tskit.load(self.temp_file) + self.assertIsNotNone(other_ts.file_uuid) + self.assertNotEqual(other_ts.file_uuid, old_uuid) + + # Tables dumped from this ts are a deep copy, so they don't have + # the file_uuid. + tables = other_ts.dump_tables() + self.assertIsNone(tables.file_uuid) + + # For now, ts.tables also returns a deep copy. This will hopefully + # change in the future thoug. + self.assertIsNone(ts.tables.file_uuid) + + def test_simple_simulation(self): + ts = msprime.simulate(2, random_seed=1) + self.validate(ts) + + def test_empty_tables(self): + tables = tskit.TableCollection(1) + self.validate(tables.tree_sequence()) + + +class TestTreeSequenceTextIO(HighLevelTestCase): + """ + Tests for the tree sequence text IO. + """ + + def verify_nodes_format(self, ts, nodes_file, precision): + """ + Verifies that the nodes we output have the correct form. + """ + def convert(v): + return "{:.{}f}".format(v, precision) + output_nodes = nodes_file.read().splitlines() + self.assertEqual(len(output_nodes) - 1, ts.num_nodes) + self.assertEqual( + list(output_nodes[0].split()), + ["id", "is_sample", "time", "population", "individual", "metadata"]) + for node, line in zip(ts.nodes(), output_nodes[1:]): + splits = line.split("\t") + self.assertEqual(str(node.id), splits[0]) + self.assertEqual(str(node.is_sample()), splits[1]) + self.assertEqual(convert(node.time), splits[2]) + self.assertEqual(str(node.population), splits[3]) + self.assertEqual(str(node.individual), splits[4]) + self.assertEqual(tests.base64_encode(node.metadata), splits[5]) + + def verify_edges_format(self, ts, edges_file, precision): + """ + Verifies that the edges we output have the correct form. + """ + def convert(v): + return "{:.{}f}".format(v, precision) + output_edges = edges_file.read().splitlines() + self.assertEqual(len(output_edges) - 1, ts.num_edges) + self.assertEqual( + list(output_edges[0].split()), + ["left", "right", "parent", "child"]) + for edge, line in zip(ts.edges(), output_edges[1:]): + splits = line.split("\t") + self.assertEqual(convert(edge.left), splits[0]) + self.assertEqual(convert(edge.right), splits[1]) + self.assertEqual(str(edge.parent), splits[2]) + self.assertEqual(str(edge.child), splits[3]) + + def verify_sites_format(self, ts, sites_file, precision): + """ + Verifies that the sites we output have the correct form. + """ + def convert(v): + return "{:.{}f}".format(v, precision) + output_sites = sites_file.read().splitlines() + self.assertEqual(len(output_sites) - 1, ts.num_sites) + self.assertEqual( + list(output_sites[0].split()), + ["position", "ancestral_state", "metadata"]) + for site, line in zip(ts.sites(), output_sites[1:]): + splits = line.split("\t") + self.assertEqual(convert(site.position), splits[0]) + self.assertEqual(site.ancestral_state, splits[1]) + self.assertEqual(tests.base64_encode(site.metadata), splits[2]) + + def verify_mutations_format(self, ts, mutations_file, precision): + """ + Verifies that the mutations we output have the correct form. + """ + def convert(v): + return "{:.{}f}".format(v, precision) + output_mutations = mutations_file.read().splitlines() + self.assertEqual(len(output_mutations) - 1, ts.num_mutations) + self.assertEqual( + list(output_mutations[0].split()), + ["site", "node", "derived_state", "parent", "metadata"]) + mutations = [mut for site in ts.sites() for mut in site.mutations] + for mutation, line in zip(mutations, output_mutations[1:]): + splits = line.split("\t") + self.assertEqual(str(mutation.site), splits[0]) + self.assertEqual(str(mutation.node), splits[1]) + self.assertEqual(str(mutation.derived_state), splits[2]) + self.assertEqual(str(mutation.parent), splits[3]) + self.assertEqual(tests.base64_encode(mutation.metadata), splits[4]) + + def test_output_format(self): + for ts in get_example_tree_sequences(): + for precision in [2, 7]: + nodes_file = six.StringIO() + edges_file = six.StringIO() + sites_file = six.StringIO() + mutations_file = six.StringIO() + ts.dump_text( + nodes=nodes_file, edges=edges_file, sites=sites_file, + mutations=mutations_file, precision=precision) + nodes_file.seek(0) + edges_file.seek(0) + sites_file.seek(0) + mutations_file.seek(0) + self.verify_nodes_format(ts, nodes_file, precision) + self.verify_edges_format(ts, edges_file, precision) + self.verify_sites_format(ts, sites_file, precision) + self.verify_mutations_format(ts, mutations_file, precision) + + def verify_approximate_equality(self, ts1, ts2): + """ + Verifies that the specified tree sequences are approximately + equal, taking into account the error incurred in exporting to text. + """ + self.assertEqual(ts1.sample_size, ts2.sample_size) + self.assertAlmostEqual(ts1.sequence_length, ts2.sequence_length) + self.assertEqual(ts1.num_nodes, ts2.num_nodes) + self.assertEqual(ts1.num_edges, ts2.num_edges) + self.assertEqual(ts1.num_sites, ts2.num_sites) + self.assertEqual(ts1.num_mutations, ts2.num_mutations) + + checked = 0 + for n1, n2 in zip(ts1.nodes(), ts2.nodes()): + self.assertEqual(n1.population, n2.population) + self.assertEqual(n1.metadata, n2.metadata) + self.assertAlmostEqual(n1.time, n2.time) + checked += 1 + self.assertEqual(checked, ts1.num_nodes) + + checked = 0 + for r1, r2 in zip(ts1.edges(), ts2.edges()): + checked += 1 + self.assertAlmostEqual(r1.left, r2.left) + self.assertAlmostEqual(r1.right, r2.right) + self.assertEqual(r1.parent, r2.parent) + self.assertEqual(r1.child, r2.child) + self.assertEqual(ts1.num_edges, checked) + + checked = 0 + for s1, s2 in zip(ts1.sites(), ts2.sites()): + checked += 1 + self.assertAlmostEqual(s1.position, s2.position) + self.assertAlmostEqual(s1.ancestral_state, s2.ancestral_state) + self.assertEqual(s1.metadata, s2.metadata) + self.assertEqual(s1.mutations, s2.mutations) + self.assertEqual(ts1.num_sites, checked) + + # Check the trees + check = 0 + for t1, t2 in zip(ts1.trees(), ts2.trees()): + self.assertEqual(list(t1.nodes()), list(t2.nodes())) + check += 1 + self.assertEqual(check, ts1.get_num_trees()) + + def test_text_record_round_trip(self): + for ts1 in get_example_tree_sequences(): + nodes_file = six.StringIO() + edges_file = six.StringIO() + sites_file = six.StringIO() + mutations_file = six.StringIO() + individuals_file = six.StringIO() + populations_file = six.StringIO() + ts1.dump_text( + nodes=nodes_file, edges=edges_file, sites=sites_file, + mutations=mutations_file, individuals=individuals_file, + populations=populations_file, precision=16) + nodes_file.seek(0) + edges_file.seek(0) + sites_file.seek(0) + mutations_file.seek(0) + individuals_file.seek(0) + populations_file.seek(0) + ts2 = tskit.load_text( + nodes=nodes_file, edges=edges_file, sites=sites_file, + mutations=mutations_file, individuals=individuals_file, + populations=populations_file, + sequence_length=ts1.sequence_length, + strict=True) + self.verify_approximate_equality(ts1, ts2) + + def test_empty_files(self): + nodes_file = six.StringIO("is_sample\ttime\n") + edges_file = six.StringIO("left\tright\tparent\tchild\n") + sites_file = six.StringIO("position\tancestral_state\n") + mutations_file = six.StringIO("site\tnode\tderived_state\n") + self.assertRaises( + _tskit.LibraryError, tskit.load_text, + nodes=nodes_file, edges=edges_file, sites=sites_file, + mutations=mutations_file) + + def test_empty_files_sequence_length(self): + nodes_file = six.StringIO("is_sample\ttime\n") + edges_file = six.StringIO("left\tright\tparent\tchild\n") + sites_file = six.StringIO("position\tancestral_state\n") + mutations_file = six.StringIO("site\tnode\tderived_state\n") + ts = tskit.load_text( + nodes=nodes_file, edges=edges_file, sites=sites_file, + mutations=mutations_file, sequence_length=100) + self.assertEqual(ts.sequence_length, 100) + self.assertEqual(ts.num_nodes, 0) + self.assertEqual(ts.num_edges, 0) + self.assertEqual(ts.num_sites, 0) + self.assertEqual(ts.num_edges, 0) + + +class TestTree(HighLevelTestCase): + """ + Some simple tests on the API for the sparse tree. + """ + def get_tree(self, sample_lists=False): + ts = msprime.simulate(10, random_seed=1, mutation_rate=1) + return next(ts.trees(sample_lists=sample_lists)) + + def verify_mutations(self, tree): + self.assertGreater(tree.num_mutations, 0) + other_mutations = [] + for site in tree.sites(): + for mutation in site.mutations: + other_mutations.append(mutation) + mutations = list(tree.mutations()) + self.assertEqual(tree.num_mutations, len(other_mutations)) + self.assertEqual(tree.num_mutations, len(mutations)) + for mut, other_mut in zip(mutations, other_mutations): + # We cannot compare these directly as the mutations obtained + # from the mutations iterator will have extra deprecated + # attributes. + self.assertEqual(mut.id, other_mut.id) + self.assertEqual(mut.site, other_mut.site) + self.assertEqual(mut.parent, other_mut.parent) + self.assertEqual(mut.node, other_mut.node) + self.assertEqual(mut.metadata, other_mut.metadata) + # Check the deprecated attrs. + self.assertEqual(mut.position, tree.tree_sequence.site(mut.site).position) + self.assertEqual(mut.index, mut.site) + + def test_simple_mutations(self): + tree = self.get_tree() + self.verify_mutations(tree) + + def test_complex_mutations(self): + ts = tsutil.insert_branch_mutations(msprime.simulate(10, random_seed=1)) + self.verify_mutations(ts.first()) + + def test_str(self): + t = self.get_tree() + self.assertIsInstance(str(t), str) + self.assertEqual(str(t), str(t.get_parent_dict())) + + def test_samples(self): + for sample_lists in [True, False]: + t = self.get_tree(sample_lists) + n = t.get_sample_size() + all_samples = list(t.samples(t.get_root())) + self.assertEqual(sorted(all_samples), list(range(n))) + for j in range(n): + self.assertEqual(list(t.samples(j)), [j]) + + def test_func(t, u): + """ + Simple test definition of the traversal. + """ + stack = [u] + while len(stack) > 0: + v = stack.pop() + if t.is_sample(v): + yield v + if t.is_internal(v): + for c in reversed(t.get_children(v)): + stack.append(c) + for u in t.nodes(): + l1 = list(t.samples(u)) + l2 = list(test_func(t, u)) + self.assertEqual(l1, l2) + self.assertEqual(t.get_num_samples(u), len(l1)) + + def verify_newick(self, tree): + """ + Verifies that we output the newick tree as expected. + """ + # TODO to make this work we may need to clamp the precision of node + # times because Python and C float printing algorithms work slightly + # differently. Seems to work OK now, so leaving alone. + if tree.num_roots == 1: + py_tree = tests.PythonTree.from_tree(tree) + newick1 = tree.newick(precision=16) + newick2 = py_tree.newick() + self.assertEqual(newick1, newick2) + + # Make sure we get the same results for a leaf root. + newick1 = tree.newick(root=0, precision=16) + newick2 = py_tree.newick(root=0) + self.assertEqual(newick1, newick2) + + # When we specify the node_labels we should get precisely the + # same result as we are using Python code now. + for precision in [0, 3, 19]: + newick1 = tree.newick(precision=precision, node_labels={}) + newick2 = py_tree.newick(precision=precision, node_labels={}) + self.assertEqual(newick1, newick2) + else: + self.assertRaises(ValueError, tree.newick) + for root in tree.roots: + py_tree = tests.PythonTree.from_tree(tree) + newick1 = tree.newick(precision=16, root=root) + newick2 = py_tree.newick(root=root) + self.assertEqual(newick1, newick2) + + def test_newick(self): + for ts in get_example_tree_sequences(): + for tree in ts.trees(): + self.verify_newick(tree) + + def test_traversals(self): + for ts in get_example_tree_sequences(): + tree = next(ts.trees()) + self.verify_traversals(tree) + + def verify_traversals(self, tree): + t1 = tree + t2 = tests.PythonTree.from_tree(t1) + self.assertEqual(list(t1.nodes()), list(t2.nodes())) + orders = ["inorder", "postorder", "levelorder", "breadthfirst"] + if tree.num_roots == 1: + self.assertRaises(ValueError, list, t1.nodes(order="bad order")) + self.assertEqual(list(t1.nodes()), list(t1.nodes(t1.get_root()))) + self.assertEqual( + list(t1.nodes()), + list(t1.nodes(t1.get_root(), "preorder"))) + for u in t1.nodes(): + self.assertEqual(list(t1.nodes(u)), list(t2.nodes(u))) + for test_order in orders: + self.assertEqual( + sorted(list(t1.nodes())), + sorted(list(t1.nodes(order=test_order)))) + self.assertEqual( + list(t1.nodes(order=test_order)), + list(t1.nodes(t1.get_root(), order=test_order))) + self.assertEqual( + list(t1.nodes(order=test_order)), + list(t1.nodes(t1.get_root(), test_order))) + self.assertEqual( + list(t1.nodes(order=test_order)), + list(t2.nodes(order=test_order))) + for u in t1.nodes(): + self.assertEqual( + list(t1.nodes(u, test_order)), + list(t2.nodes(u, test_order))) + else: + for test_order in orders: + all_nodes = [] + for root in t1.roots: + self.assertEqual( + list(t1.nodes(root, order=test_order)), + list(t2.nodes(root, order=test_order))) + all_nodes.extend(t1.nodes(root, order=test_order)) + self.assertEqual(all_nodes, list(t1.nodes(order=test_order))) + + def test_total_branch_length(self): + t1 = self.get_tree() + bl = 0 + root = t1.get_root() + for node in t1.nodes(): + if node != root: + bl += t1.get_branch_length(node) + self.assertGreater(bl, 0) + self.assertEqual(t1.get_total_branch_length(), bl) + + def test_apis(self): + # tree properties + t1 = self.get_tree() + self.assertEqual(t1.get_root(), t1.root) + self.assertEqual(t1.get_index(), t1.index) + self.assertEqual(t1.get_interval(), t1.interval) + self.assertEqual(t1.get_length(), t1.length) + self.assertEqual(t1.get_sample_size(), t1.sample_size) + self.assertEqual(t1.get_num_mutations(), t1.num_mutations) + self.assertEqual(t1.get_parent_dict(), t1.parent_dict) + self.assertEqual(t1.get_total_branch_length(), t1.total_branch_length) + # node properties + root = t1.get_root() + for node in t1.nodes(): + if node != root: + self.assertEqual(t1.get_time(node), t1.time(node)) + self.assertEqual(t1.get_parent(node), t1.parent(node)) + self.assertEqual(t1.get_children(node), t1.children(node)) + self.assertEqual(t1.get_population(node), t1.population(node)) + self.assertEqual(t1.get_num_samples(node), t1.num_samples(node)) + self.assertEqual(t1.get_branch_length(node), + t1.branch_length(node)) + self.assertEqual(t1.get_num_tracked_samples(node), + t1.num_tracked_samples(node)) + + pairs = itertools.islice(itertools.combinations(t1.nodes(), 2), 50) + for pair in pairs: + self.assertEqual(t1.get_mrca(*pair), t1.mrca(*pair)) + self.assertEqual(t1.get_tmrca(*pair), t1.tmrca(*pair)) + + +class TestNodeOrdering(HighLevelTestCase): + """ + Verify that we can use any node ordering for internal nodes + and get the same topologies. + """ + num_random_permutations = 10 + + def verify_tree_sequences_equal(self, ts1, ts2, approx=False): + self.assertEqual(ts1.get_num_trees(), ts2.get_num_trees()) + self.assertEqual(ts1.get_sample_size(), ts2.get_sample_size()) + self.assertEqual(ts1.get_num_nodes(), ts2.get_num_nodes()) + j = 0 + for r1, r2 in zip(ts1.edges(), ts2.edges()): + self.assertEqual(r1.parent, r2.parent) + self.assertEqual(r1.child, r2.child) + if approx: + self.assertAlmostEqual(r1.left, r2.left) + self.assertAlmostEqual(r1.right, r2.right) + else: + self.assertEqual(r1.left, r2.left) + self.assertEqual(r1.right, r2.right) + j += 1 + self.assertEqual(ts1.num_edges, j) + j = 0 + for n1, n2 in zip(ts1.nodes(), ts2.nodes()): + self.assertEqual(n1.metadata, n2.metadata) + self.assertEqual(n1.population, n2.population) + if approx: + self.assertAlmostEqual(n1.time, n2.time) + else: + self.assertEqual(n1.time, n2.time) + j += 1 + self.assertEqual(ts1.num_nodes, j) + + def verify_random_permutation(self, ts): + n = ts.sample_size + node_map = {} + for j in range(n): + node_map[j] = j + internal_nodes = list(range(n, ts.num_nodes)) + random.shuffle(internal_nodes) + for j, node in enumerate(internal_nodes): + node_map[n + j] = node + other_tables = tskit.TableCollection(ts.sequence_length) + # Insert the new nodes into the table. + inv_node_map = {v: k for k, v in node_map.items()} + for j in range(ts.num_nodes): + node = ts.node(inv_node_map[j]) + other_tables.nodes.add_row( + flags=node.flags, time=node.time, population=node.population) + for e in ts.edges(): + other_tables.edges.add_row( + left=e.left, right=e.right, parent=node_map[e.parent], + child=node_map[e.child]) + for _ in range(ts.num_populations): + other_tables.populations.add_row() + other_tables.sort() + other_ts = other_tables.tree_sequence() + + self.assertEqual(ts.get_num_trees(), other_ts.get_num_trees()) + self.assertEqual(ts.get_sample_size(), other_ts.get_sample_size()) + self.assertEqual(ts.get_num_nodes(), other_ts.get_num_nodes()) + j = 0 + for t1, t2 in zip(ts.trees(), other_ts.trees()): + # Verify the topologies are identical. We do this by traversing + # upwards to the root for every sample and checking if we map to + # the correct node and time. + for u in range(n): + v_orig = u + v_map = u + while v_orig != tskit.NULL: + self.assertEqual(node_map[v_orig], v_map) + self.assertEqual( + t1.get_time(v_orig), + t2.get_time(v_map)) + v_orig = t1.get_parent(v_orig) + v_map = t2.get_parent(v_map) + self.assertEqual(v_orig, tskit.NULL) + self.assertEqual(v_map, tskit.NULL) + j += 1 + self.assertEqual(j, ts.get_num_trees()) + # Verify we can dump this new tree sequence OK. + other_ts.dump(self.temp_file) + ts3 = tskit.load(self.temp_file) + self.verify_tree_sequences_equal(other_ts, ts3) + nodes_file = six.StringIO() + edges_file = six.StringIO() + # Also verify we can read the text version. + other_ts.dump_text(nodes=nodes_file, edges=edges_file, precision=14) + nodes_file.seek(0) + edges_file.seek(0) + ts3 = tskit.load_text(nodes_file, edges_file) + self.verify_tree_sequences_equal(other_ts, ts3, True) + + def test_single_locus(self): + ts = msprime.simulate(7) + for _ in range(self.num_random_permutations): + self.verify_random_permutation(ts) + + def test_multi_locus(self): + ts = msprime.simulate(20, recombination_rate=10) + for _ in range(self.num_random_permutations): + self.verify_random_permutation(ts) + + def test_nonbinary(self): + ts = msprime.simulate( + sample_size=20, recombination_rate=10, + demographic_events=[ + msprime.SimpleBottleneck(time=0.5, population=0, proportion=1)]) + # Make sure this really has some non-binary nodes + found = False + for t in ts.trees(): + for u in t.nodes(): + if len(t.children(u)) > 2: + found = True + break + if found: + break + self.assertTrue(found) + for _ in range(self.num_random_permutations): + self.verify_random_permutation(ts) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py new file mode 100644 index 0000000000..b12cc072a0 --- /dev/null +++ b/python/tests/test_lowlevel.py @@ -0,0 +1,786 @@ +""" +Test cases for the low level C interface to tskit. +""" +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals + +import collections +import itertools +import os +import platform +import random +import sys +import tempfile +import unittest + +import msprime + +import _tskit + +IS_PY2 = sys.version_info[0] < 3 +IS_WINDOWS = platform.system() == "Windows" + + +def get_tracked_sample_counts(st, tracked_samples): + """ + Returns a list giving the number of samples in the specified list + that are in the subtree rooted at each node. + """ + nu = [0 for j in range(st.get_num_nodes())] + for j in tracked_samples: + # Duplicates not permitted. + assert nu[j] == 0 + u = j + while u != _tskit.NULL: + nu[u] += 1 + u = st.get_parent(u) + return nu + + +def get_sample_counts(tree_sequence, st): + """ + Returns a list of the sample node counts for the specfied sparse tree. + """ + nu = [0 for j in range(st.get_num_nodes())] + for j in range(tree_sequence.get_num_samples()): + u = j + while u != _tskit.NULL: + nu[u] += 1 + u = st.get_parent(u) + return nu + + +class LowLevelTestCase(unittest.TestCase): + """ + Superclass of tests for the low-level interface. + """ + def verify_tree_dict(self, n, pi): + """ + Verifies that the specified sparse tree in dict format is a + consistent coalescent history for a sample of size n. + """ + self.assertLessEqual(len(pi), 2 * n - 1) + # _tskit.NULL should not be a node + self.assertNotIn(_tskit.NULL, pi) + # verify the root is equal for all samples + root = 0 + while pi[root] != _tskit.NULL: + root = pi[root] + for j in range(n): + k = j + while pi[k] != _tskit.NULL: + k = pi[k] + self.assertEqual(k, root) + # 0 to n - 1 inclusive should always be nodes + for j in range(n): + self.assertIn(j, pi) + num_children = collections.defaultdict(int) + for j in pi.keys(): + num_children[pi[j]] += 1 + # nodes 0 to n are samples. + for j in range(n): + self.assertNotEqual(pi[j], 0) + self.assertEqual(num_children[j], 0) + # All non-sample nodes should be binary + for j in pi.keys(): + if j > n: + self.assertGreaterEqual(num_children[j], 2) + + def get_example_tree_sequence(self): + ts = msprime.simulate(10, recombination_rate=0.1, random_seed=1) + return ts.ll_tree_sequence + + def get_example_tree_sequences(self): + yield self.get_example_tree_sequence() + yield self.get_example_migration_tree_sequence() + + def get_example_migration_tree_sequence(self): + pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] + migration_matrix = [[0, 1], [1, 0]] + ts = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1) + return ts.ll_tree_sequence + + def verify_iterator(self, iterator): + """ + Checks that the specified non-empty iterator implements the + iterator protocol correctly. + """ + list_ = list(iterator) + self.assertGreater(len(list_), 0) + for j in range(10): + self.assertRaises(StopIteration, next, iterator) + + +class TestTableCollection(LowLevelTestCase): + """ + Tests for the low-level TableCollection class + """ + + def test_reference_deletion(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=1) + tc = ts.tables.ll_tables + # Get references to all the tables + tables = [ + tc.individuals, tc.nodes, tc.edges, tc.migrations, tc.sites, tc.mutations, + tc.populations, tc.provenances] + del tc + for _ in range(10): + for table in tables: + self.assertGreater(len(str(table)), 0) + + +class TestTreeSequence(LowLevelTestCase): + """ + Tests for the low-level interface for the TreeSequence. + """ + def setUp(self): + fd, self.temp_file = tempfile.mkstemp(prefix="msp_ll_ts_") + os.close(fd) + + def tearDown(self): + os.unlink(self.temp_file) + + @unittest.skipIf(IS_WINDOWS, "File permissions on Windows") + def test_file_errors(self): + ts1 = self.get_example_tree_sequence() + + def loader(*args): + ts2 = _tskit.TreeSequence() + ts2.load(*args) + + for func in [ts1.dump, loader]: + self.assertRaises(TypeError, func) + for bad_type in [1, None, [], {}]: + self.assertRaises(TypeError, func, bad_type) + # Try to dump/load files we don't have access to or don't exist. + for f in ["/", "/test.trees", "/dir_does_not_exist/x.trees"]: + self.assertRaises(_tskit.FileFormatError, func, f) + try: + func(f) + except _tskit.FileFormatError as e: + message = str(e) + self.assertGreater(len(message), 0) + # use a long filename and make sure we don't overflow error + # buffers + f = "/" + 4000 * "x" + self.assertRaises(_tskit.FileFormatError, func, f) + try: + func(f) + except _tskit.FileFormatError as e: + message = str(e) + self.assertLess(len(message), 1024) + + def test_initial_state(self): + # Check the initial state to make sure that it is empty. + ts = _tskit.TreeSequence() + self.assertRaises(ValueError, ts.get_num_samples) + self.assertRaises(ValueError, ts.get_sequence_length) + self.assertRaises(ValueError, ts.get_num_trees) + self.assertRaises(ValueError, ts.get_num_edges) + self.assertRaises(ValueError, ts.get_num_mutations) + self.assertRaises(ValueError, ts.get_num_migrations) + self.assertRaises(ValueError, ts.get_num_migrations) + self.assertRaises(ValueError, ts.get_genotype_matrix) + self.assertRaises(ValueError, ts.dump) + + def test_num_nodes(self): + for ts in self.get_example_tree_sequences(): + max_node = 0 + for j in range(ts.get_num_edges()): + _, _, parent, child = ts.get_edge(j) + for node in [parent, child]: + if node > max_node: + max_node = node + self.assertEqual(max_node + 1, ts.get_num_nodes()) + + def verify_dump_equality(self, ts): + """ + Verifies that we can dump a copy of the specified tree sequence + to the specified file, and load an identical copy. + """ + ts.dump(self.temp_file) + ts2 = _tskit.TreeSequence() + ts2.load(self.temp_file) + self.assertEqual(ts.get_num_samples(), ts2.get_num_samples()) + self.assertEqual(ts.get_sequence_length(), ts2.get_sequence_length()) + self.assertEqual(ts.get_num_mutations(), ts2.get_num_mutations()) + self.assertEqual(ts.get_num_nodes(), ts2.get_num_nodes()) + records1 = [ts.get_edge(j) for j in range(ts.get_num_edges())] + records2 = [ts2.get_edge(j) for j in range(ts2.get_num_edges())] + self.assertEqual(records1, records2) + mutations1 = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] + mutations2 = [ts2.get_mutation(j) for j in range(ts2.get_num_mutations())] + self.assertEqual(mutations1, mutations2) + provenances1 = [ts.get_provenance(j) for j in range(ts.get_num_provenances())] + provenances2 = [ts2.get_provenance(j) for j in range(ts2.get_num_provenances())] + self.assertEqual(provenances1, provenances2) + + def test_dump_equality(self): + for ts in self.get_example_tree_sequences(): + self.verify_dump_equality(ts) + + def verify_mutations(self, ts): + mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] + self.assertGreater(ts.get_num_mutations(), 0) + self.assertEqual(len(mutations), ts.get_num_mutations()) + # Check the form of the mutations + for j, (position, nodes, index) in enumerate(mutations): + self.assertEqual(j, index) + for node in nodes: + self.assertIsInstance(node, int) + self.assertGreaterEqual(node, 0) + self.assertLessEqual(node, ts.get_num_nodes()) + self.assertIsInstance(position, float) + self.assertGreater(position, 0) + self.assertLess(position, ts.get_sequence_length()) + # mutations must be sorted by position order. + self.assertEqual(mutations, sorted(mutations)) + + def test_get_edge_interface(self): + for ts in self.get_example_tree_sequences(): + num_edges = ts.get_num_edges() + # We don't accept Python negative indexes here. + self.assertRaises(IndexError, ts.get_edge, -1) + for j in [0, 10, 10**6]: + self.assertRaises(IndexError, ts.get_edge, num_edges + j) + for x in [None, "", {}, []]: + self.assertRaises(TypeError, ts.get_edge, x) + + def test_get_node_interface(self): + for ts in self.get_example_tree_sequences(): + num_nodes = ts.get_num_nodes() + # We don't accept Python negative indexes here. + self.assertRaises(IndexError, ts.get_node, -1) + for j in [0, 10, 10**6]: + self.assertRaises(IndexError, ts.get_node, num_nodes + j) + for x in [None, "", {}, []]: + self.assertRaises(TypeError, ts.get_node, x) + + def test_get_genotype_matrix_interface(self): + for ts in self.get_example_tree_sequences(): + num_samples = ts.get_num_samples() + num_sites = ts.get_num_sites() + G = ts.get_genotype_matrix() + self.assertEqual(G.shape, (num_sites, num_samples)) + + def test_get_migration_interface(self): + ts = self.get_example_migration_tree_sequence() + for bad_type in ["", None, {}]: + self.assertRaises(TypeError, ts.get_migration, bad_type) + num_records = ts.get_num_migrations() + # We don't accept Python negative indexes here. + self.assertRaises(IndexError, ts.get_migration, -1) + for j in [0, 10, 10**6]: + self.assertRaises(IndexError, ts.get_migration, num_records + j) + + def test_get_samples(self): + ts = self.get_example_migration_tree_sequence() + # get_samples takes no arguments. + self.assertRaises(TypeError, ts.get_samples, 0) + self.assertEqual(list(range(ts.get_num_samples())), ts.get_samples()) + + def test_pairwise_diversity(self): + for ts in self.get_example_tree_sequences(): + for bad_type in ["", None, {}]: + self.assertRaises( + TypeError, ts.get_pairwise_diversity, bad_type) + self.assertRaises( + ValueError, ts.get_pairwise_diversity, []) + self.assertRaises( + ValueError, ts.get_pairwise_diversity, [0]) + self.assertRaises( + ValueError, ts.get_pairwise_diversity, + [0, ts.get_num_samples()]) + self.assertRaises( + _tskit.LibraryError, ts.get_pairwise_diversity, [0, 0]) + samples = list(range(ts.get_num_samples())) + pi1 = ts.get_pairwise_diversity(samples) + self.assertGreaterEqual(pi1, 0) + + def test_genealogical_nearest_neighbours(self): + for ts in self.get_example_tree_sequences(): + self.assertRaises(TypeError, ts.genealogical_nearest_neighbours) + self.assertRaises( + TypeError, ts.genealogical_nearest_neighbours, focal=None) + self.assertRaises( + TypeError, ts.genealogical_nearest_neighbours, focal=ts.get_samples(), + reference_sets={}) + self.assertRaises( + ValueError, ts.genealogical_nearest_neighbours, focal=ts.get_samples(), + reference_sets=[]) + + bad_array_values = ["", {}, "x", [[[0], [1, 2]]]] + for bad_array_value in bad_array_values: + self.assertRaises( + ValueError, ts.genealogical_nearest_neighbours, + focal=bad_array_value, reference_sets=[[0], [1]]) + self.assertRaises( + ValueError, ts.genealogical_nearest_neighbours, + focal=ts.get_samples(), reference_sets=[[0], bad_array_value]) + self.assertRaises( + ValueError, ts.genealogical_nearest_neighbours, + focal=ts.get_samples(), reference_sets=[bad_array_value]) + focal = ts.get_samples() + A = ts.genealogical_nearest_neighbours(focal, [focal[2:], focal[:2]]) + self.assertEqual(A.shape, (len(focal), 2)) + + def test_mean_descendants(self): + for ts in self.get_example_tree_sequences(): + self.assertRaises(TypeError, ts.mean_descendants) + self.assertRaises(TypeError, ts.mean_descendants, reference_sets={}) + self.assertRaises(ValueError, ts.mean_descendants, reference_sets=[]) + + bad_array_values = ["", {}, "x", [[[0], [1, 2]]]] + for bad_array_value in bad_array_values: + self.assertRaises( + ValueError, ts.mean_descendants, + reference_sets=[[0], bad_array_value]) + self.assertRaises( + ValueError, ts.mean_descendants, reference_sets=[bad_array_value]) + focal = ts.get_samples() + A = ts.mean_descendants([focal[2:], focal[:2]]) + self.assertEqual(A.shape, (ts.get_num_nodes(), 2)) + + +class TestTreeDiffIterator(LowLevelTestCase): + """ + Tests for the low-level tree diff iterator. + """ + def test_uninitialised_tree_sequence(self): + ts = _tskit.TreeSequence() + self.assertRaises(ValueError, _tskit.TreeDiffIterator, ts) + + def test_constructor(self): + self.assertRaises(TypeError, _tskit.TreeDiffIterator) + self.assertRaises(TypeError, _tskit.TreeDiffIterator, None) + ts = self.get_example_tree_sequence() + before = list(_tskit.TreeDiffIterator(ts)) + iterator = _tskit.TreeDiffIterator(ts) + del ts + # We should keep a reference to the tree sequence. + after = list(iterator) + self.assertEqual(before, after) + + def test_iterator(self): + ts = self.get_example_tree_sequence() + self.verify_iterator(_tskit.TreeDiffIterator(ts)) + + +class TestTreeIterator(LowLevelTestCase): + """ + Tests for the low-level sparse tree iterator. + """ + def test_uninitialised_tree_sequence(self): + ts = _tskit.TreeSequence() + self.assertRaises(ValueError, _tskit.Tree, ts) + + def test_constructor(self): + self.assertRaises(TypeError, _tskit.TreeIterator) + self.assertRaises(TypeError, _tskit.TreeIterator, None) + ts = _tskit.TreeSequence() + self.assertRaises(TypeError, _tskit.TreeIterator, ts) + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + n_before = 0 + parents_before = [] + for t in _tskit.TreeIterator(tree): + n_before += 1 + self.assertIs(t, tree) + pi = {} + for j in range(t.get_num_nodes()): + pi[j] = t.get_parent(j) + parents_before.append(pi) + self.assertEqual(n_before, len(list(_tskit.TreeDiffIterator(ts)))) + # If we remove the objects, we should get the same results. + iterator = _tskit.TreeIterator(tree) + del tree + del ts + n_after = 0 + parents_after = [] + for index, t in enumerate(iterator): + n_after += 1 + self.assertIsInstance(t, _tskit.Tree) + pi = {} + for j in range(t.get_num_nodes()): + pi[j] = t.get_parent(j) + parents_after.append(pi) + self.assertEqual(index, t.get_index()) + self.assertEqual(parents_before, parents_after) + + def test_iterator(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + self.verify_iterator(_tskit.TreeIterator(tree)) + + +class TestTree(LowLevelTestCase): + """ + Tests on the low-level sparse tree interface. + """ + + def test_flags(self): + ts = self.get_example_tree_sequence() + st = _tskit.Tree(ts) + self.assertEqual(st.get_flags(), 0) + # We should still be able to count the samples, just inefficiently. + self.assertEqual(st.get_num_samples(0), 1) + self.assertRaises(_tskit.LibraryError, st.get_num_tracked_samples, 0) + all_flags = [ + 0, _tskit.SAMPLE_COUNTS, _tskit.SAMPLE_LISTS, + _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS] + for flags in all_flags: + st = _tskit.Tree(ts, flags=flags) + self.assertEqual(st.get_flags(), flags) + self.assertEqual(st.get_num_samples(0), 1) + if flags & _tskit.SAMPLE_COUNTS: + self.assertEqual(st.get_num_tracked_samples(0), 0) + else: + self.assertRaises(_tskit.LibraryError, st.get_num_tracked_samples, 0) + if flags & _tskit.SAMPLE_LISTS: + self.assertEqual(0, st.get_left_sample(0)) + self.assertEqual(0, st.get_right_sample(0)) + else: + self.assertRaises(ValueError, st.get_left_sample, 0) + self.assertRaises(ValueError, st.get_right_sample, 0) + self.assertRaises(ValueError, st.get_next_sample, 0) + + def test_sites(self): + for ts in self.get_example_tree_sequences(): + st = _tskit.Tree(ts) + all_sites = [ts.get_site(j) for j in range(ts.get_num_sites())] + all_tree_sites = [] + j = 0 + mutation_id = 0 + for st in _tskit.TreeIterator(st): + tree_sites = st.get_sites() + self.assertEqual(st.get_num_sites(), len(tree_sites)) + all_tree_sites.extend(tree_sites) + for position, ancestral_state, mutations, index, metadata in tree_sites: + self.assertTrue(st.get_left() <= position < st.get_right()) + self.assertEqual(index, j) + self.assertEqual(metadata, b"") + for mut_id in mutations: + site, node, derived_state, parent, metadata = \ + ts.get_mutation(mut_id) + self.assertEqual(site, index) + self.assertEqual(mutation_id, mut_id) + self.assertNotEqual(st.get_parent(node), _tskit.NULL) + self.assertEqual(metadata, b"") + mutation_id += 1 + j += 1 + self.assertEqual(all_tree_sites, all_sites) + + def test_constructor(self): + self.assertRaises(TypeError, _tskit.Tree) + for bad_type in ["", {}, [], None, 0]: + self.assertRaises( + TypeError, _tskit.Tree, bad_type) + ts = self.get_example_tree_sequence() + for bad_type in ["", {}, True, 1, None]: + self.assertRaises( + TypeError, _tskit.Tree, ts, tracked_samples=bad_type) + for bad_type in ["", {}, None, []]: + self.assertRaises( + TypeError, _tskit.Tree, ts, flags=bad_type) + for ts in self.get_example_tree_sequences(): + st = _tskit.Tree(ts) + self.assertEqual(st.get_num_nodes(), ts.get_num_nodes()) + # An uninitialised sparse tree should always be zero. + self.assertEqual(st.get_left_root(), 0) + self.assertEqual(st.get_left(), 0) + self.assertEqual(st.get_right(), 0) + for j in range(ts.get_num_samples()): + self.assertEqual(st.get_parent(j), _tskit.NULL) + self.assertEqual(st.get_children(j), tuple()) + self.assertEqual(st.get_time(j), 0) + + def test_memory_error(self): + # This provokes a bug where we weren't reference counting + # the tree sequence properly, and the underlying memory for a + # sparse tree was getting corrupted. + for ts in self.get_example_tree_sequences(): + num_nodes = ts.get_num_nodes() + st = _tskit.Tree(ts) + # deleting the tree sequence should still give a well formed + # sparse tree. + st_iter = _tskit.TreeIterator(st) + next(st_iter) + del ts + del st_iter + # Do a quick traversal just to exercise the tree + stack = [st.get_left_root()] + while len(stack) > 0: + u = stack.pop() + self.assertLess(u, num_nodes) + stack.extend(st.get_children(u)) + + def test_bad_tracked_samples(self): + ts = self.get_example_tree_sequence() + flags = _tskit.SAMPLE_COUNTS + for bad_type in ["", {}, [], None]: + self.assertRaises( + TypeError, _tskit.Tree, ts, flags=flags, + tracked_samples=[bad_type]) + self.assertRaises( + TypeError, _tskit.Tree, ts, flags=flags, + tracked_samples=[1, bad_type]) + for bad_sample in [10**6, -1e6]: + self.assertRaises( + ValueError, _tskit.Tree, ts, flags=flags, + tracked_samples=[bad_sample]) + self.assertRaises( + ValueError, _tskit.Tree, ts, flags=flags, + tracked_samples=[1, bad_sample]) + self.assertRaises( + ValueError, _tskit.Tree, ts, + tracked_samples=[1, bad_sample, 1]) + + def test_count_all_samples(self): + for ts in self.get_example_tree_sequences(): + self.verify_iterator(_tskit.TreeDiffIterator(ts)) + st = _tskit.Tree(ts, flags=_tskit.SAMPLE_COUNTS) + # Without initialisation we should be 0 samples for every node + # that is not a sample. + for j in range(st.get_num_nodes()): + count = 1 if j < ts.get_num_samples() else 0 + self.assertEqual(st.get_num_samples(j), count) + self.assertEqual(st.get_num_tracked_samples(j), 0) + # Now, try this for a tree sequence. + for st in _tskit.TreeIterator(st): + nu = get_sample_counts(ts, st) + nu_prime = [ + st.get_num_samples(j) for j in + range(st.get_num_nodes())] + self.assertEqual(nu, nu_prime) + # For tracked samples, this should be all zeros. + nu = [ + st.get_num_tracked_samples(j) for j in + range(st.get_num_nodes())] + self.assertEqual(nu, list([0 for _ in nu])) + + def test_count_tracked_samples(self): + # Ensure that there are some non-binary nodes. + non_binary = False + for ts in self.get_example_tree_sequences(): + st = _tskit.Tree(ts) + for st in _tskit.TreeIterator(st): + for u in range(ts.get_num_nodes()): + if len(st.get_children(u)) > 1: + non_binary = True + samples = [j for j in range(ts.get_num_samples())] + powerset = itertools.chain.from_iterable( + itertools.combinations(samples, r) + for r in range(len(samples) + 1)) + for subset in map(list, powerset): + # Ordering shouldn't make any different. + random.shuffle(subset) + st = _tskit.Tree( + ts, flags=_tskit.SAMPLE_COUNTS, tracked_samples=subset) + for st in _tskit.TreeIterator(st): + nu = get_tracked_sample_counts(st, subset) + nu_prime = [ + st.get_num_tracked_samples(j) for j in + range(st.get_num_nodes())] + self.assertEqual(nu, nu_prime) + # Passing duplicated values should raise an error + sample = 1 + for j in range(2, 20): + tracked_samples = [sample for _ in range(j)] + self.assertRaises( + _tskit.LibraryError, _tskit.Tree, + ts, flags=_tskit.SAMPLE_COUNTS, + tracked_samples=tracked_samples) + self.assertTrue(non_binary) + + def test_bounds_checking(self): + for ts in self.get_example_tree_sequences(): + n = ts.get_num_nodes() + st = _tskit.Tree( + ts, flags=_tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS) + for v in [-100, -1, n + 1, n + 100, n * 100]: + self.assertRaises(ValueError, st.get_parent, v) + self.assertRaises(ValueError, st.get_children, v) + self.assertRaises(ValueError, st.get_time, v) + self.assertRaises(ValueError, st.get_left_sample, v) + self.assertRaises(ValueError, st.get_right_sample, v) + n = ts.get_num_samples() + for v in [-100, -1, n + 1, n + 100, n * 100]: + self.assertRaises(ValueError, st.get_next_sample, v) + + def test_mrca_interface(self): + for ts in self.get_example_tree_sequences(): + num_nodes = ts.get_num_nodes() + st = _tskit.Tree(ts) + for v in [num_nodes, 10**6, _tskit.NULL]: + self.assertRaises(ValueError, st.get_mrca, v, v) + self.assertRaises(ValueError, st.get_mrca, v, 1) + self.assertRaises(ValueError, st.get_mrca, 1, v) + # All the mrcas for an uninitialised tree should be _tskit.NULL + for u, v in itertools.combinations(range(num_nodes), 2): + self.assertEqual(st.get_mrca(u, v), _tskit.NULL) + + def test_newick_precision(self): + + def get_times(tree): + """ + Returns the time strings from the specified newick tree. + """ + ret = [] + current_time = None + for c in tree: + if c == ":": + current_time = "" + elif c in [",", ")"]: + ret.append(current_time) + current_time = None + elif current_time is not None: + current_time += c + return ret + + ts = self.get_example_tree_sequence() + st = _tskit.Tree(ts) + for st in _tskit.TreeIterator(st): + self.assertRaises(ValueError, st.get_newick, root=0, precision=-1) + self.assertRaises(ValueError, st.get_newick, root=0, precision=17) + self.assertRaises(ValueError, st.get_newick, root=0, precision=100) + for precision in range(17): + tree = st.get_newick( + root=st.get_left_root(), precision=precision).decode() + times = get_times(tree) + self.assertGreater(len(times), ts.get_num_samples()) + for t in times: + if precision == 0: + self.assertNotIn(".", t) + else: + point = t.find(".") + self.assertEqual(precision, len(t) - point - 1) + + @unittest.skip("Correct initialisation for sparse tree.") + def test_newick_interface(self): + ts = self.get_tree_sequence(num_loci=10, num_samples=10) + st = _tskit.Tree(ts) + # TODO this will break when we correctly handle multiple roots. + self.assertEqual(st.get_newick(), b"1;") + for bad_type in [None, "", [], {}]: + self.assertRaises(TypeError, st.get_newick, precision=bad_type) + self.assertRaises(TypeError, st.get_newick, ts, time_scale=bad_type) + for st in _tskit.TreeIterator(st): + newick = st.get_newick() + self.assertTrue(newick.endswith(b";")) + + def test_index(self): + for ts in self.get_example_tree_sequences(): + st = _tskit.Tree(ts) + for index, st in enumerate(_tskit.TreeIterator(st)): + self.assertEqual(index, st.get_index()) + + def test_bad_mutations(self): + ts = self.get_example_tree_sequence() + tables = _tskit.TableCollection() + ts.dump_tables(tables) + + def f(mutations): + position = [] + node = [] + site = [] + ancestral_state = [] + ancestral_state_offset = [0] + derived_state = [] + derived_state_offset = [0] + for j, (p, n) in enumerate(mutations): + site.append(j) + position.append(p) + ancestral_state.append("0") + ancestral_state_offset.append(ancestral_state_offset[-1] + 1) + derived_state.append("1") + derived_state_offset.append(derived_state_offset[-1] + 1) + node.append(n) + tables.sites.set_columns(dict( + position=position, ancestral_state=ancestral_state, + ancestral_state_offset=ancestral_state_offset, + metadata=None, metadata_offset=None)) + tables.mutations.set_columns(dict( + site=site, node=node, derived_state=derived_state, + derived_state_offset=derived_state_offset, + parent=None, metadata=None, metadata_offset=None)) + ts2 = _tskit.TreeSequence() + ts2.load_tables(tables) + self.assertRaises(_tskit.LibraryError, f, [(0.1, -1)]) + length = ts.get_sequence_length() + u = ts.get_num_nodes() + for bad_node in [u, u + 1, 2 * u]: + self.assertRaises(_tskit.LibraryError, f, [(0.1, bad_node)]) + for bad_pos in [-1, length, length + 1]: + self.assertRaises(_tskit.LibraryError, f, [(length, 0)]) + + def test_free(self): + ts = self.get_example_tree_sequence() + t = _tskit.Tree( + ts, flags=_tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS) + no_arg_methods = [ + t.get_left_root, t.get_index, t.get_left, t.get_right, + t.get_num_sites, t.get_flags, t.get_sites, t.get_num_nodes] + node_arg_methods = [ + t.get_parent, t.get_population, t.get_children, t.get_num_samples, + t.get_num_tracked_samples] + two_node_arg_methods = [t.get_mrca] + for method in no_arg_methods: + method() + for method in node_arg_methods: + method(0) + for method in two_node_arg_methods: + method(0, 0) + t.free() + self.assertRaises(RuntimeError, t.free) + for method in no_arg_methods: + self.assertRaises(RuntimeError, method) + for method in node_arg_methods: + self.assertRaises(RuntimeError, method, 0) + for method in two_node_arg_methods: + self.assertRaises(RuntimeError, method, 0, 0) + + def test_sample_list(self): + flags = _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS + # Note: we're assuming that samples are 0-n here. + for ts in self.get_example_tree_sequences(): + st = _tskit.Tree(ts, flags=flags) + for t in _tskit.TreeIterator(st): + # All sample nodes should have themselves. + for j in range(ts.get_num_samples()): + self.assertEqual(t.get_left_sample(j), j) + self.assertEqual(t.get_right_sample(j), j) + + # All non-tree nodes should have 0 + for j in range(t.get_num_nodes()): + if t.get_parent(j) == _tskit.NULL \ + and t.get_left_child(j) == _tskit.NULL: + self.assertEqual(t.get_left_sample(j), _tskit.NULL) + self.assertEqual(t.get_right_sample(j), _tskit.NULL) + # The roots should have all samples. + u = t.get_left_root() + samples = [] + while u != _tskit.NULL: + sample = t.get_left_sample(u) + end = t.get_right_sample(u) + while True: + samples.append(sample) + if sample == end: + break + sample = t.get_next_sample(sample) + u = t.get_right_sib(u) + self.assertEqual(sorted(samples), list(range(ts.get_num_samples()))) + + +class TestModuleFunctions(unittest.TestCase): + """ + Tests for the module level functions. + """ + def test_kastore_version(self): + version = _tskit.get_kastore_version() + self.assertEqual(version, (0, 1, 0)) diff --git a/python/tests/test_metadata.py b/python/tests/test_metadata.py new file mode 100644 index 0000000000..7d44d9615e --- /dev/null +++ b/python/tests/test_metadata.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +""" +Tests for metadata handling. +""" +from __future__ import print_function +from __future__ import division + +import json +import os +import tempfile +import unittest +import pickle + +import numpy as np +import python_jsonschema_objects as pjs +import six +import msprime + +import tskit + + +class TestMetadataHdf5RoundTrip(unittest.TestCase): + """ + Tests that we can encode metadata under various formats and this will + successfully round-trip through the HDF5 format. + """ + def setUp(self): + fd, self.temp_file = tempfile.mkstemp(prefix="msp_hdf5meta_test_") + os.close(fd) + + def tearDown(self): + os.unlink(self.temp_file) + + def test_json(self): + ts = msprime.simulate(10, random_seed=1) + tables = ts.dump_tables() + nodes = tables.nodes + # For each node, we create some Python metadata that can be JSON encoded. + metadata = [ + {"one": j, "two": 2 * j, "three": list(range(j))} for j in range(len(nodes))] + encoded, offset = tskit.pack_strings(map(json.dumps, metadata)) + nodes.set_columns( + flags=nodes.flags, time=nodes.time, population=nodes.population, + metadata_offset=offset, metadata=encoded) + self.assertTrue(np.array_equal(nodes.metadata_offset, offset)) + self.assertTrue(np.array_equal(nodes.metadata, encoded)) + ts1 = tables.tree_sequence() + for j, node in enumerate(ts1.nodes()): + decoded_metadata = json.loads(node.metadata.decode()) + self.assertEqual(decoded_metadata, metadata[j]) + ts1.dump(self.temp_file) + ts2 = tskit.load(self.temp_file) + self.assertEqual(ts1.tables.nodes, ts2.tables.nodes) + + def test_pickle(self): + ts = msprime.simulate(10, random_seed=1) + tables = ts.dump_tables() + # For each node, we create some Python metadata that can be pickled + metadata = [ + {"one": j, "two": 2 * j, "three": list(range(j))} + for j in range(ts.num_nodes)] + encoded, offset = tskit.pack_bytes(list(map(pickle.dumps, metadata))) + tables.nodes.set_columns( + flags=tables.nodes.flags, time=tables.nodes.time, + population=tables.nodes.population, + metadata_offset=offset, metadata=encoded) + self.assertTrue(np.array_equal(tables.nodes.metadata_offset, offset)) + self.assertTrue(np.array_equal(tables.nodes.metadata, encoded)) + ts1 = tables.tree_sequence() + for j, node in enumerate(ts1.nodes()): + decoded_metadata = pickle.loads(node.metadata) + self.assertEqual(decoded_metadata, metadata[j]) + ts1.dump(self.temp_file) + ts2 = tskit.load(self.temp_file) + self.assertEqual(ts1.tables.nodes, ts2.tables.nodes) + + +class ExampleMetadata(object): + """ + Simple class that we can pickle/unpickle in metadata. + """ + def __init__(self, one=None, two=None): + self.one = one + self.two = two + + +class TestMetadataPickleDecoding(unittest.TestCase): + """ + Tests in which use pickle.pickle to decode metadata in nodes, sites and mutations. + """ + + def test_nodes(self): + tables = tskit.TableCollection(sequence_length=1) + metadata = ExampleMetadata(one="node1", two="node2") + pickled = pickle.dumps(metadata) + tables.nodes.add_row(time=0.125, metadata=pickled) + ts = tables.tree_sequence() + node = ts.node(0) + self.assertEqual(node.time, 0.125) + self.assertEqual(node.metadata, pickled) + unpickled = pickle.loads(node.metadata) + self.assertEqual(unpickled.one, metadata.one) + self.assertEqual(unpickled.two, metadata.two) + + def test_sites(self): + tables = tskit.TableCollection(sequence_length=1) + metadata = ExampleMetadata(one="node1", two="node2") + pickled = pickle.dumps(metadata) + tables.sites.add_row(position=0.1, ancestral_state="A", metadata=pickled) + ts = tables.tree_sequence() + site = ts.site(0) + self.assertEqual(site.position, 0.1) + self.assertEqual(site.ancestral_state, "A") + self.assertEqual(site.metadata, pickled) + unpickled = pickle.loads(site.metadata) + self.assertEqual(unpickled.one, metadata.one) + self.assertEqual(unpickled.two, metadata.two) + + def test_mutations(self): + tables = tskit.TableCollection(sequence_length=1) + metadata = ExampleMetadata(one="node1", two="node2") + pickled = pickle.dumps(metadata) + tables.nodes.add_row(time=0) + tables.sites.add_row(position=0.1, ancestral_state="A") + tables.mutations.add_row(site=0, node=0, derived_state="T", metadata=pickled) + ts = tables.tree_sequence() + mutation = ts.site(0).mutations[0] + self.assertEqual(mutation.site, 0) + self.assertEqual(mutation.node, 0) + self.assertEqual(mutation.derived_state, "T") + self.assertEqual(mutation.metadata, pickled) + unpickled = pickle.loads(mutation.metadata) + self.assertEqual(unpickled.one, metadata.one) + self.assertEqual(unpickled.two, metadata.two) + + +class TestJsonSchemaDecoding(unittest.TestCase): + """ + Tests in which use json-schema to decode the metadata. + """ + schema = """{ + "title": "Example Metadata", + "type": "object", + "properties": { + "one": {"type": "string"}, + "two": {"type": "string"} + }, + "required": ["one", "two"] + }""" + + def test_nodes(self): + tables = tskit.TableCollection(sequence_length=1) + builder = pjs.ObjectBuilder(json.loads(self.schema)) + ns = builder.build_classes() + metadata = ns.ExampleMetadata(one="node1", two="node2") + encoded = json.dumps(metadata.as_dict()).encode() + tables.nodes.add_row(time=0.125, metadata=encoded) + ts = tables.tree_sequence() + node = ts.node(0) + self.assertEqual(node.time, 0.125) + self.assertEqual(node.metadata, encoded) + decoded = ns.ExampleMetadata.from_json(node.metadata.decode()) + self.assertEqual(decoded.one, metadata.one) + self.assertEqual(decoded.two, metadata.two) + + +class TestLoadTextMetadata(unittest.TestCase): + """ + Tests that use the load_text interface. + """ + + def test_individuals(self): + individuals = six.StringIO("""\ + id flags location metadata + 0 1 0.0,1.0,0.0 abc + 1 1 1.0,2.0 XYZ+ + 2 0 2.0,3.0,0.0 !@#$%^&*() + """) + i = tskit.parse_individuals( + individuals, strict=False, encoding='utf8', base64_metadata=False) + expected = [(1, [0.0, 1.0, 0.0], 'abc'), + (1, [1.0, 2.0], 'XYZ+'), + (0, [2.0, 3.0, 0.0], '!@#$%^&*()')] + for a, b in zip(expected, i): + self.assertEqual(a[0], b.flags) + self.assertEqual(len(a[1]), len(b.location)) + for x, y in zip(a[1], b.location): + self.assertEqual(x, y) + self.assertEqual(a[2].encode('utf8'), + b.metadata) + + def test_nodes(self): + nodes = six.StringIO("""\ + id is_sample time metadata + 0 1 0 abc + 1 1 0 XYZ+ + 2 0 1 !@#$%^&*() + """) + n = tskit.parse_nodes( + nodes, strict=False, encoding='utf8', base64_metadata=False) + expected = ['abc', 'XYZ+', '!@#$%^&*()'] + for a, b in zip(expected, n): + self.assertEqual(a.encode('utf8'), + b.metadata) + + def test_sites(self): + sites = six.StringIO("""\ + position ancestral_state metadata + 0.1 A abc + 0.5 C XYZ+ + 0.8 G !@#$%^&*() + """) + s = tskit.parse_sites( + sites, strict=False, encoding='utf8', base64_metadata=False) + expected = ['abc', 'XYZ+', '!@#$%^&*()'] + for a, b in zip(expected, s): + self.assertEqual(a.encode('utf8'), + b.metadata) + + def test_mutations(self): + mutations = six.StringIO("""\ + site node derived_state metadata + 0 2 C mno + 0 3 G )(*&^%$#@! + """) + m = tskit.parse_mutations( + mutations, strict=False, encoding='utf8', base64_metadata=False) + expected = ['mno', ')(*&^%$#@!'] + for a, b in zip(expected, m): + self.assertEqual(a.encode('utf8'), + b.metadata) + + def test_populations(self): + populations = six.StringIO("""\ + id metadata + 0 mno + 1 )(*&^%$#@! + """) + p = tskit.parse_populations( + populations, strict=False, encoding='utf8', base64_metadata=False) + expected = ['mno', ')(*&^%$#@!'] + for a, b in zip(expected, p): + self.assertEqual(a.encode('utf8'), + b.metadata) diff --git a/python/tests/test_newick.py b/python/tests/test_newick.py new file mode 100644 index 0000000000..40cb140d72 --- /dev/null +++ b/python/tests/test_newick.py @@ -0,0 +1,121 @@ +""" +Tests for the newick output feature. +""" +from __future__ import print_function +from __future__ import division + +import unittest + +import msprime + +import newick + + +class TestNewick(unittest.TestCase): + """ + Tests that the newick output has the properties that we need using + external Newick parser. + """ + random_seed = 155 + + def verify_newick_topology(self, tree, root=None, node_labels=None): + if root is None: + root = tree.root + ns = tree.newick(precision=16, root=root, node_labels=node_labels) + if node_labels is None: + leaf_labels = {u: str(u + 1) for u in tree.leaves(root)} + else: + leaf_labels = {u: node_labels[u] for u in tree.leaves(root)} + newick_tree = newick.loads(ns)[0] + leaf_names = newick_tree.get_leaf_names() + self.assertEqual(sorted(leaf_names), sorted(leaf_labels.values())) + for u in tree.leaves(root): + name = leaf_labels[u] + node = newick_tree.get_node(name) + while u != root: + self.assertAlmostEqual(node.length, tree.branch_length(u)) + node = node.ancestor + u = tree.parent(u) + self.assertIsNone(node.ancestor) + + def get_nonbinary_example(self): + ts = msprime.simulate( + sample_size=20, recombination_rate=10, random_seed=self.random_seed, + demographic_events=[ + msprime.SimpleBottleneck(time=0.5, population=0, proportion=1)]) + # Make sure this really has some non-binary nodes + found = False + for e in ts.edgesets(): + if len(e.children) > 2: + found = True + break + self.assertTrue(found) + return ts + + def get_binary_example(self): + ts = msprime.simulate( + sample_size=25, recombination_rate=5, random_seed=self.random_seed) + return ts + + def get_multiroot_example(self): + ts = msprime.simulate(sample_size=50, random_seed=self.random_seed) + tables = ts.dump_tables() + edges = tables.edges + n = len(edges) // 2 + edges.set_columns( + left=edges.left[:n], right=edges.right[:n], + parent=edges.parent[:n], child=edges.child[:n]) + return tables.tree_sequence() + + def test_nonbinary_tree(self): + ts = self.get_nonbinary_example() + for t in ts.trees(): + self.verify_newick_topology(t) + + def test_binary_tree(self): + ts = self.get_binary_example() + for t in ts.trees(): + self.verify_newick_topology(t) + + def test_multiroot(self): + ts = self.get_multiroot_example() + t = ts.first() + self.assertRaises(ValueError, t.newick) + for root in t.roots: + self.verify_newick_topology(t, root=root) + + def test_all_nodes(self): + ts = msprime.simulate(10, random_seed=5) + tree = ts.first() + for u in tree.nodes(): + self.verify_newick_topology(tree, root=u) + + def test_binary_leaf_labels(self): + tree = self.get_binary_example().first() + labels = {u: "x_{}".format(u) for u in tree.leaves()} + self.verify_newick_topology(tree, node_labels=labels) + + def test_nonbinary_leaf_labels(self): + ts = self.get_nonbinary_example() + for t in ts.trees(): + labels = {u: str(u) for u in t.leaves()} + self.verify_newick_topology(t, node_labels=labels) + + def test_all_node_labels(self): + tree = msprime.simulate(5, random_seed=2).first() + labels = {u: "x_{}".format(u) for u in tree.nodes()} + ns = tree.newick(node_labels=labels) + root = newick.loads(ns)[0] + self.assertEqual(root.name, labels[tree.root]) + self.assertEqual( + sorted([n.name for n in root.walk()]), sorted(labels.values())) + + def test_single_node_label(self): + tree = msprime.simulate(5, random_seed=2).first() + labels = {tree.root: "XXX"} + ns = tree.newick(node_labels=labels) + root = newick.loads(ns)[0] + self.assertEqual(root.name, labels[tree.root]) + self.assertEqual( + [n.name for n in root.walk()], + [labels[tree.root]] + [None for _ in range(len(list(tree.nodes())) - 1)]) diff --git a/python/tests/test_provenance.py b/python/tests/test_provenance.py new file mode 100644 index 0000000000..f468ffea1a --- /dev/null +++ b/python/tests/test_provenance.py @@ -0,0 +1,192 @@ +""" +Tests for the provenance information attached to tree sequences. +""" +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import division + +import unittest +import json +import platform +import os + +import msprime + +import _tskit +import tskit +import tskit.provenance as provenance + + +def get_provenance( + software_name="x", software_version="y", schema_version="1", environment=None, + parameters=None): + """ + Utility function to return a provenance document for testing. + """ + document = { + "schema_version": schema_version, + "software": { + "name": software_name, + "version": software_version, + }, + "environment": {} if environment is None else environment, + "parameters": {} if parameters is None else parameters, + } + return document + + +class TestSchema(unittest.TestCase): + """ + Tests for schema validation. + """ + def test_empty(self): + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance({}) + + def test_missing_keys(self): + minimal = get_provenance() + tskit.validate_provenance(minimal) + for key in minimal.keys(): + copy = dict(minimal) + del copy[key] + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(copy) + copy = dict(minimal) + del copy["software"]["name"] + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(copy) + copy = dict(minimal) + del copy["software"]["version"] + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(copy) + + def test_software_types(self): + for bad_type in [0, [1, 2, 3], {}]: + doc = get_provenance(software_name=bad_type) + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(doc) + doc = get_provenance(software_version=bad_type) + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(doc) + + def test_schema_version_empth(self): + doc = get_provenance(schema_version="") + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(doc) + + def test_software_empty_strings(self): + doc = get_provenance(software_name="") + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(doc) + doc = get_provenance(software_version="") + with self.assertRaises(tskit.ProvenanceValidationError): + tskit.validate_provenance(doc) + + def test_minimal(self): + minimal = { + "schema_version": "1", + "software": { + "name": "x", + "version": "y", + }, + "environment": {}, + "parameters": {} + } + tskit.validate_provenance(minimal) + + def test_extra_stuff(self): + extra = { + "you": "can", + "schema_version": "1", + "software": { + "put": "anything", + "name": "x", + "version": "y", + }, + "environment": {"extra": ["you", "want"]}, + "parameters": {"so": ["long", "its", "JSON", 0]} + } + tskit.validate_provenance(extra) + + +class TestOutputProvenance(unittest.TestCase): + """ + Check that the schemas we produce in tskit are valid. + """ + def test_simplify(self): + ts = msprime.simulate(5, random_seed=1) + ts = ts.simplify() + prov = json.loads(ts.provenance(1).record) + tskit.validate_provenance(prov) + self.assertEqual(prov["parameters"]["command"], "simplify") + self.assertEqual( + prov["environment"], provenance.get_environment(include_tskit=False)) + self.assertEqual( + prov["software"], + {"name": "tskit", "version": tskit.__version__}) + + +class TestEnvironment(unittest.TestCase): + """ + Tests for the environment provenance. + """ + def test_os(self): + env = provenance.get_environment() + os = { + "system": platform.system(), + "node": platform.node(), + "release": platform.release(), + "version": platform.version(), + "machine": platform.machine() + } + self.assertEqual(env["os"], os) + + def test_python(self): + env = provenance.get_environment() + python = { + "implementation": platform.python_implementation(), + "version": platform.python_version(), + } + self.assertEqual(env["python"], python) + + def test_libraries(self): + kastore_lib = {"version": ".".join(map(str, _tskit.get_kastore_version()))} + env = provenance.get_environment() + self.assertEqual({ + "kastore": kastore_lib, + "tskit": {"version": tskit.__version__}}, + env["libraries"]) + + env = provenance.get_environment(include_tskit=False) + self.assertEqual({"kastore": kastore_lib}, env["libraries"]) + + extra_libs = {"abc": [], "xyz": {"one": 1}} + env = provenance.get_environment(include_tskit=False, extra_libs=extra_libs) + libs = {"kastore": kastore_lib} + libs.update(extra_libs) + self.assertEqual(libs, env["libraries"]) + + +class TestGetSchema(unittest.TestCase): + """ + Ensure we return the correct JSON schema. + """ + def test_file_equal(self): + s1 = provenance.get_schema() + with open(os.path.join("tskit", "provenance.schema.json")) as f: + s2 = json.load(f) + self.assertEqual(s1, s2) + + def test_caching(self): + n = 10 + schemas = [provenance.get_schema() for _ in range(n)] + # Ensure all the schemas are different objects. + self.assertEqual(len(set(map(id, schemas))), n) + # Ensure the schemas are all equal + for j in range(n): + self.assertEqual(schemas[0], schemas[j]) + + def test_form(self): + s = provenance.get_schema() + self.assertEqual(s["schema"], "http://json-schema.org/draft-07/schema#") + self.assertEqual(s["version"], "1.0.0") diff --git a/python/tests/test_stats.py b/python/tests/test_stats.py new file mode 100644 index 0000000000..29f138c4dc --- /dev/null +++ b/python/tests/test_stats.py @@ -0,0 +1,585 @@ +""" +Test cases for stats calculations in tskit. +""" +from __future__ import print_function +from __future__ import division + +import unittest +import sys + +import numpy as np +import msprime + +import tskit +import _tskit +import tests.tsutil as tsutil +import tests.test_wright_fisher as wf + + +IS_PY2 = sys.version_info[0] < 3 + + +def get_r2_matrix(ts): + """ + Returns the matrix for the specified tree sequence. This is computed + via a straightforward Python algorithm. + """ + n = ts.get_sample_size() + m = ts.get_num_mutations() + A = np.zeros((m, m), dtype=float) + for t1 in ts.trees(): + for sA in t1.sites(): + assert len(sA.mutations) == 1 + mA = sA.mutations[0] + A[sA.id, sA.id] = 1 + fA = t1.get_num_samples(mA.node) / n + samples = list(t1.samples(mA.node)) + for t2 in ts.trees(tracked_samples=samples): + for sB in t2.sites(): + assert len(sB.mutations) == 1 + mB = sB.mutations[0] + if sB.position > sA.position: + fB = t2.get_num_samples(mB.node) / n + fAB = t2.get_num_tracked_samples(mB.node) / n + D = fAB - fA * fB + r2 = D * D / (fA * fB * (1 - fA) * (1 - fB)) + A[sA.id, sB.id] = r2 + A[sB.id, sA.id] = r2 + return A + + +class TestLdCalculator(unittest.TestCase): + """ + Tests for the LdCalculator class. + """ + + num_test_sites = 50 + + def verify_matrix(self, ts): + m = ts.get_num_sites() + ldc = tskit.LdCalculator(ts) + A = ldc.get_r2_matrix() + self.assertEqual(A.shape, (m, m)) + B = get_r2_matrix(ts) + self.assertTrue(np.allclose(A, B)) + + # Now look at each row in turn, and verify it's the same + # when we use get_r2 directly. + for j in range(m): + a = ldc.get_r2_array(j, direction=tskit.FORWARD) + b = A[j, j + 1:] + self.assertEqual(a.shape[0], m - j - 1) + self.assertEqual(b.shape[0], m - j - 1) + self.assertTrue(np.allclose(a, b)) + a = ldc.get_r2_array(j, direction=tskit.REVERSE) + b = A[j, :j] + self.assertEqual(a.shape[0], j) + self.assertEqual(b.shape[0], j) + self.assertTrue(np.allclose(a[::-1], b)) + + # Now check every cell in the matrix in turn. + for j in range(m): + for k in range(m): + self.assertAlmostEqual(ldc.get_r2(j, k), A[j, k]) + + def verify_max_distance(self, ts): + """ + Verifies that the max_distance parameter works as expected. + """ + mutations = list(ts.mutations()) + ldc = tskit.LdCalculator(ts) + A = ldc.get_r2_matrix() + j = len(mutations) // 2 + for k in range(j): + x = mutations[j + k].position - mutations[j].position + a = ldc.get_r2_array(j, max_distance=x) + self.assertEqual(a.shape[0], k) + self.assertTrue(np.allclose(A[j, j + 1: j + 1 + k], a)) + x = mutations[j].position - mutations[j - k].position + a = ldc.get_r2_array(j, max_distance=x, direction=tskit.REVERSE) + self.assertEqual(a.shape[0], k) + self.assertTrue(np.allclose(A[j, j - k: j], a[::-1])) + L = ts.get_sequence_length() + m = len(mutations) + a = ldc.get_r2_array(0, max_distance=L) + self.assertEqual(a.shape[0], m - 1) + self.assertTrue(np.allclose(A[0, 1:], a)) + a = ldc.get_r2_array(m - 1, max_distance=L, direction=tskit.REVERSE) + self.assertEqual(a.shape[0], m - 1) + self.assertTrue(np.allclose(A[m - 1, :-1], a[::-1])) + + def verify_max_mutations(self, ts): + """ + Verifies that the max mutations parameter works as expected. + """ + mutations = list(ts.mutations()) + ldc = tskit.LdCalculator(ts) + A = ldc.get_r2_matrix() + j = len(mutations) // 2 + for k in range(j): + a = ldc.get_r2_array(j, max_mutations=k) + self.assertEqual(a.shape[0], k) + self.assertTrue(np.allclose(A[j, j + 1: j + 1 + k], a)) + a = ldc.get_r2_array(j, max_mutations=k, direction=tskit.REVERSE) + self.assertEqual(a.shape[0], k) + self.assertTrue(np.allclose(A[j, j - k: j], a[::-1])) + + def test_single_tree_simulated_mutations(self): + ts = msprime.simulate(20, mutation_rate=10, random_seed=15) + ts = tsutil.subsample_sites(ts, self.num_test_sites) + self.verify_matrix(ts) + self.verify_max_distance(ts) + + def test_deprecated_aliases(self): + ts = msprime.simulate(20, mutation_rate=10, random_seed=15) + ts = tsutil.subsample_sites(ts, self.num_test_sites) + ldc = tskit.LdCalculator(ts) + A = ldc.get_r2_matrix() + B = ldc.r2_matrix() + self.assertTrue(np.array_equal(A, B)) + a = ldc.get_r2_array(0) + b = ldc.r2_array(0) + self.assertTrue(np.array_equal(a, b)) + self.assertEqual(ldc.get_r2(0, 1), ldc.r2(0, 1)) + + def test_single_tree_regular_mutations(self): + ts = msprime.simulate(self.num_test_sites, length=self.num_test_sites) + ts = tsutil.insert_branch_mutations(ts) + # We don't support back mutations, so this should fail. + self.assertRaises(_tskit.LibraryError, self.verify_matrix, ts) + self.assertRaises(_tskit.LibraryError, self.verify_max_distance, ts) + + def test_tree_sequence_regular_mutations(self): + ts = msprime.simulate( + self.num_test_sites, recombination_rate=1, + length=self.num_test_sites) + self.assertGreater(ts.get_num_trees(), 10) + t = ts.dump_tables() + t.sites.reset() + t.mutations.reset() + for j in range(self.num_test_sites): + site_id = len(t.sites) + t.sites.add_row(position=j, ancestral_state="0") + t.mutations.add_row(site=site_id, derived_state="1", node=j) + ts = t.tree_sequence() + self.verify_matrix(ts) + self.verify_max_distance(ts) + + def test_tree_sequence_simulated_mutations(self): + ts = msprime.simulate(20, mutation_rate=10, recombination_rate=10) + self.assertGreater(ts.get_num_trees(), 10) + ts = tsutil.subsample_sites(ts, self.num_test_sites) + self.verify_matrix(ts) + self.verify_max_distance(ts) + self.verify_max_mutations(ts) + + +def set_partitions(collection): + """ + Returns an ierator over all partitions of the specified set. + + From https://stackoverflow.com/questions/19368375/set-partitions-in-python + """ + if len(collection) == 1: + yield [collection] + else: + first = collection[0] + for smaller in set_partitions(collection[1:]): + for n, subset in enumerate(smaller): + yield smaller[:n] + [[first] + subset] + smaller[n + 1:] + yield [[first]] + smaller + + +def naive_mean_descendants(ts, reference_sets): + """ + Straightforward implementation of mean sample ancestry by iterating + over the trees and nodes in each tree. + """ + # TODO generalise this to allow arbitrary nodes, not just samples. + C = np.zeros((ts.num_nodes, len(reference_sets))) + T = np.zeros(ts.num_nodes) + tree_iters = [ts.trees(tracked_samples=sample_set) for sample_set in reference_sets] + for _ in range(ts.num_trees): + trees = [next(tree_iter) for tree_iter in tree_iters] + left, right = trees[0].interval + length = right - left + for node in trees[0].nodes(): + num_samples = trees[0].num_samples(node) + if num_samples > 0: + for j, tree in enumerate(trees): + C[node, j] += length * tree.num_tracked_samples(node) + T[node] += length + for node in range(ts.num_nodes): + if T[node] > 0: + C[node] /= T[node] + return C + + +class TestMeanDescendants(unittest.TestCase): + """ + Tests the TreeSequence.mean_descendants method. + """ + def verify(self, ts, reference_sets): + C1 = naive_mean_descendants(ts, reference_sets) + C2 = tsutil.mean_descendants(ts, reference_sets) + C3 = ts.mean_descendants(reference_sets) + self.assertEqual(C1.shape, C2.shape) + self.assertTrue(np.allclose(C1, C2)) + self.assertTrue(np.allclose(C1, C3)) + return C1 + + def test_two_populations_high_migration(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(8), + msprime.PopulationConfiguration(8)], + migration_matrix=[[0, 1], [1, 0]], + recombination_rate=3, + random_seed=5) + self.assertGreater(ts.num_trees, 1) + self.verify(ts, [ts.samples(0), ts.samples(1)]) + + def test_single_tree(self): + ts = msprime.simulate(6, random_seed=1) + S = [range(3), range(3, 6)] + C = self.verify(ts, S) + for j, samples in enumerate(S): + tree = next(ts.trees(tracked_samples=samples)) + for u in tree.nodes(): + self.assertEqual(tree.num_tracked_samples(u), C[u, j]) + + def test_single_tree_partial_samples(self): + ts = msprime.simulate(6, random_seed=1) + S = [range(3), range(3, 4)] + C = self.verify(ts, S) + for j, samples in enumerate(S): + tree = next(ts.trees(tracked_samples=samples)) + for u in tree.nodes(): + self.assertEqual(tree.num_tracked_samples(u), C[u, j]) + + def test_single_tree_all_sample_sets(self): + ts = msprime.simulate(6, random_seed=1) + for S in set_partitions(list(range(ts.num_samples))): + C = self.verify(ts, S) + for j, samples in enumerate(S): + tree = next(ts.trees(tracked_samples=samples)) + for u in tree.nodes(): + self.assertEqual(tree.num_tracked_samples(u), C[u, j]) + + def test_many_trees_all_sample_sets(self): + ts = msprime.simulate(6, recombination_rate=2, random_seed=1) + self.assertGreater(ts.num_trees, 2) + for S in set_partitions(list(range(ts.num_samples))): + self.verify(ts, S) + + def test_wright_fisher_unsimplified_all_sample_sets(self): + tables = wf.wf_sim( + 4, 5, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=10) + tables.sort() + ts = tables.tree_sequence() + for S in set_partitions(list(ts.samples())): + self.verify(ts, S) + + def test_wright_fisher_unsimplified(self): + tables = wf.wf_sim( + 20, 15, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=20) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + def test_wright_fisher_simplified(self): + tables = wf.wf_sim( + 30, 10, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=5) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + +def naive_genealogical_nearest_neighbours(ts, focal, reference_sets): + # Make sure everyhing is a sample so we can use the tracked_samples option. + # This is a limitation of the current API. + tables = ts.dump_tables() + tables.nodes.set_columns( + flags=np.ones_like(tables.nodes.flags), + time=tables.nodes.time) + ts = tables.tree_sequence() + + A = np.zeros((len(focal), len(reference_sets))) + L = np.zeros(len(focal)) + reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 + for k, ref_set in enumerate(reference_sets): + for u in ref_set: + reference_set_map[u] = k + tree_iters = [ + ts.trees(tracked_samples=reference_nodes) for reference_nodes in reference_sets] + for _ in range(ts.num_trees): + trees = list(map(next, tree_iters)) + length = trees[0].interval[1] - trees[0].interval[0] + for j, u in enumerate(focal): + v = trees[0].parent(u) + while v != tskit.NULL: + total = sum(tree.num_tracked_samples(v) for tree in trees) + if total > 1: + break + v = trees[0].parent(v) + if v != tskit.NULL: + focal_node_set = reference_set_map[u] + for k, tree in enumerate(trees): + # If the focal node is in the current set, we subtract its + # contribution from the numerator + n = tree.num_tracked_samples(v) - (k == focal_node_set) + # If the focal node is in *any* reference set, we subtract its + # contribution from the demoninator. + A[j, k] += length * n / (total - int(focal_node_set != -1)) + L[j] += length + # Normalise by the accumulated value for each focal node. + index = L > 0 + L = L[index] + L = L.reshape((L.shape[0], 1)) + A[index, :] /= L + return A + + +class TestGenealogicalNearestNeighbours(unittest.TestCase): + """ + Tests the TreeSequence.genealogical_nearest_neighbours method. + """ + def verify(self, ts, reference_sets, focal=None): + if focal is None: + focal = [u for refset in reference_sets for u in refset] + A1 = naive_genealogical_nearest_neighbours(ts, focal, reference_sets) + A2 = tsutil.genealogical_nearest_neighbours(ts, focal, reference_sets) + A3 = ts.genealogical_nearest_neighbours(focal, reference_sets) + if IS_PY2: + # Threads not supported on PY2 + self.assertRaises( + ValueError, ts.genealogical_nearest_neighbours, focal, + reference_sets, num_threads=3) + else: + A4 = ts.genealogical_nearest_neighbours(focal, reference_sets, num_threads=3) + self.assertTrue(np.array_equal(A3, A4)) + self.assertEqual(A1.shape, A2.shape) + self.assertEqual(A1.shape, A3.shape) + self.assertTrue(np.allclose(A1, A2)) + self.assertTrue(np.allclose(A1, A3)) + if all(ts.node(u).is_sample() for u in focal): + # When the focal nodes are samples, we can assert some stronger properties. + fully_rooted = True + for tree in ts.trees(): + if tree.num_roots > 1: + fully_rooted = False + break + if fully_rooted: + self.assertTrue(np.allclose(np.sum(A1, axis=1), 1)) + else: + all_refs = [u for refset in reference_sets for u in refset] + # Any node that hits a root before meeting a descendent of the reference + # nodes must have total zero. + coalescence_found = np.array([False for _ in all_refs]) + for tree in ts.trees(tracked_samples=all_refs): + for j, u in enumerate(focal): + while u != tskit.NULL: + if tree.num_tracked_samples(u) > 1: + coalescence_found[j] = True + break + u = tree.parent(u) + self.assertTrue(np.allclose(np.sum(A1[coalescence_found], axis=1), 1)) + # Anything where there's no coalescence, ever is zero by convention. + self.assertTrue( + np.allclose( + np.sum(A1[np.logical_not(coalescence_found)], axis=1), 0)) + return A1 + + def test_two_populations_high_migration(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(18), + msprime.PopulationConfiguration(18)], + migration_matrix=[[0, 1], [1, 0]], + recombination_rate=8, + random_seed=5) + self.assertGreater(ts.num_trees, 1) + self.verify(ts, [ts.samples(0), ts.samples(1)]) + + def test_single_tree(self): + ts = msprime.simulate(6, random_seed=1) + S = [range(3), range(3, 6)] + self.verify(ts, S) + + def test_single_tree_internal_reference_sets(self): + ts = msprime.simulate(10, random_seed=1) + tree = ts.first() + S = [[u] for u in tree.children(tree.root)] + self.verify(ts, S, ts.samples()) + + def test_single_tree_all_nodes(self): + ts = msprime.simulate(10, random_seed=1) + S = [np.arange(ts.num_nodes, dtype=np.int32)] + self.verify(ts, S, np.arange(ts.num_nodes, dtype=np.int32)) + + def test_single_tree_partial_samples(self): + ts = msprime.simulate(6, random_seed=1) + S = [range(3), range(3, 4)] + self.verify(ts, S) + + def test_single_tree_all_sample_sets(self): + ts = msprime.simulate(6, random_seed=1) + for S in set_partitions(list(range(ts.num_samples))): + self.verify(ts, S) + + def test_many_trees_all_sample_sets(self): + ts = msprime.simulate(6, recombination_rate=2, random_seed=1) + self.assertGreater(ts.num_trees, 2) + for S in set_partitions(list(range(ts.num_samples))): + self.verify(ts, S) + + def test_many_trees_sequence_length(self): + for L in [0.5, 1.5, 3.3333]: + ts = msprime.simulate(6, length=L, recombination_rate=2, random_seed=1) + self.verify(ts, [range(3), range(3, 6)]) + + def test_many_trees_all_nodes(self): + ts = msprime.simulate(6, length=4, recombination_rate=2, random_seed=1) + S = [np.arange(ts.num_nodes, dtype=np.int32)] + self.verify(ts, S, np.arange(ts.num_nodes, dtype=np.int32)) + + def test_wright_fisher_unsimplified_all_sample_sets(self): + tables = wf.wf_sim( + 4, 5, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=10) + tables.sort() + ts = tables.tree_sequence() + for S in set_partitions(list(ts.samples())): + self.verify(ts, S) + + def test_wright_fisher_unsimplified(self): + tables = wf.wf_sim( + 20, 15, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=20) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + def test_wright_fisher_initial_generation(self): + tables = wf.wf_sim( + 20, 15, seed=1, deep_history=True, initial_generation_samples=True, + num_loci=20) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + samples = ts.samples() + founders = [u for u in samples if ts.node(u).time > 0] + samples = [u for u in samples if ts.node(u).time == 0] + self.verify(ts, [founders[:10], founders[10:]], samples) + + def test_wright_fisher_initial_generation_no_deep_history(self): + tables = wf.wf_sim( + 20, 15, seed=2, deep_history=False, initial_generation_samples=True, + num_loci=20) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + samples = ts.samples() + founders = [u for u in samples if ts.node(u).time > 0] + samples = [u for u in samples if ts.node(u).time == 0] + A = self.verify(ts, [founders[:10], founders[10:]], samples) + # Because the founders are all isolated, the stat must be zero. + self.assertTrue(np.all(A == 0)) + + def test_wright_fisher_unsimplified_multiple_roots(self): + tables = wf.wf_sim( + 20, 15, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=20) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + def test_wright_fisher_simplified(self): + tables = wf.wf_sim( + 31, 10, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=5) + tables.sort() + ts = tables.tree_sequence().simplify() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + def test_wright_fisher_simplified_multiple_roots(self): + tables = wf.wf_sim( + 31, 10, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=5) + tables.sort() + ts = tables.tree_sequence() + samples = ts.samples() + self.verify(ts, [samples[:10], samples[10:]]) + + def test_empty_ts(self): + tables = tskit.TableCollection(1.0) + tables.nodes.add_row(1, 0) + tables.nodes.add_row(1, 0) + ts = tables.tree_sequence() + self.verify(ts, [[0], [1]]) + + +def exact_genealogical_nearest_neighbours(ts, focal, reference_sets): + # Same as above, except we return the per-tree value for a single node. + + # Make sure everyhing is a sample so we can use the tracked_samples option. + # This is a limitation of the current API. + tables = ts.dump_tables() + tables.nodes.set_columns( + flags=np.ones_like(tables.nodes.flags), + time=tables.nodes.time) + ts = tables.tree_sequence() + + A = np.zeros((len(reference_sets), ts.num_trees)) + L = np.zeros(ts.num_trees) + reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 + for k, ref_set in enumerate(reference_sets): + for u in ref_set: + reference_set_map[u] = k + tree_iters = [ + ts.trees(tracked_samples=reference_nodes) for reference_nodes in reference_sets] + u = focal + for _ in range(ts.num_trees): + trees = list(map(next, tree_iters)) + v = trees[0].parent(u) + while v != tskit.NULL: + total = sum(tree.num_tracked_samples(v) for tree in trees) + if total > 1: + break + v = trees[0].parent(v) + if v != tskit.NULL: + # The length is only reported where the statistic is defined. + L[trees[0].index] = trees[0].interval[1] - trees[0].interval[0] + focal_node_set = reference_set_map[u] + for k, tree in enumerate(trees): + # If the focal node is in the current set, we subtract its + # contribution from the numerator + n = tree.num_tracked_samples(v) - (k == focal_node_set) + # If the focal node is in *any* reference set, we subtract its + # contribution from the demoninator. + A[k, tree.index] = n / (total - int(focal_node_set != -1)) + return A, L + + +class TestExactGenealogicalNearestNeighbours(TestGenealogicalNearestNeighbours): + + def verify(self, ts, reference_sets, focal=None): + if focal is None: + focal = [u for refset in reference_sets for u in refset] + A = ts.genealogical_nearest_neighbours(focal, reference_sets) + + for j, u in enumerate(focal): + T, L = exact_genealogical_nearest_neighbours(ts, u, reference_sets) + # Ignore the cases where the node has no GNNs + if np.sum(L) > 0: + mean = np.sum(T * L, axis=1) / np.sum(L) + self.assertTrue(np.allclose(mean, A[j])) + return A diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py new file mode 100644 index 0000000000..49bc1ffb97 --- /dev/null +++ b/python/tests/test_tables.py @@ -0,0 +1,1660 @@ +# -*- coding: utf-8 -*- +""" +Test cases for the low-level tables used to transfer information +between simulations and the tree sequence. +""" +from __future__ import print_function +from __future__ import division + +import pickle +import random +import string +import unittest +import warnings +import sys + +import numpy as np +import six + +import tskit +import _tskit +import msprime + +import tests.tsutil as tsutil + +IS_PY2 = sys.version_info[0] < 3 + + +def random_bytes(max_length): + """ + Returns a random bytearray of the specified maximum length. + """ + length = random.randint(0, max_length) + return bytearray(random.randint(0, 255) for _ in range(length)) + + +def random_strings(max_length): + """ + Returns a random bytearray of the specified maximum length. + """ + length = random.randint(0, max_length) + return "".join(random.choice(string.printable) for _ in range(length)) + + +class Column(object): + def __init__(self, name): + self.name = name + + +class Int32Column(Column): + def get_input(self, n): + return 1 + np.arange(n, dtype=np.int32) + + +class UInt8Column(Column): + def get_input(self, n): + return 2 + np.arange(n, dtype=np.uint8) + + +class UInt32Column(Column): + def get_input(self, n): + return 3 + np.arange(n, dtype=np.uint32) + + +class CharColumn(Column): + def get_input(self, n): + return np.zeros(n, dtype=np.int8) + + +class DoubleColumn(Column): + def get_input(self, n): + return 4 + np.arange(n, dtype=np.float64) + + +class CommonTestsMixin(object): + """ + Abstract base class for common table tests. Because of the design of unittest, + we have to make this a mixin. + """ + def test_max_rows_increment(self): + for bad_value in [-1, -2**10]: + self.assertRaises(ValueError, self.table_class, max_rows_increment=bad_value) + for v in [1, 100, 256]: + table = self.table_class(max_rows_increment=v) + self.assertEqual(table.max_rows_increment, v) + # Setting zero or not argument both denote the default. + table = self.table_class() + self.assertEqual(table.max_rows_increment, 1024) + table = self.table_class(max_rows_increment=0) + self.assertEqual(table.max_rows_increment, 1024) + + def test_input_parameters_errors(self): + self.assertGreater(len(self.input_parameters), 0) + for param, _ in self.input_parameters: + for bad_value in [-1, -2**10]: + self.assertRaises(ValueError, self.table_class, **{param: bad_value}) + for bad_type in [None, ValueError, "ser"]: + self.assertRaises(TypeError, self.table_class, **{param: bad_type}) + + def test_input_parameter_values(self): + self.assertGreater(len(self.input_parameters), 0) + for param, _ in self.input_parameters: + for v in [1, 100, 256]: + table = self.table_class(**{param: v}) + self.assertEqual(getattr(table, param), v) + + def test_set_columns_string_errors(self): + inputs = {c.name: c.get_input(1) for c in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(1) + inputs[list_col.name] = value + inputs[offset_col.name] = [0, 1] + # Make sure this works. + table = self.table_class() + table.set_columns(**inputs) + for list_col, offset_col in self.ragged_list_columns: + kwargs = dict(inputs) + del kwargs[list_col.name] + self.assertRaises(TypeError, table.set_columns, **kwargs) + kwargs = dict(inputs) + del kwargs[offset_col.name] + self.assertRaises(TypeError, table.set_columns, **kwargs) + + def test_set_columns_interface(self): + kwargs = {c.name: c.get_input(1) for c in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(1) + kwargs[list_col.name] = value + kwargs[offset_col.name] = [0, 1] + # Make sure this works. + table = self.table_class() + table.set_columns(**kwargs) + table.append_columns(**kwargs) + for focal_col in self.columns: + table = self.table_class() + for bad_type in [Exception, tskit]: + error_kwargs = dict(kwargs) + error_kwargs[focal_col.name] = bad_type + self.assertRaises(ValueError, table.set_columns, **error_kwargs) + self.assertRaises(ValueError, table.append_columns, **error_kwargs) + for bad_value in ["qwer", [0, "sd"]]: + error_kwargs = dict(kwargs) + error_kwargs[focal_col.name] = bad_value + self.assertRaises(ValueError, table.set_columns, **error_kwargs) + self.assertRaises(ValueError, table.append_columns, **error_kwargs) + + def test_set_columns_from_dict(self): + kwargs = {c.name: c.get_input(1) for c in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(1) + kwargs[list_col.name] = value + kwargs[offset_col.name] = [0, 1] + # Make sure this works. + t1 = self.table_class() + t1.set_columns(**kwargs) + t2 = self.table_class() + t2.set_columns(**t1.asdict()) + self.assertEqual(t1, t2) + + def test_set_columns_dimension(self): + kwargs = {c.name: c.get_input(1) for c in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(1) + kwargs[list_col.name] = value + kwargs[offset_col.name] = [0, 1] + table = self.table_class() + table.set_columns(**kwargs) + table.append_columns(**kwargs) + for focal_col in self.columns: + table = self.table_class() + for bad_dims in [5, [[1], [1]], np.zeros((2, 2))]: + error_kwargs = dict(kwargs) + error_kwargs[focal_col.name] = bad_dims + self.assertRaises(ValueError, table.set_columns, **error_kwargs) + self.assertRaises(ValueError, table.append_columns, **error_kwargs) + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(1) + error_kwargs = dict(kwargs) + for bad_dims in [5, [[1], [1]], np.zeros((2, 2))]: + error_kwargs[offset_col.name] = bad_dims + self.assertRaises(ValueError, table.set_columns, **error_kwargs) + self.assertRaises(ValueError, table.append_columns, **error_kwargs) + # Empty offset columns are caught also + error_kwargs[offset_col.name] = [] + self.assertRaises(ValueError, table.set_columns, **error_kwargs) + + def test_set_columns_input_sizes(self): + num_rows = 100 + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + col_map = {col.name: col for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + col_map[list_col.name] = list_col + col_map[offset_col.name] = offset_col + table = self.table_class() + table.set_columns(**input_data) + table.append_columns(**input_data) + for equal_len_col_set in self.equal_len_columns: + if len(equal_len_col_set) > 1: + for col in equal_len_col_set: + kwargs = dict(input_data) + kwargs[col] = col_map[col].get_input(1) + self.assertRaises(ValueError, table.set_columns, **kwargs) + self.assertRaises(ValueError, table.append_columns, **kwargs) + + @unittest.skip("Fix or remove when column setter done. #492") + def test_set_read_only_attributes(self): + table = self.table_class() + with self.assertRaises(AttributeError): + table.num_rows = 10 + with self.assertRaises(AttributeError): + table.max_rows = 10 + for param, default in self.input_parameters: + with self.assertRaises(AttributeError): + setattr(table, param, 2) + for col in self.columns: + with self.assertRaises(AttributeError): + setattr(table, col.name, np.zeros(5)) + self.assertEqual(table.num_rows, 0) + self.assertEqual(len(table), 0) + + def test_defaults(self): + table = self.table_class() + self.assertEqual(table.num_rows, 0) + self.assertEqual(len(table), 0) + for param, default in self.input_parameters: + self.assertEqual(getattr(table, param), default) + for col in self.columns: + array = getattr(table, col.name) + self.assertEqual(array.shape, (0,)) + + def test_add_row_data(self): + for num_rows in [0, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + table = self.table_class() + for j in range(num_rows): + kwargs = {col: data[j] for col, data in input_data.items()} + for col in self.string_colnames: + kwargs[col] = "x" + for col in self.binary_colnames: + kwargs[col] = b"x" + k = table.add_row(**kwargs) + self.assertEqual(k, j) + for colname, input_array in input_data.items(): + output_array = getattr(table, colname) + self.assertEqual(input_array.shape, output_array.shape) + self.assertTrue(np.all(input_array == output_array)) + table.clear() + self.assertEqual(table.num_rows, 0) + self.assertEqual(len(table), 0) + + def test_add_row_round_trip(self): + for num_rows in [0, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t1 = self.table_class() + t1.set_columns(**input_data) + for colname, input_array in input_data.items(): + output_array = getattr(t1, colname) + self.assertEqual(input_array.shape, output_array.shape) + self.assertTrue(np.all(input_array == output_array)) + t2 = self.table_class() + for row in list(t1): + t2.add_row(**row._asdict()) + self.assertEqual(t1, t2) + + def test_set_columns_data(self): + for num_rows in [0, 10, 100, 1000]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + offset_cols = set() + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + offset_cols.add(offset_col.name) + table = self.table_class() + for _ in range(5): + table.set_columns(**input_data) + for colname, input_array in input_data.items(): + output_array = getattr(table, colname) + self.assertEqual(input_array.shape, output_array.shape) + self.assertTrue(np.all(input_array == output_array)) + table.clear() + self.assertEqual(table.num_rows, 0) + self.assertEqual(len(table), 0) + for colname in input_data.keys(): + if colname in offset_cols: + self.assertEqual(list(getattr(table, colname)), [0]) + else: + self.assertEqual(list(getattr(table, colname)), []) + + def test_truncate(self): + num_rows = 100 + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(2 * num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = 2 * np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + + copy = table.copy() + table.truncate(num_rows) + self.assertEqual(copy, table) + + for num_rows in [100, 10, 1]: + table.truncate(num_rows) + self.assertEqual(table.num_rows, num_rows) + self.assertEqual(len(table), num_rows) + used = set() + for list_col, offset_col in self.ragged_list_columns: + offset = getattr(table, offset_col.name) + self.assertEqual(offset.shape, (num_rows + 1,)) + self.assertTrue(np.array_equal( + input_data[offset_col.name][:num_rows + 1], offset)) + list_data = getattr(table, list_col.name) + self.assertTrue(np.array_equal( + list_data, input_data[list_col.name][:offset[-1]])) + used.add(offset_col.name) + used.add(list_col.name) + for name, data in input_data.items(): + if name not in used: + self.assertTrue(np.array_equal( + data[:num_rows], getattr(table, name))) + + def test_truncate_errors(self): + num_rows = 10 + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(2 * num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = 2 * np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + for bad_type in [None, 0.001, {}]: + self.assertRaises(TypeError, table.truncate, bad_type) + for bad_num_rows in [-1, num_rows + 1, 10**6]: + self.assertRaises(ValueError, table.truncate, bad_num_rows) + + def test_append_columns_data(self): + for num_rows in [0, 10, 100, 1000]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + offset_cols = set() + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + offset_cols.add(offset_col.name) + table = self.table_class() + for j in range(1, 10): + table.append_columns(**input_data) + for colname, values in input_data.items(): + output_array = getattr(table, colname) + if colname in offset_cols: + input_array = np.zeros(j * num_rows + 1, dtype=np.uint32) + for k in range(j): + input_array[k * num_rows: (k + 1) * num_rows + 1] = ( + k * values[-1]) + values + self.assertEqual(input_array.shape, output_array.shape) + else: + input_array = np.hstack([values for _ in range(j)]) + self.assertEqual(input_array.shape, output_array.shape) + self.assertTrue(np.array_equal(input_array, output_array)) + self.assertEqual(table.num_rows, j * num_rows) + self.assertEqual(len(table), j * num_rows) + + def test_append_columns_max_rows(self): + for num_rows in [0, 10, 100, 1000]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + for max_rows in [0, 1, 8192]: + table = self.table_class(max_rows_increment=max_rows) + for j in range(1, 10): + table.append_columns(**input_data) + self.assertEqual(table.num_rows, j * num_rows) + self.assertEqual(len(table), j * num_rows) + self.assertGreater(table.max_rows, table.num_rows) + + def test_str(self): + for num_rows in [0, 10]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + s = str(table) + self.assertEqual(len(s.splitlines()), num_rows + 1) + + def test_copy(self): + for num_rows in [0, 10]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + for _ in range(10): + copy = table.copy() + self.assertNotEqual(id(copy), id(table)) + self.assertIsInstance(copy, self.table_class) + self.assertEqual(copy, table) + table = copy + + def test_pickle(self): + for num_rows in [0, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + pkl = pickle.dumps(table) + new_table = pickle.loads(pkl) + self.assertEqual(table, new_table) + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + pkl = pickle.dumps(table, protocol=protocol) + new_table = pickle.loads(pkl) + self.assertEqual(table, new_table) + + def test_equality(self): + for num_rows in [1, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t1 = self.table_class() + t2 = self.table_class() + self.assertEqual(t1, t1) + self.assertEqual(t1, t2) + self.assertTrue(t1 == t2) + self.assertFalse(t1 != t2) + t1.set_columns(**input_data) + self.assertEqual(t1, t1) + self.assertNotEqual(t1, t2) + self.assertNotEqual(t2, t1) + t2.set_columns(**input_data) + self.assertEqual(t1, t2) + self.assertEqual(t2, t2) + t2.clear() + self.assertNotEqual(t1, t2) + self.assertNotEqual(t2, t1) + # Check each column in turn to see if we are correctly checking values. + for col in self.columns: + col_copy = np.copy(input_data[col.name]) + input_data_copy = dict(input_data) + input_data_copy[col.name] = col_copy + t2.set_columns(**input_data_copy) + self.assertEqual(t1, t2) + self.assertFalse(t1 != t2) + col_copy += 1 + t2.set_columns(**input_data_copy) + self.assertNotEqual(t1, t2) + self.assertNotEqual(t2, t1) + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data_copy = dict(input_data) + input_data_copy[list_col.name] = value + 1 + t2.set_columns(**input_data_copy) + self.assertNotEqual(t1, t2) + value = list_col.get_input(num_rows + 1) + input_data_copy = dict(input_data) + input_data_copy[list_col.name] = value + input_data_copy[offset_col.name] = np.arange( + num_rows + 1, dtype=np.uint32) + input_data_copy[offset_col.name][-1] = num_rows + 1 + t2.set_columns(**input_data_copy) + self.assertNotEqual(t1, t2) + self.assertNotEqual(t2, t1) + # Different types should always be unequal. + self.assertNotEqual(t1, None) + self.assertNotEqual(t1, []) + + def test_bad_offsets(self): + for num_rows in [10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t = self.table_class() + t.set_columns(**input_data) + + for list_col, offset_col in self.ragged_list_columns: + input_data[offset_col.name][0] = -1 + self.assertRaises(_tskit.LibraryError, t.set_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t.set_columns(**input_data) + input_data[offset_col.name][-1] = 0 + self.assertRaises(ValueError, t.set_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t.set_columns(**input_data) + input_data[offset_col.name][num_rows // 2] = 2**31 + self.assertRaises(_tskit.LibraryError, t.set_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + + input_data[offset_col.name][0] = -1 + self.assertRaises(_tskit.LibraryError, t.append_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t.append_columns(**input_data) + input_data[offset_col.name][-1] = 0 + self.assertRaises(ValueError, t.append_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + t.append_columns(**input_data) + input_data[offset_col.name][num_rows // 2] = 2**31 + self.assertRaises(_tskit.LibraryError, t.append_columns, **input_data) + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + + +class MetadataTestsMixin(object): + """ + Tests for column that have metadata columns. + """ + def test_random_metadata(self): + for num_rows in [0, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + metadatas = [random_bytes(10) for _ in range(num_rows)] + metadata, metadata_offset = tskit.pack_bytes(metadatas) + input_data["metadata"] = metadata + input_data["metadata_offset"] = metadata_offset + table.set_columns(**input_data) + unpacked_metadatas = tskit.unpack_bytes( + table.metadata, table.metadata_offset) + self.assertEqual(metadatas, unpacked_metadatas) + + def test_optional_metadata(self): + for num_rows in [0, 10, 100]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + del input_data["metadata"] + del input_data["metadata_offset"] + table.set_columns(**input_data) + self.assertEqual(len(list(table.metadata)), 0) + self.assertEqual( + list(table.metadata_offset), [0 for _ in range(num_rows + 1)]) + # Supplying None is the same not providing the column. + input_data["metadata"] = None + input_data["metadata_offset"] = None + table.set_columns(**input_data) + self.assertEqual(len(list(table.metadata)), 0) + self.assertEqual( + list(table.metadata_offset), [0 for _ in range(num_rows + 1)]) + + +class TestIndividualTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + + columns = [UInt32Column("flags")] + ragged_list_columns = [ + (DoubleColumn("location"), UInt32Column("location_offset")), + (CharColumn("metadata"), UInt32Column("metadata_offset"))] + string_colnames = [] + binary_colnames = ["metadata"] + input_parameters = [("max_rows_increment", 1024)] + equal_len_columns = [["flags"]] + table_class = tskit.IndividualTable + + def test_simple_example(self): + t = tskit.IndividualTable() + t.add_row(flags=0, location=[], metadata=b"123") + t.add_row(flags=1, location=(0, 1, 2, 3), metadata=b"\xf0") + s = str(t) + self.assertGreater(len(s), 0) + self.assertEqual(len(t), 2) + self.assertEqual(t[0].flags, 0) + self.assertEqual(list(t[0].location), []) + self.assertEqual(t[0].metadata, b"123") + self.assertEqual(t[1].flags, 1) + self.assertEqual(list(t[1].location), [0, 1, 2, 3]) + self.assertEqual(t[1].metadata, b"\xf0") + self.assertRaises(IndexError, t.__getitem__, -3) + + def test_add_row_defaults(self): + t = tskit.IndividualTable() + self.assertEqual(t.add_row(), 0) + self.assertEqual(t.flags[0], 0) + self.assertEqual(len(t.location), 0) + self.assertEqual(t.location_offset[0], 0) + self.assertEqual(len(t.metadata), 0) + self.assertEqual(t.metadata_offset[0], 0) + + +class TestNodeTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + + columns = [ + UInt32Column("flags"), + DoubleColumn("time"), + Int32Column("individual"), + Int32Column("population")] + ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] + string_colnames = [] + binary_colnames = ["metadata"] + input_parameters = [("max_rows_increment", 1024)] + equal_len_columns = [["time", "flags", "population"]] + table_class = tskit.NodeTable + + def test_simple_example(self): + t = tskit.NodeTable() + t.add_row(flags=0, time=1, population=2, individual=0, metadata=b"123") + t.add_row(flags=1, time=2, population=3, individual=1, metadata=b"\xf0") + s = str(t) + self.assertGreater(len(s), 0) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (0, 1, 2, 0, b"123")) + self.assertEqual(t[1], (1, 2, 3, 1, b"\xf0")) + self.assertEqual(t[0].flags, 0) + self.assertEqual(t[0].time, 1) + self.assertEqual(t[0].population, 2) + self.assertEqual(t[0].individual, 0) + self.assertEqual(t[0].metadata, b"123") + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -3) + + def test_add_row_defaults(self): + t = tskit.NodeTable() + self.assertEqual(t.add_row(), 0) + self.assertEqual(t.time[0], 0) + self.assertEqual(t.flags[0], 0) + self.assertEqual(t.population[0], tskit.NULL) + self.assertEqual(t.individual[0], tskit.NULL) + self.assertEqual(len(t.metadata), 0) + self.assertEqual(t.metadata_offset[0], 0) + + def test_optional_population(self): + for num_rows in [0, 10, 100]: + metadatas = [str(j) for j in range(num_rows)] + metadata, metadata_offset = tskit.pack_strings(metadatas) + flags = list(range(num_rows)) + time = list(range(num_rows)) + table = tskit.NodeTable() + table.set_columns( + metadata=metadata, metadata_offset=metadata_offset, + flags=flags, time=time) + self.assertEqual(list(table.population), [-1 for _ in range(num_rows)]) + self.assertEqual(list(table.flags), flags) + self.assertEqual(list(table.time), time) + self.assertEqual(list(table.metadata), list(metadata)) + self.assertEqual(list(table.metadata_offset), list(metadata_offset)) + table.set_columns(flags=flags, time=time, population=None) + self.assertEqual(list(table.population), [-1 for _ in range(num_rows)]) + self.assertEqual(list(table.flags), flags) + self.assertEqual(list(table.time), time) + + +class TestEdgeTable(unittest.TestCase, CommonTestsMixin): + + columns = [ + DoubleColumn("left"), + DoubleColumn("right"), + Int32Column("parent"), + Int32Column("child")] + equal_len_columns = [["left", "right", "parent", "child"]] + string_colnames = [] + binary_colnames = [] + ragged_list_columns = [] + input_parameters = [("max_rows_increment", 1024)] + table_class = tskit.EdgeTable + + def test_simple_example(self): + t = tskit.EdgeTable() + t.add_row(left=0, right=1, parent=2, child=3) + t.add_row(1, 2, 3, 4) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (0, 1, 2, 3)) + self.assertEqual(t[1], (1, 2, 3, 4)) + self.assertEqual(t[0].left, 0) + self.assertEqual(t[0].right, 1) + self.assertEqual(t[0].parent, 2) + self.assertEqual(t[0].child, 3) + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestSiteTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + columns = [DoubleColumn("position")] + ragged_list_columns = [ + (CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")), + (CharColumn("metadata"), UInt32Column("metadata_offset"))] + equal_len_columns = [["position"]] + string_colnames = ["ancestral_state"] + binary_colnames = ["metadata"] + input_parameters = [("max_rows_increment", 1024)] + table_class = tskit.SiteTable + + def test_simple_example(self): + t = tskit.SiteTable() + t.add_row(position=0, ancestral_state="1", metadata=b"2") + t.add_row(1, "2", b"\xf0") + s = str(t) + self.assertGreater(len(s), 0) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (0, "1", b"2")) + self.assertEqual(t[1], (1, "2", b"\xf0")) + self.assertEqual(t[0].position, 0) + self.assertEqual(t[0].ancestral_state, "1") + self.assertEqual(t[0].metadata, b"2") + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, 2) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestMutationTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + columns = [ + Int32Column("site"), + Int32Column("node"), + Int32Column("parent")] + ragged_list_columns = [ + (CharColumn("derived_state"), UInt32Column("derived_state_offset")), + (CharColumn("metadata"), UInt32Column("metadata_offset"))] + equal_len_columns = [["site", "node"]] + string_colnames = ["derived_state"] + binary_colnames = ["metadata"] + input_parameters = [("max_rows_increment", 1024)] + table_class = tskit.MutationTable + + def test_simple_example(self): + t = tskit.MutationTable() + t.add_row(site=0, node=1, derived_state="2", parent=3, metadata=b"4") + t.add_row(1, 2, "3", 4, b"\xf0") + s = str(t) + self.assertGreater(len(s), 0) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (0, 1, "2", 3, b"4")) + self.assertEqual(t[1], (1, 2, "3", 4, b"\xf0")) + self.assertEqual(t[0].site, 0) + self.assertEqual(t[0].node, 1) + self.assertEqual(t[0].derived_state, "2") + self.assertEqual(t[0].parent, 3) + self.assertEqual(t[0].metadata, b"4") + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestMigrationTable(unittest.TestCase, CommonTestsMixin): + columns = [ + DoubleColumn("left"), + DoubleColumn("right"), + Int32Column("node"), + Int32Column("source"), + Int32Column("dest"), + DoubleColumn("time")] + ragged_list_columns = [] + string_colnames = [] + binary_colnames = [] + input_parameters = [("max_rows_increment", 1024)] + equal_len_columns = [["left", "right", "node", "source", "dest", "time"]] + table_class = tskit.MigrationTable + + def test_simple_example(self): + t = tskit.MigrationTable() + t.add_row(left=0, right=1, node=2, source=3, dest=4, time=5) + t.add_row(1, 2, 3, 4, 5, 6) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (0, 1, 2, 3, 4, 5)) + self.assertEqual(t[1], (1, 2, 3, 4, 5, 6)) + self.assertEqual(t[0].left, 0) + self.assertEqual(t[0].right, 1) + self.assertEqual(t[0].node, 2) + self.assertEqual(t[0].source, 3) + self.assertEqual(t[0].dest, 4) + self.assertEqual(t[0].time, 5) + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestProvenanceTable(unittest.TestCase, CommonTestsMixin): + columns = [] + ragged_list_columns = [ + (CharColumn("timestamp"), UInt32Column("timestamp_offset")), + (CharColumn("record"), UInt32Column("record_offset"))] + equal_len_columns = [[]] + string_colnames = ["record", "timestamp"] + binary_colnames = [] + input_parameters = [("max_rows_increment", 1024)] + table_class = tskit.ProvenanceTable + + def test_simple_example(self): + t = tskit.ProvenanceTable() + t.add_row(timestamp="0", record="1") + t.add_row("2", "1") # The orders are reversed for default timestamp. + self.assertEqual(len(t), 2) + self.assertEqual(t[0], ("0", "1")) + self.assertEqual(t[1], ("1", "2")) + self.assertEqual(t[0].timestamp, "0") + self.assertEqual(t[0].record, "1") + self.assertEqual(t[0], t[-2]) + self.assertEqual(t[1], t[-1]) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestPopulationTable(unittest.TestCase, CommonTestsMixin): + columns = [] + ragged_list_columns = [ + (CharColumn("metadata"), UInt32Column("metadata_offset"))] + equal_len_columns = [[]] + string_colnames = [] + binary_colnames = ["metadata"] + input_parameters = [("max_rows_increment", 1024)] + table_class = tskit.PopulationTable + + def test_simple_example(self): + t = tskit.PopulationTable() + t.add_row(metadata=b"\xf0") + t.add_row(b"1") + s = str(t) + self.assertGreater(len(s), 0) + self.assertEqual(len(t), 2) + self.assertEqual(t[0], (b"\xf0",)) + self.assertEqual(t[0].metadata, b"\xf0") + self.assertEqual(t[1], (b"1",)) + self.assertRaises(IndexError, t.__getitem__, -3) + + +class TestStringPacking(unittest.TestCase): + """ + Tests the code for packing and unpacking unicode string data into numpy arrays. + """ + + def test_simple_string_case(self): + strings = ["hello", "world"] + packed, offset = tskit.pack_strings(strings) + self.assertEqual(list(offset), [0, 5, 10]) + self.assertEqual(packed.shape, (10,)) + returned = tskit.unpack_strings(packed, offset) + self.assertEqual(returned, strings) + + def verify_packing(self, strings): + packed, offset = tskit.pack_strings(strings) + self.assertEqual(packed.dtype, np.int8) + self.assertEqual(offset.dtype, np.uint32) + self.assertEqual(packed.shape[0], offset[-1]) + returned = tskit.unpack_strings(packed, offset) + self.assertEqual(strings, returned) + + def test_regular_cases(self): + for n in range(10): + strings = ["a" * j for j in range(n)] + self.verify_packing(strings) + + def test_random_cases(self): + for n in range(100): + strings = [random_strings(10) for _ in range(n)] + self.verify_packing(strings) + + def test_unicode(self): + self.verify_packing([u'abcdé', u'€']) + + +class TestBytePacking(unittest.TestCase): + """ + Tests the code for packing and unpacking binary data into numpy arrays. + """ + + def test_simple_string_case(self): + strings = [b"hello", b"world"] + packed, offset = tskit.pack_bytes(strings) + self.assertEqual(list(offset), [0, 5, 10]) + self.assertEqual(packed.shape, (10,)) + returned = tskit.unpack_bytes(packed, offset) + self.assertEqual(returned, strings) + + def verify_packing(self, data): + packed, offset = tskit.pack_bytes(data) + self.assertEqual(packed.dtype, np.int8) + self.assertEqual(offset.dtype, np.uint32) + self.assertEqual(packed.shape[0], offset[-1]) + returned = tskit.unpack_bytes(packed, offset) + self.assertEqual(data, returned) + return returned + + def test_random_cases(self): + for n in range(100): + data = [random_bytes(10) for _ in range(n)] + self.verify_packing(data) + + def test_pickle_packing(self): + data = [list(range(j)) for j in range(10)] + # Pickle each of these in turn + pickled = [pickle.dumps(d) for d in data] + unpacked = self.verify_packing(pickled) + unpickled = [pickle.loads(p) for p in unpacked] + self.assertEqual(data, unpickled) + + +class TestSortTables(unittest.TestCase): + """ + Tests for the sort_tables method. + """ + random_seed = 12345 + + def verify_randomise_tables(self, ts): + tables = ts.dump_tables() + + # Randomise the tables. + random.seed(self.random_seed) + randomised_edges = list(ts.edges()) + random.shuffle(randomised_edges) + tables.edges.clear() + for e in randomised_edges: + tables.edges.add_row(e.left, e.right, e.parent, e.child) + # Verify that import fails for randomised edges + self.assertRaises(_tskit.LibraryError, tables.tree_sequence) + tables.sort() + self.assertEqual(tables, ts.dump_tables()) + + tables.sites.clear() + tables.mutations.clear() + randomised_sites = list(ts.sites()) + random.shuffle(randomised_sites) + # Maps original IDs into their indexes in the randomised table. + site_id_map = {} + randomised_mutations = [] + for s in randomised_sites: + site_id_map[s.id] = tables.sites.add_row( + s.position, ancestral_state=s.ancestral_state, metadata=s.metadata) + randomised_mutations.extend(s.mutations) + random.shuffle(randomised_mutations) + for m in randomised_mutations: + tables.mutations.add_row( + site=site_id_map[m.site], node=m.node, derived_state=m.derived_state, + parent=m.parent, metadata=m.metadata) + if ts.num_sites > 1: + # Verify that import fails for randomised sites + self.assertRaises(_tskit.LibraryError, tables.tree_sequence) + tables.sort() + self.assertEqual(tables, ts.dump_tables()) + + ts_new = tables.tree_sequence() + self.assertEqual(ts_new.num_edges, ts.num_edges) + self.assertEqual(ts_new.num_trees, ts.num_trees) + self.assertEqual(ts_new.num_sites, ts.num_sites) + self.assertEqual(ts_new.num_mutations, ts.num_mutations) + + def verify_edge_sort_offset(self, ts): + """ + Verifies the behaviour of the edge_start offset value. + """ + tables = ts.dump_tables() + edges = tables.edges.copy() + starts = [0] + if len(edges) > 2: + starts = [0, 1, len(edges) // 2, len(edges) - 2] + random.seed(self.random_seed) + for start in starts: + # Unsort the edges starting from index start + all_edges = list(ts.edges()) + keep = all_edges[:start] + reversed_edges = all_edges[start:][::-1] + all_edges = keep + reversed_edges + tables.edges.clear() + for e in all_edges: + tables.edges.add_row(e.left, e.right, e.parent, e.child) + # Verify that import fails for randomised edges + self.assertRaises(_tskit.LibraryError, tables.tree_sequence) + # If we sort after the start value we should still fail. + tables.sort(edge_start=start + 1) + self.assertRaises(_tskit.LibraryError, tables.tree_sequence) + # Sorting from the correct index should give us back the original table. + tables.edges.clear() + for e in all_edges: + tables.edges.add_row(e.left, e.right, e.parent, e.child) + tables.sort(edge_start=start) + # Verify the new and old edges are equal. + self.assertEqual(edges, tables.edges) + + def test_single_tree_no_mutations(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_single_tree_no_mutations_metadata(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + ts = tsutil.add_random_metadata(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def test_many_trees_no_mutations(self): + ts = msprime.simulate(10, recombination_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_single_tree_mutations(self): + ts = msprime.simulate(10, mutation_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_sites, 2) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_single_tree_mutations_metadata(self): + ts = msprime.simulate(10, mutation_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_sites, 2) + ts = tsutil.add_random_metadata(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def test_single_tree_multichar_mutations(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + ts = tsutil.insert_multichar_mutations(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def test_single_tree_multichar_mutations_metadata(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + ts = tsutil.insert_multichar_mutations(ts, self.random_seed) + ts = tsutil.add_random_metadata(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def test_many_trees_mutations(self): + ts = msprime.simulate( + 10, recombination_rate=2, mutation_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + self.assertGreater(ts.num_sites, 2) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_many_trees_multichar_mutations(self): + ts = msprime.simulate(10, recombination_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + ts = tsutil.insert_multichar_mutations(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def test_many_trees_multichar_mutations_metadata(self): + ts = msprime.simulate(10, recombination_rate=2, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + ts = tsutil.insert_multichar_mutations(ts, self.random_seed) + ts = tsutil.add_random_metadata(ts, self.random_seed) + self.verify_randomise_tables(ts) + + def get_nonbinary_example(self, mutation_rate): + ts = msprime.simulate( + sample_size=20, recombination_rate=10, random_seed=self.random_seed, + mutation_rate=mutation_rate, demographic_events=[ + msprime.SimpleBottleneck(time=0.5, population=0, proportion=1)]) + # Make sure this really has some non-binary nodes + found = False + for e in ts.edgesets(): + if len(e.children) > 2: + found = True + break + self.assertTrue(found) + return ts + + def test_nonbinary_trees(self): + ts = self.get_nonbinary_example(mutation_rate=0) + self.assertGreater(ts.num_trees, 2) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_nonbinary_trees_mutations(self): + ts = self.get_nonbinary_example(mutation_rate=2) + self.assertGreater(ts.num_trees, 2) + self.assertGreater(ts.num_sites, 2) + self.verify_randomise_tables(ts) + self.verify_edge_sort_offset(ts) + + def test_incompatible_edges(self): + ts1 = msprime.simulate(10, random_seed=self.random_seed) + ts2 = msprime.simulate(20, random_seed=self.random_seed) + tables1 = ts1.dump_tables() + tables2 = ts2.dump_tables() + tables2.edges.set_columns(**tables1.edges.asdict()) + # The edges in tables2 will refer to nodes that don't exist. + self.assertRaises(_tskit.LibraryError, tables1.sort()) + + def test_incompatible_sites(self): + ts1 = msprime.simulate(10, random_seed=self.random_seed) + ts2 = msprime.simulate(10, mutation_rate=2, random_seed=self.random_seed) + self.assertGreater(ts2.num_sites, 1) + tables1 = ts1.dump_tables() + tables2 = ts2.dump_tables() + # The mutations in tables2 will refer to sites that don't exist. + tables1.mutations.set_columns(**tables2.mutations.asdict()) + self.assertRaises(_tskit.LibraryError, tables1.sort) + + def test_incompatible_mutation_nodes(self): + ts1 = msprime.simulate(2, random_seed=self.random_seed) + ts2 = msprime.simulate(10, mutation_rate=2, random_seed=self.random_seed) + self.assertGreater(ts2.num_sites, 1) + tables1 = ts1.dump_tables() + tables2 = ts2.dump_tables() + # The mutations in tables2 will refer to nodes that don't exist. + # print(tables2.sites.asdict()) + tables1.sites.set_columns(**tables2.sites.asdict()) + tables1.mutations.set_columns(**tables2.mutations.asdict()) + self.assertRaises(_tskit.LibraryError, tables1.sort) + + def test_empty_tables(self): + tables = tskit.TableCollection(1) + tables.sort() + self.assertEqual(tables.nodes.num_rows, 0) + self.assertEqual(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 0) + + +class TestSortMutations(unittest.TestCase): + """ + Tests that mutations are correctly sorted by sort_tables. + """ + + def test_sort_mutations_stability(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + """) + edges = six.StringIO("""\ + left right parent child + """) + sites = six.StringIO("""\ + position ancestral_state + 0.1 0 + 0.2 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 1 0 1 -1 + 1 1 1 -1 + 0 1 1 -1 + 0 0 1 -1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, + sequence_length=1, strict=False) + # Load text automatically calls sort tables, so we can test the + # output directly. + sites = ts.tables.sites + mutations = ts.tables.mutations + self.assertEqual(len(sites), 2) + self.assertEqual(len(mutations), 4) + self.assertEqual(list(mutations.site), [0, 0, 1, 1]) + self.assertEqual(list(mutations.node), [1, 0, 0, 1]) + + def test_sort_mutations_remap_parent_id(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + """) + edges = six.StringIO("""\ + left right parent child + """) + sites = six.StringIO("""\ + position ancestral_state + 0.1 0 + 0.2 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 1 0 1 -1 + 1 0 0 0 + 1 0 1 1 + 0 0 1 -1 + 0 0 0 3 + 0 0 1 4 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, + sequence_length=1, strict=False) + # Load text automatically calls sort tables, so we can test the + # output directly. + sites = ts.tables.sites + mutations = ts.tables.mutations + self.assertEqual(len(sites), 2) + self.assertEqual(len(mutations), 6) + self.assertEqual(list(mutations.site), [0, 0, 0, 1, 1, 1]) + self.assertEqual(list(mutations.node), [0, 0, 0, 0, 0, 0]) + self.assertEqual(list(mutations.parent), [-1, 0, 1, -1, 3, 4]) + + def test_sort_mutations_bad_parent_id(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + """) + edges = six.StringIO("""\ + left right parent child + """) + sites = six.StringIO("""\ + position ancestral_state + 0.1 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 1 0 1 -2 + """) + self.assertRaises( + _tskit.LibraryError, tskit.load_text, + nodes=nodes, edges=edges, sites=sites, mutations=mutations, + sequence_length=1, strict=False) + + +class TestSimplifyTables(unittest.TestCase): + """ + Tests for the simplify_tables function. + """ + random_seed = 42 + + @unittest.skipIf(IS_PY2, "Warnings different in Py2") + def test_deprecated_zero_mutation_sites(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=self.random_seed) + tables = ts.dump_tables() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + tables.simplify(ts.samples(), filter_zero_mutation_sites=True) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) + + def test_zero_mutation_sites(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=self.random_seed) + for filter_sites in [True, False]: + t1 = ts.dump_tables() + t1.simplify([0, 1], filter_zero_mutation_sites=filter_sites) + t2 = ts.dump_tables() + t2.simplify([0, 1], filter_sites=filter_sites) + t1.provenances.clear() + t2.provenances.clear() + self.assertEqual(t1, t2) + if filter_sites: + self.assertGreater(ts.num_sites, len(t1.sites)) + + def test_full_samples(self): + for n in [2, 10, 100, 1000]: + ts = msprime.simulate( + n, recombination_rate=1, mutation_rate=1, random_seed=self.random_seed) + tables = ts.dump_tables() + nodes_before = tables.nodes.copy() + edges_before = tables.edges.copy() + sites_before = tables.sites.copy() + mutations_before = tables.mutations.copy() + for samples in [None, list(ts.samples()), ts.samples()]: + node_map = tables.simplify(samples=samples) + self.assertEqual(node_map.shape, (len(nodes_before),)) + self.assertEqual(nodes_before, tables.nodes) + self.assertEqual(edges_before, tables.edges) + self.assertEqual(sites_before, tables.sites) + self.assertEqual(mutations_before, tables.mutations) + + def test_bad_samples(self): + n = 10 + ts = msprime.simulate(n, random_seed=self.random_seed) + tables = ts.dump_tables() + for bad_node in [-1, n, n + 1, ts.num_nodes - 1, ts.num_nodes, 2**31 - 1]: + self.assertRaises( + _tskit.LibraryError, tables.simplify, samples=[0, bad_node]) + + def test_bad_edge_ordering(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + tables = ts.dump_tables() + edges = tables.edges + # Reversing the edges violates the ordering constraints. + edges.set_columns( + left=edges.left[::-1], right=edges.right[::-1], + parent=edges.parent[::-1], child=edges.child[::-1]) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_bad_edges(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + for bad_node in [-1, ts.num_nodes, ts.num_nodes + 1, 2**31 - 1]: + # Bad parent node + tables = ts.dump_tables() + edges = tables.edges + parent = edges.parent + parent[0] = bad_node + edges.set_columns( + left=edges.left, right=edges.right, parent=parent, child=edges.child) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + # Bad child node + tables = ts.dump_tables() + edges = tables.edges + child = edges.child + child[0] = bad_node + edges.set_columns( + left=edges.left, right=edges.right, parent=edges.parent, child=child) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + # child == parent + tables = ts.dump_tables() + edges = tables.edges + child = edges.child + child[0] = edges.parent[0] + edges.set_columns( + left=edges.left, right=edges.right, parent=edges.parent, child=child) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + # left == right + tables = ts.dump_tables() + edges = tables.edges + left = edges.left + left[0] = edges.right[0] + edges.set_columns( + left=left, right=edges.right, parent=edges.parent, child=edges.child) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + # left > right + tables = ts.dump_tables() + edges = tables.edges + left = edges.left + left[0] = edges.right[0] + 1 + edges.set_columns( + left=left, right=edges.right, parent=edges.parent, child=edges.child) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_bad_mutation_nodes(self): + ts = msprime.simulate(10, random_seed=self.random_seed, mutation_rate=1) + self.assertGreater(ts.num_mutations, 0) + for bad_node in [-1, ts.num_nodes, 2**31 - 1]: + tables = ts.dump_tables() + mutations = tables.mutations + node = mutations.node + node[0] = bad_node + mutations.set_columns( + site=mutations.site, node=node, derived_state=mutations.derived_state, + derived_state_offset=mutations.derived_state_offset) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_bad_mutation_sites(self): + ts = msprime.simulate(10, random_seed=self.random_seed, mutation_rate=1) + self.assertGreater(ts.num_mutations, 0) + for bad_site in [-1, ts.num_sites, 2**31 - 1]: + tables = ts.dump_tables() + mutations = tables.mutations + site = mutations.site + site[0] = bad_site + mutations.set_columns( + site=site, node=mutations.node, derived_state=mutations.derived_state, + derived_state_offset=mutations.derived_state_offset) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_bad_site_positions(self): + ts = msprime.simulate(10, random_seed=self.random_seed, mutation_rate=1) + self.assertGreater(ts.num_mutations, 0) + # Positions > sequence_length are valid, as we can have gaps at the end of + # a tree sequence. + for bad_position in [-1, -1e-6]: + tables = ts.dump_tables() + sites = tables.sites + position = sites.position + position[0] = bad_position + sites.set_columns( + position=position, ancestral_state=sites.ancestral_state, + ancestral_state_offset=sites.ancestral_state_offset) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_duplicate_positions(self): + tables = tskit.TableCollection(sequence_length=1) + tables.sites.add_row(0, ancestral_state="0") + tables.sites.add_row(0, ancestral_state="0") + self.assertRaises(_tskit.LibraryError, tables.simplify, []) + + def test_samples_interface(self): + ts = msprime.simulate(50, random_seed=1) + for good_form in [[], [0, 1], (0, 1), np.array([0, 1], dtype=np.int32)]: + tables = ts.dump_tables() + tables.simplify(good_form) + tables = ts.dump_tables() + for bad_type in [[[[]]], {}]: + self.assertRaises(ValueError, tables.simplify, bad_type) + # We only accept numpy arrays of the right type + for bad_dtype in [np.uint32, np.int64, np.float64]: + self.assertRaises( + TypeError, tables.simplify, np.array([0, 1], dtype=bad_dtype)) + bad_samples = np.array([[0, 1], [2, 3]], dtype=np.int32) + self.assertRaises(ValueError, tables.simplify, bad_samples) + + +class TestTableCollection(unittest.TestCase): + """ + Tests for the convenience wrapper around a collection of related tables. + """ + def test_table_references(self): + ts = msprime.simulate(10, mutation_rate=2, random_seed=1) + tables = ts.tables + before_individuals = str(tables.individuals) + individuals = tables.individuals + before_nodes = str(tables.nodes) + nodes = tables.nodes + before_edges = str(tables.edges) + edges = tables.edges + before_migrations = str(tables.migrations) + migrations = tables.migrations + before_sites = str(tables.sites) + sites = tables.sites + before_mutations = str(tables.mutations) + mutations = tables.mutations + before_populations = str(tables.populations) + populations = tables.populations + before_nodes = str(tables.nodes) + provenances = tables.provenances + before_provenances = str(tables.provenances) + del tables + self.assertEqual(str(individuals), before_individuals) + self.assertEqual(str(nodes), before_nodes) + self.assertEqual(str(edges), before_edges) + self.assertEqual(str(migrations), before_migrations) + self.assertEqual(str(sites), before_sites) + self.assertEqual(str(mutations), before_mutations) + self.assertEqual(str(populations), before_populations) + self.assertEqual(str(provenances), before_provenances) + + def test_str(self): + ts = msprime.simulate(10, random_seed=1) + tables = ts.tables + s = str(tables) + self.assertGreater(len(s), 0) + + def test_asdict(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=1) + t = ts.tables + d1 = { + "sequence_length": t.sequence_length, + "individuals": t.individuals.asdict(), + "populations": t.populations.asdict(), + "nodes": t.nodes.asdict(), + "edges": t.edges.asdict(), + "sites": t.sites.asdict(), + "mutations": t.mutations.asdict(), + "migrations": t.migrations.asdict(), + "provenances": t.provenances.asdict()} + d2 = t.asdict() + self.assertEqual(set(d1.keys()), set(d2.keys())) + # TODO test the fromdict constructor + + def test_equals_empty(self): + self.assertEqual(tskit.TableCollection(), tskit.TableCollection()) + + def test_equals_sequence_length(self): + self.assertNotEqual( + tskit.TableCollection(sequence_length=1), + tskit.TableCollection(sequence_length=2)) + + def test_equals(self): + pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] + migration_matrix = [[0, 1], [1, 0]] + t1 = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1).dump_tables() + t2 = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1).dump_tables() + self.assertEqual(t1, t1) + # The provenances may or may not be equal depending on the clock + # precision for record. So clear them first. + t1.provenances.clear() + t2.provenances.clear() + self.assertEqual(t1, t2) + self.assertTrue(t1 == t2) + self.assertFalse(t1 != t2) + + t1.nodes.clear() + self.assertNotEqual(t1, t2) + t2.nodes.clear() + self.assertEqual(t1, t2) + + t1.edges.clear() + self.assertNotEqual(t1, t2) + t2.edges.clear() + self.assertEqual(t1, t2) + + t1.migrations.clear() + self.assertNotEqual(t1, t2) + t2.migrations.clear() + self.assertEqual(t1, t2) + + t1.sites.clear() + self.assertNotEqual(t1, t2) + t2.sites.clear() + self.assertEqual(t1, t2) + + t1.mutations.clear() + self.assertNotEqual(t1, t2) + t2.mutations.clear() + self.assertEqual(t1, t2) + + t1.populations.clear() + self.assertNotEqual(t1, t2) + t2.populations.clear() + self.assertEqual(t1, t2) + + def test_sequence_length(self): + for sequence_length in [0, 1, 100.1234]: + tables = tskit.TableCollection(sequence_length=sequence_length) + self.assertEqual(tables.sequence_length, sequence_length) + + def test_uuid_simulation(self): + ts = msprime.simulate(10, random_seed=1) + tables = ts.tables + self.assertIsNone(tables.file_uuid, None) + + def test_uuid_empty(self): + tables = tskit.TableCollection(sequence_length=1) + self.assertIsNone(tables.file_uuid, None) + + +class TestTableCollectionPickle(unittest.TestCase): + """ + Tests that we can round-trip table collections through pickle. + """ + def verify(self, tables): + other_tables = pickle.loads(pickle.dumps(tables)) + self.assertEqual(tables, other_tables) + + def test_simple_simulation(self): + ts = msprime.simulate(2, random_seed=1) + self.verify(ts.dump_tables()) + + def test_simulation_populations(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(10), + msprime.PopulationConfiguration(10)], + migration_matrix=[[0, 1], [1, 0]], + record_migrations=True, + random_seed=1) + self.verify(ts.dump_tables()) + + def test_simulation_sites(self): + ts = msprime.simulate(12, random_seed=1, mutation_rate=5) + self.assertGreater(ts.num_sites, 1) + self.verify(ts.dump_tables()) + + def test_simulation_individuals(self): + ts = msprime.simulate(100, random_seed=1) + ts = tsutil.insert_random_ploidy_individuals(ts, seed=1) + self.assertGreater(ts.num_individuals, 1) + self.verify(ts.dump_tables()) + + def test_empty_tables(self): + self.verify(tskit.TableCollection()) + + +class TestDeduplicateSites(unittest.TestCase): + """ + Tests for the TableCollection.deduplicate_sites method. + """ + def test_empty(self): + tables = tskit.TableCollection(1) + tables.deduplicate_sites() + self.assertEqual(tables, tskit.TableCollection(1)) + + def test_unsorted(self): + tables = msprime.simulate(10, mutation_rate=1, random_seed=1).dump_tables() + self.assertGreater(len(tables.sites), 0) + position = tables.sites.position + for j in range(len(position) - 1): + position = np.roll(position, 1) + tables.sites.set_columns( + position=position, ancestral_state=tables.sites.ancestral_state, + ancestral_state_offset=tables.sites.ancestral_state_offset) + self.assertRaises(_tskit.LibraryError, tables.deduplicate_sites) + + def test_bad_position(self): + for bad_position in [-1, -0.001]: + tables = tskit.TableCollection() + tables.sites.add_row(bad_position, "0") + self.assertRaises(_tskit.LibraryError, tables.deduplicate_sites) + + def test_no_effect(self): + t1 = msprime.simulate(10, mutation_rate=1, random_seed=1).dump_tables() + t2 = msprime.simulate(10, mutation_rate=1, random_seed=1).dump_tables() + self.assertGreater(len(t1.sites), 0) + t1.deduplicate_sites() + t1.provenances.clear() + t2.provenances.clear() + self.assertEqual(t1, t2) + + def test_same_sites(self): + t1 = msprime.simulate(10, mutation_rate=1, random_seed=1).dump_tables() + t2 = msprime.simulate(10, mutation_rate=1, random_seed=1).dump_tables() + self.assertGreater(len(t1.sites), 0) + t1.sites.append_columns( + position=t1.sites.position, + ancestral_state=t1.sites.ancestral_state, + ancestral_state_offset=t1.sites.ancestral_state_offset) + self.assertEqual(len(t1.sites), 2 * len(t2.sites)) + t1.sort() + t1.deduplicate_sites() + t1.provenances.clear() + t2.provenances.clear() + self.assertEqual(t1, t2) + + def test_order_maintained(self): + t1 = tskit.TableCollection(1) + t1.sites.add_row(position=0, ancestral_state="first") + t1.sites.add_row(position=0, ancestral_state="second") + t1.deduplicate_sites() + self.assertEqual(len(t1.sites), 1) + self.assertEqual(t1.sites.ancestral_state.tobytes(), b"first") + + def test_multichar_ancestral_state(self): + ts = msprime.simulate(8, random_seed=3, mutation_rate=1) + self.assertGreater(ts.num_sites, 2) + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for site in ts.sites(): + site_id = tables.sites.add_row( + position=site.position, ancestral_state="A" * site.id) + tables.sites.add_row(position=site.position, ancestral_state="0") + for mutation in site.mutations: + tables.mutations.add_row( + site=site_id, node=mutation.node, derived_state="T" * site.id) + tables.deduplicate_sites() + new_ts = tables.tree_sequence() + self.assertEqual(new_ts.num_sites, ts.num_sites) + for site in new_ts.sites(): + self.assertEqual(site.ancestral_state, site.id * "A") + + def test_multichar_metadata(self): + ts = msprime.simulate(8, random_seed=3, mutation_rate=1) + self.assertGreater(ts.num_sites, 2) + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for site in ts.sites(): + site_id = tables.sites.add_row( + position=site.position, ancestral_state="0", metadata=b"A" * site.id) + tables.sites.add_row(position=site.position, ancestral_state="0") + for mutation in site.mutations: + tables.mutations.add_row( + site=site_id, node=mutation.node, derived_state="1", + metadata=b"T" * site.id) + tables.deduplicate_sites() + new_ts = tables.tree_sequence() + self.assertEqual(new_ts.num_sites, ts.num_sites) + for site in new_ts.sites(): + self.assertEqual(site.metadata, site.id * b"A") + + +class TestBaseTable(unittest.TestCase): + """ + Tests of the table superclass. + """ + def test_asdict_not_implemented(self): + t = tskit.BaseTable(None, None) + with self.assertRaises(NotImplementedError): + t.asdict() diff --git a/python/tests/test_threads.py b/python/tests/test_threads.py new file mode 100644 index 0000000000..7b36065480 --- /dev/null +++ b/python/tests/test_threads.py @@ -0,0 +1,257 @@ +""" +Test cases for threading enabled aspects of the API. +""" +from __future__ import print_function +from __future__ import division + +import sys +import threading +import unittest +import platform + +import numpy as np +import msprime + +import tskit +import tests.tsutil as tsutil + +IS_PY2 = sys.version_info[0] < 3 +IS_WINDOWS = platform.system() == "Windows" + + +def run_threads(worker, num_threads): + results = [None for _ in range(num_threads)] + threads = [ + threading.Thread(target=worker, args=(j, results)) + for j in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + return results + + +class TestLdCalculatorReplicates(unittest.TestCase): + """ + Tests the LdCalculator object to ensure we get correct results + when using threads. + """ + num_test_sites = 25 + + def get_tree_sequence(self): + ts = msprime.simulate(20, mutation_rate=10, recombination_rate=10, random_seed=8) + return tsutil.subsample_sites(ts, self.num_test_sites) + + def test_get_r2_multiple_instances(self): + # This is the nominal case where we have a separate LdCalculator + # instance in each thread. + ts = self.get_tree_sequence() + ld_calc = tskit.LdCalculator(ts) + A = ld_calc.get_r2_matrix() + del ld_calc + m = A.shape[0] + + def worker(thread_index, results): + ld_calc = tskit.LdCalculator(ts) + row = np.zeros(m) + results[thread_index] = row + for j in range(m): + row[j] = ld_calc.get_r2(thread_index, j) + + results = run_threads(worker, m) + for j in range(m): + self.assertTrue(np.allclose(results[j], A[j])) + + def test_get_r2_single_instance(self): + # This is the degenerate case where we have a single LdCalculator + # instance shared by the threads. We should have only one thread + # actually executing get_r2() at one time. + ts = self.get_tree_sequence() + ld_calc = tskit.LdCalculator(ts) + A = ld_calc.get_r2_matrix() + m = A.shape[0] + + def worker(thread_index, results): + row = np.zeros(m) + results[thread_index] = row + for j in range(m): + row[j] = ld_calc.get_r2(thread_index, j) + + results = run_threads(worker, m) + for j in range(m): + self.assertTrue(np.allclose(results[j], A[j])) + + def test_get_r2_array_multiple_instances(self): + # This is the nominal case where we have a separate LdCalculator + # instance in each thread. + ts = self.get_tree_sequence() + ld_calc = tskit.LdCalculator(ts) + A = ld_calc.get_r2_matrix() + m = A.shape[0] + del ld_calc + + def worker(thread_index, results): + ld_calc = tskit.LdCalculator(ts) + results[thread_index] = np.array( + ld_calc.get_r2_array(thread_index)) + + results = run_threads(worker, m) + for j in range(m): + self.assertTrue(np.allclose(results[j], A[j, j + 1:])) + + def test_get_r2_array_single_instance(self): + # This is the degenerate case where we have a single LdCalculator + # instance shared by the threads. We should have only one thread + # actually executing get_r2_array() at one time. Because the buffer + # is shared by many different instances, we can't make any assertions + # about the returned values --- they are essentially gibberish. + # However, we shouldn't crash and burn, which is what this test + # is here to check for. + ts = self.get_tree_sequence() + ld_calc = tskit.LdCalculator(ts) + m = ts.get_num_mutations() + + def worker(thread_index, results): + results[thread_index] = ld_calc.get_r2_array(thread_index).shape + + results = run_threads(worker, m) + for j in range(m): + self.assertEqual(results[j][0], m - j - 1) + + +# @unittest.skipIf(IS_PY2, "Cannot test thread support on Py2.") +# Temporarily skipping these on windows too. See +# https://github.com/jeromekelleher/tskit/issues/344 +@unittest.skipIf(IS_PY2 or IS_WINDOWS, "Cannot test thread support on Py2.") +class TestTables(unittest.TestCase): + """ + Tests to ensure that attempts to access tables in threads correctly + raise an exception. + """ + def get_tables(self): + # TODO include migrations here. + ts = msprime.simulate( + 100, mutation_rate=10, recombination_rate=10, random_seed=8) + return ts.tables + + def run_multiple_writers(self, writer, num_writers=32): + barrier = threading.Barrier(num_writers) + + def writer_proxy(thread_index, results): + barrier.wait() + # Attempts to operate on a table while locked should raise a RuntimeError + try: + writer(thread_index, results) + results[thread_index] = 0 + except RuntimeError: + results[thread_index] = 1 + + results = run_threads(writer_proxy, num_writers) + failures = sum(results) + successes = num_writers - failures + # Note: we would like to insist that #failures is > 0, but this is too + # stochastic to guarantee for test purposes. + self.assertGreaterEqual(failures, 0) + self.assertGreater(successes, 0) + + def run_failing_reader(self, writer, reader, num_readers=32): + """ + Runs a test in which a single writer acceses some tables + and a bunch of other threads try to read the data. + """ + barrier = threading.Barrier(num_readers + 1) + + def writer_proxy(): + barrier.wait() + writer() + + def reader_proxy(thread_index, results): + barrier.wait() + # Attempts to operate on a table while locked should raise a RuntimeError + try: + reader(thread_index, results) + results[thread_index] = 0 + except RuntimeError: + results[thread_index] = 1 + + writer_thread = threading.Thread(target=writer_proxy) + writer_thread.start() + results = run_threads(reader_proxy, num_readers) + writer_thread.join() + + failures = sum(results) + successes = num_readers - failures + # Note: we would like to insist that #failures is > 0, but this is too + # stochastic to guarantee for test purposes. + self.assertGreaterEqual(failures, 0) + self.assertGreater(successes, 0) + + def test_many_simplify_all_tables(self): + tables = self.get_tables() + + def writer(thread_index, results): + tables.simplify([0, 1]) + + self.run_multiple_writers(writer) + + def test_many_sort(self): + tables = self.get_tables() + + def writer(thread_index, results): + tables.sort() + + self.run_multiple_writers(writer) + + def run_simplify_access_table(self, table_name, col_name): + tables = self.get_tables() + + def writer(): + tables.simplify([0, 1]) + + table = getattr(tables, table_name) + + def reader(thread_index, results): + for j in range(100): + x = getattr(table, col_name) + assert x.shape[0] == len(table) + + self.run_failing_reader(writer, reader) + + def run_sort_access_table(self, table_name, col_name): + tables = self.get_tables() + + def writer(): + tables.sort() + + table = getattr(tables, table_name) + + def reader(thread_index, results): + for j in range(100): + x = getattr(table, col_name) + assert x.shape[0] == len(table) + + self.run_failing_reader(writer, reader) + + def test_simplify_access_nodes(self): + self.run_simplify_access_table("nodes", "time") + + def test_simplify_access_edges(self): + self.run_simplify_access_table("edges", "left") + + def test_simplify_access_sites(self): + self.run_simplify_access_table("sites", "position") + + def test_simplify_access_mutations(self): + self.run_simplify_access_table("mutations", "site") + + def test_sort_access_nodes(self): + self.run_sort_access_table("nodes", "time") + + def test_sort_access_edges(self): + self.run_sort_access_table("edges", "left") + + def test_sort_access_sites(self): + self.run_sort_access_table("sites", "position") + + def test_sort_access_mutations(self): + self.run_sort_access_table("mutations", "site") diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py new file mode 100644 index 0000000000..d47fbad99c --- /dev/null +++ b/python/tests/test_topology.py @@ -0,0 +1,3814 @@ +""" +Test cases for the supported topological variations and operations. +""" +from __future__ import print_function +from __future__ import division + +try: + # We use the zip as iterator functionality here. + from future_builtins import zip +except ImportError: + # This fails for Python 3.x, but that's fine. + pass + +import unittest +import itertools +import random + +import six +import numpy as np +import msprime + +import tskit +import _tskit +import tests as tests +import tests.tsutil as tsutil +import tests.test_wright_fisher as wf + + +def generate_segments(n, sequence_length=100, seed=None): + rng = random.Random(seed) + segs = [] + for j in range(n): + left = rng.randint(0, sequence_length - 1) + right = rng.randint(left + 1, sequence_length) + assert left < right + segs.append(tests.Segment(left, right, j)) + return segs + + +class TestOverlappingSegments(unittest.TestCase): + """ + Tests for the overlapping segments algorithm required for simplify. + This test probably belongs somewhere else. + """ + + def test_random(self): + segs = generate_segments(10, 20, 1) + for left, right, X in tests.overlapping_segments(segs): + self.assertGreater(right, left) + self.assertGreater(len(X), 0) + + def test_empty(self): + ret = list(tests.overlapping_segments([])) + self.assertEqual(len(ret), 0) + + def test_single_interval(self): + for j in range(1, 10): + segs = [tests.Segment(0, 1, j) for _ in range(j)] + ret = list(tests.overlapping_segments(segs)) + self.assertEqual(len(ret), 1) + left, right, X = ret[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(sorted(segs), sorted(X)) + + def test_stairs_down(self): + segs = [ + tests.Segment(0, 1, 0), + tests.Segment(0, 2, 1), + tests.Segment(0, 3, 2)] + ret = list(tests.overlapping_segments(segs)) + self.assertEqual(len(ret), 3) + + left, right, X = ret[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(sorted(X), sorted(segs)) + + left, right, X = ret[1] + self.assertEqual(left, 1) + self.assertEqual(right, 2) + self.assertEqual(sorted(X), sorted(segs[1:])) + + left, right, X = ret[2] + self.assertEqual(left, 2) + self.assertEqual(right, 3) + self.assertEqual(sorted(X), sorted(segs[2:])) + + def test_stairs_up(self): + segs = [ + tests.Segment(0, 3, 0), + tests.Segment(1, 3, 1), + tests.Segment(2, 3, 2)] + ret = list(tests.overlapping_segments(segs)) + self.assertEqual(len(ret), 3) + + left, right, X = ret[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(X, segs[:1]) + + left, right, X = ret[1] + self.assertEqual(left, 1) + self.assertEqual(right, 2) + self.assertEqual(sorted(X), sorted(segs[:2])) + + left, right, X = ret[2] + self.assertEqual(left, 2) + self.assertEqual(right, 3) + self.assertEqual(sorted(X), sorted(segs)) + + def test_pyramid(self): + segs = [ + tests.Segment(0, 5, 0), + tests.Segment(1, 4, 1), + tests.Segment(2, 3, 2)] + ret = list(tests.overlapping_segments(segs)) + self.assertEqual(len(ret), 5) + + left, right, X = ret[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(X, segs[:1]) + + left, right, X = ret[1] + self.assertEqual(left, 1) + self.assertEqual(right, 2) + self.assertEqual(sorted(X), sorted(segs[:2])) + + left, right, X = ret[2] + self.assertEqual(left, 2) + self.assertEqual(right, 3) + self.assertEqual(sorted(X), sorted(segs)) + + left, right, X = ret[3] + self.assertEqual(left, 3) + self.assertEqual(right, 4) + self.assertEqual(sorted(X), sorted(segs[:2])) + + left, right, X = ret[4] + self.assertEqual(left, 4) + self.assertEqual(right, 5) + self.assertEqual(sorted(X), sorted(segs[:1])) + + def test_gap(self): + segs = [ + tests.Segment(0, 2, 0), + tests.Segment(3, 4, 1)] + ret = list(tests.overlapping_segments(segs)) + self.assertEqual(len(ret), 2) + + left, right, X = ret[0] + self.assertEqual(left, 0) + self.assertEqual(right, 2) + self.assertEqual(X, segs[:1]) + + left, right, X = ret[1] + self.assertEqual(left, 3) + self.assertEqual(right, 4) + self.assertEqual(X, segs[1:]) + + +class TopologyTestCase(unittest.TestCase): + """ + Superclass of test cases containing common utilities. + """ + random_seed = 123456 + + def assert_haplotypes_equal(self, ts1, ts2): + h1 = list(ts1.haplotypes()) + h2 = list(ts2.haplotypes()) + self.assertEqual(h1, h2) + + def assert_variants_equal(self, ts1, ts2): + v1 = list(ts1.variants(as_bytes=True)) + v2 = list(ts2.variants(as_bytes=True)) + self.assertEqual(v1, v2) + + def check_num_samples(self, ts, x): + """ + Compare against x, a list of tuples of the form + `(tree number, parent, number of samples)`. + """ + k = 0 + tss = ts.trees(sample_counts=True) + t = next(tss) + for j, node, nl in x: + while k < j: + t = next(tss) + k += 1 + self.assertEqual(nl, t.num_samples(node)) + + def check_num_tracked_samples(self, ts, tracked_samples, x): + k = 0 + tss = ts.trees(sample_counts=True, tracked_samples=tracked_samples) + t = next(tss) + for j, node, nl in x: + while k < j: + t = next(tss) + k += 1 + self.assertEqual(nl, t.num_tracked_samples(node)) + + def check_sample_iterator(self, ts, x): + """ + Compare against x, a list of tuples of the form + `(tree number, node, sample ID list)`. + """ + k = 0 + tss = ts.trees(sample_lists=True) + t = next(tss) + for j, node, samples in x: + while k < j: + t = next(tss) + k += 1 + for u, v in zip(samples, t.samples(node)): + self.assertEqual(u, v) + + +class TestZeroRoots(unittest.TestCase): + """ + Tests that for the case in which we have zero samples and therefore + zero roots in our trees. + """ + def remove_samples(self, ts): + tables = ts.dump_tables() + tables.nodes.set_columns( + flags=np.zeros_like(tables.nodes.flags), + time=tables.nodes.time) + return tables.tree_sequence() + + def verify(self, ts, no_root_ts): + self.assertEqual(ts.num_trees, no_root_ts.num_trees) + for tree, no_root in zip(ts.trees(), no_root_ts.trees()): + self.assertEqual(no_root.num_roots, 0) + self.assertEqual(no_root.left_root, tskit.NULL) + self.assertEqual(no_root.roots, []) + self.assertEqual(tree.parent_dict, no_root.parent_dict) + + def test_single_tree(self): + ts = msprime.simulate(10, random_seed=1) + no_root_ts = self.remove_samples(ts) + self.assertEqual(ts.num_trees, 1) + self.verify(ts, no_root_ts) + + def test_multiple_trees(self): + ts = msprime.simulate(10, recombination_rate=2, random_seed=1) + no_root_ts = self.remove_samples(ts) + self.assertGreater(ts.num_trees, 1) + self.verify(ts, no_root_ts) + + +class TestEmptyTreeSequences(TopologyTestCase): + """ + Tests covering tree sequences that have zero edges. + """ + def test_zero_nodes(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 0) + self.assertEqual(ts.num_edges, 0) + t = next(ts.trees()) + self.assertEqual(t.index, 0) + self.assertEqual(t.left_root, tskit.NULL) + self.assertEqual(t.interval, (0, 1)) + self.assertEqual(t.roots, []) + self.assertEqual(t.root, tskit.NULL) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(list(t.nodes()), []) + self.assertEqual(list(ts.haplotypes()), []) + self.assertEqual(list(ts.variants()), []) + methods = [t.parent, t.left_child, t.right_child, t.left_sib, t.right_sib] + for method in methods: + for u in [-1, 0, 1, 100]: + self.assertRaises(ValueError, method, u) + tsp = ts.simplify() + self.assertEqual(tsp.num_nodes, 0) + self.assertEqual(tsp.num_edges, 0) + + def test_one_node_zero_samples(self): + tables = tskit.TableCollection(sequence_length=1) + tables.nodes.add_row(time=0, flags=0) + # Without a sequence length this should fail. + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 1) + self.assertEqual(ts.sample_size, 0) + self.assertEqual(ts.num_edges, 0) + self.assertEqual(ts.num_sites, 0) + self.assertEqual(ts.num_mutations, 0) + t = next(ts.trees()) + self.assertEqual(t.index, 0) + self.assertEqual(t.left_root, tskit.NULL) + self.assertEqual(t.interval, (0, 1)) + self.assertEqual(t.roots, []) + self.assertEqual(t.root, tskit.NULL) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(list(t.nodes()), []) + self.assertEqual(list(ts.haplotypes()), []) + self.assertEqual(list(ts.variants()), []) + methods = [t.parent, t.left_child, t.right_child, t.left_sib, t.right_sib] + for method in methods: + self.assertEqual(method(0), tskit.NULL) + for u in [-1, 1, 100]: + self.assertRaises(ValueError, method, u) + + def test_one_node_zero_samples_sites(self): + tables = tskit.TableCollection(sequence_length=1) + tables.nodes.add_row(time=0, flags=0) + tables.sites.add_row(position=0.5, ancestral_state='0') + tables.mutations.add_row(site=0, derived_state='1', node=0) + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 1) + self.assertEqual(ts.sample_size, 0) + self.assertEqual(ts.num_edges, 0) + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 1) + t = next(ts.trees()) + self.assertEqual(t.index, 0) + self.assertEqual(t.left_root, tskit.NULL) + self.assertEqual(t.interval, (0, 1)) + self.assertEqual(t.roots, []) + self.assertEqual(t.root, tskit.NULL) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(len(list(t.sites())), 1) + self.assertEqual(list(t.nodes()), []) + self.assertEqual(list(ts.haplotypes()), []) + self.assertEqual(len(list(ts.variants())), 1) + tsp = ts.simplify() + self.assertEqual(tsp.num_nodes, 0) + self.assertEqual(tsp.num_edges, 0) + + def test_one_node_one_sample(self): + tables = tskit.TableCollection(sequence_length=1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 1) + self.assertEqual(ts.sample_size, 1) + self.assertEqual(ts.num_edges, 0) + t = next(ts.trees()) + self.assertEqual(t.index, 0) + self.assertEqual(t.left_root, 0) + self.assertEqual(t.interval, (0, 1)) + self.assertEqual(t.roots, [0]) + self.assertEqual(t.root, 0) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(list(t.nodes()), [0]) + self.assertEqual(list(ts.haplotypes()), [""]) + self.assertEqual(list(ts.variants()), []) + methods = [t.parent, t.left_child, t.right_child, t.left_sib, t.right_sib] + for method in methods: + self.assertEqual(method(0), tskit.NULL) + for u in [-1, 1, 100]: + self.assertRaises(ValueError, method, u) + tsp = ts.simplify() + self.assertEqual(tsp.num_nodes, 1) + self.assertEqual(tsp.num_edges, 0) + + def test_one_node_one_sample_sites(self): + tables = tskit.TableCollection(sequence_length=1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + tables.sites.add_row(position=0.5, ancestral_state='0') + tables.mutations.add_row(site=0, derived_state='1', node=0) + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 1) + self.assertEqual(ts.sample_size, 1) + self.assertEqual(ts.num_edges, 0) + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 1) + t = next(ts.trees()) + self.assertEqual(t.index, 0) + self.assertEqual(t.left_root, 0) + self.assertEqual(t.interval, (0, 1)) + self.assertEqual(t.roots, [0]) + self.assertEqual(t.root, 0) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(list(t.nodes()), [0]) + self.assertEqual(list(ts.haplotypes()), ["1"]) + self.assertEqual(len(list(ts.variants())), 1) + methods = [t.parent, t.left_child, t.right_child, t.left_sib, t.right_sib] + for method in methods: + self.assertEqual(method(0), tskit.NULL) + for u in [-1, 1, 100]: + self.assertRaises(ValueError, method, u) + tsp = ts.simplify(filter_sites=False) + self.assertEqual(tsp.num_nodes, 1) + self.assertEqual(tsp.num_edges, 0) + self.assertEqual(tsp.num_sites, 1) + + +class TestHoleyTreeSequences(TopologyTestCase): + """ + Tests for tree sequences in which we have partial (or no) trees defined + over some of the sequence. + """ + def verify_trees(self, ts, expected): + observed = [] + for t in ts.trees(): + observed.append((t.interval, t.parent_dict)) + self.assertEqual(expected, observed) + # Test simple algorithm also. + observed = [] + for interval, parent in tsutil.algorithm_T(ts): + parent_dict = {j: parent[j] for j in range(ts.num_nodes) if parent[j] >= 0} + observed.append((interval, parent_dict)) + self.assertEqual(expected, observed) + + def verify_zero_roots(self, ts): + for tree in ts.trees(): + self.assertEqual(tree.num_roots, 0) + self.assertEqual(tree.left_root, tskit.NULL) + self.assertEqual(tree.roots, []) + + def test_simple_hole(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 2 0 + 2 3 2 0 + 0 1 2 1 + 2 3 2 1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + expected = [ + ((0, 1), {0: 2, 1: 2}), + ((1, 2), {}), + ((2, 3), {0: 2, 1: 2})] + self.verify_trees(ts, expected) + + def test_simple_hole_zero_roots(self): + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 0 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 2 0 + 2 3 2 0 + 0 1 2 1 + 2 3 2 1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + expected = [ + ((0, 1), {0: 2, 1: 2}), + ((1, 2), {}), + ((2, 3), {0: 2, 1: 2})] + self.verify_trees(ts, expected) + self.verify_zero_roots(ts) + + def test_initial_gap(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 1 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + expected = [ + ((0, 1), {}), + ((1, 2), {0: 2, 1: 2})] + self.verify_trees(ts, expected) + + def test_initial_gap_zero_roots(self): + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 0 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 1 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + expected = [ + ((0, 1), {}), + ((1, 2), {0: 2, 1: 2})] + self.verify_trees(ts, expected) + self.verify_zero_roots(ts) + + def test_final_gap(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, sequence_length=3, strict=False) + expected = [ + ((0, 2), {0: 2, 1: 2}), + ((2, 3), {})] + self.verify_trees(ts, expected) + + def test_final_gap_zero_roots(self): + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 0 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, sequence_length=3, strict=False) + expected = [ + ((0, 2), {0: 2, 1: 2}), + ((2, 3), {})] + self.verify_trees(ts, expected) + self.verify_zero_roots(ts) + + def test_initial_and_final_gap(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 1 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, sequence_length=3, strict=False) + expected = [ + ((0, 1), {}), + ((1, 2), {0: 2, 1: 2}), + ((2, 3), {})] + self.verify_trees(ts, expected) + + def test_initial_and_final_gap_zero_roots(self): + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 0 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 1 2 2 0,1 + """) + ts = tskit.load_text(nodes, edges, sequence_length=3, strict=False) + expected = [ + ((0, 1), {}), + ((1, 2), {0: 2, 1: 2}), + ((2, 3), {})] + self.verify_trees(ts, expected) + self.verify_zero_roots(ts) + + +class TestTsinferExamples(TopologyTestCase): + """ + Test cases on troublesome topology examples that arose from tsinfer. + """ + def test_no_last_tree(self): + # The last tree was not being generated here because of a bug in + # the low-level tree generation code. + nodes = six.StringIO("""\ + id is_sample population time + 0 1 -1 3.00000000000000 + 1 1 -1 2.00000000000000 + 2 1 -1 2.00000000000000 + 3 1 -1 2.00000000000000 + 4 1 -1 2.00000000000000 + 5 1 -1 1.00000000000000 + 6 1 -1 1.00000000000000 + 7 1 -1 1.00000000000000 + 8 1 -1 1.00000000000000 + 9 1 -1 1.00000000000000 + 10 1 -1 1.00000000000000 + """) + edges = six.StringIO("""\ + id left right parent child + 0 62291.41659631 79679.17408763 1 5 + 1 62291.41659631 62374.60889677 1 6 + 2 122179.36037089 138345.43104411 1 7 + 3 67608.32330402 79679.17408763 1 8 + 4 122179.36037089 138345.43104411 1 8 + 5 62291.41659631 79679.17408763 1 9 + 6 126684.47550333 138345.43104411 1 10 + 7 23972.05905068 62291.41659631 2 5 + 8 79679.17408763 82278.53390076 2 5 + 9 23972.05905068 62291.41659631 2 6 + 10 79679.17408763 110914.43816806 2 7 + 11 145458.28890561 189765.31932273 2 7 + 12 79679.17408763 110914.43816806 2 8 + 13 145458.28890561 200000.00000000 2 8 + 14 23972.05905068 62291.41659631 2 9 + 15 79679.17408763 110914.43816806 2 9 + 16 145458.28890561 145581.18329797 2 10 + 17 4331.62138785 23972.05905068 3 6 + 18 4331.62138785 23972.05905068 3 9 + 19 110914.43816806 122179.36037089 4 7 + 20 138345.43104411 145458.28890561 4 7 + 21 110914.43816806 122179.36037089 4 8 + 22 138345.43104411 145458.28890561 4 8 + 23 110914.43816806 112039.30503475 4 9 + 24 138345.43104411 145458.28890561 4 10 + 25 0.00000000 200000.00000000 0 1 + 26 0.00000000 200000.00000000 0 2 + 27 0.00000000 200000.00000000 0 3 + 28 0.00000000 200000.00000000 0 4 + """) + ts = tskit.load_text(nodes, edges, sequence_length=200000, strict=False) + pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + num_trees = 0 + for t in pts.trees(): + num_trees += 1 + self.assertEqual(num_trees, ts.num_trees) + n = 0 + for pt, t in zip(pts.trees(), ts.trees()): + self.assertEqual((pt.left, pt.right), t.interval) + for j in range(ts.num_nodes): + self.assertEqual(pt.parent[j], t.parent(j)) + self.assertEqual(pt.left_child[j], t.left_child(j)) + self.assertEqual(pt.right_child[j], t.right_child(j)) + self.assertEqual(pt.left_sib[j], t.left_sib(j)) + self.assertEqual(pt.right_sib[j], t.right_sib(j)) + n += 1 + self.assertEqual(n, num_trees) + intervals = [t.interval for t in ts.trees()] + self.assertEqual(intervals[0][0], 0) + self.assertEqual(intervals[-1][-1], ts.sequence_length) + + +class TestRecordSquashing(TopologyTestCase): + """ + Tests that we correctly squash adjacent equal records together. + """ + def test_single_record(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 1 0 + 1 2 1 0 + """) + ts = tskit.load_text(nodes, edges, strict=False) + tss, node_map = ts.simplify(map_nodes=True) + self.assertEqual(list(node_map), [0, 1]) + self.assertEqual(tss.dump_tables().nodes, ts.dump_tables().nodes) + simplified_edges = list(tss.edges()) + self.assertEqual(len(simplified_edges), 1) + e = simplified_edges[0] + self.assertEqual(e.left, 0) + self.assertEqual(e.right, 2) + + def test_single_tree(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + ts_redundant = tsutil.insert_redundant_breakpoints(ts) + tss = ts_redundant.simplify() + self.assertEqual(tss.dump_tables().nodes, ts.dump_tables().nodes) + self.assertEqual(tss.dump_tables().edges, ts.dump_tables().edges) + + def test_many_trees(self): + ts = msprime.simulate( + 20, recombination_rate=5, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + ts_redundant = tsutil.insert_redundant_breakpoints(ts) + tss = ts_redundant.simplify() + self.assertEqual(tss.dump_tables().nodes, ts.dump_tables().nodes) + self.assertEqual(tss.dump_tables().edges, ts.dump_tables().edges) + + +class TestRedundantBreakpoints(TopologyTestCase): + """ + Tests for dealing with redundant breakpoints within the tree sequence. + These are records that may be squashed together into a single record. + """ + def test_single_tree(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + ts_redundant = tsutil.insert_redundant_breakpoints(ts) + self.assertEqual(ts.sample_size, ts_redundant.sample_size) + self.assertEqual(ts.sequence_length, ts_redundant.sequence_length) + self.assertEqual(ts_redundant.num_trees, 2) + trees = [t.parent_dict for t in ts_redundant.trees()] + self.assertEqual(len(trees), 2) + self.assertEqual(trees[0], trees[1]) + self.assertEqual([t.parent_dict for t in ts.trees()][0], trees[0]) + + def test_many_trees(self): + ts = msprime.simulate( + 20, recombination_rate=5, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + ts_redundant = tsutil.insert_redundant_breakpoints(ts) + self.assertEqual(ts.sample_size, ts_redundant.sample_size) + self.assertEqual(ts.sequence_length, ts_redundant.sequence_length) + self.assertGreater(ts_redundant.num_trees, ts.num_trees) + self.assertGreater(ts_redundant.num_edges, ts.num_edges) + redundant_trees = ts_redundant.trees() + redundant_t = next(redundant_trees) + comparisons = 0 + for t in ts.trees(): + while redundant_t is not None and redundant_t.interval[1] <= t.interval[1]: + self.assertEqual(t.parent_dict, redundant_t.parent_dict) + comparisons += 1 + redundant_t = next(redundant_trees, None) + self.assertEqual(comparisons, ts_redundant.num_trees) + + +class TestUnaryNodes(TopologyTestCase): + """ + Tests for situations in which we have unary nodes in the tree sequence. + """ + def test_simple_case(self): + # Simple case where we have n = 2 and some unary nodes. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1 + 4 0 2 + 5 0 3 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 2 0 + 0 1 3 1 + 0 1 4 2,3 + 0 1 5 4 + """) + sites = "position ancestral_state\n" + mutations = "site node derived_state\n" + for j in range(5): + position = j * 1 / 5 + sites += "{} 0\n".format(position) + mutations += "{} {} 1\n".format(j, j) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=six.StringIO(sites), + mutations=six.StringIO(mutations), strict=False) + + self.assertEqual(ts.sample_size, 2) + self.assertEqual(ts.num_nodes, 6) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 5) + self.assertEqual(ts.num_mutations, 5) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + t = next(ts.trees()) + self.assertEqual( + t.parent_dict, {0: 2, 1: 3, 2: 4, 3: 4, 4: 5}) + self.assertEqual(t.mrca(0, 1), 4) + self.assertEqual(t.mrca(0, 2), 2) + self.assertEqual(t.mrca(0, 4), 4) + self.assertEqual(t.mrca(0, 5), 5) + self.assertEqual(t.mrca(0, 3), 4) + H = list(ts.haplotypes()) + self.assertEqual(H[0], "10101") + self.assertEqual(H[1], "01011") + + def test_ladder_tree(self): + # We have a single tree with a long ladder of unary nodes along a path + num_unary_nodes = 30 + n = 2 + nodes = """\ + is_sample time + 1 0 + 1 0 + """ + edges = """\ + left right parent child + 0 1 2 0 + """ + for j in range(num_unary_nodes + 2): + nodes += "0 {}\n".format(j + 2) + for j in range(num_unary_nodes): + edges += "0 1 {} {}\n".format(n + j + 1, n + j) + root = num_unary_nodes + 3 + root_time = num_unary_nodes + 3 + edges += "0 1 {} 1,{}\n".format(root, num_unary_nodes + 2) + ts = tskit.load_text(six.StringIO(nodes), six.StringIO(edges), strict=False) + t = next(ts.trees()) + self.assertEqual(t.mrca(0, 1), root) + self.assertEqual(t.tmrca(0, 1), root_time) + ts_simplified, node_map = ts.simplify(map_nodes=True) + test_map = [tskit.NULL for _ in range(ts.num_nodes)] + test_map[0] = 0 + test_map[1] = 1 + test_map[root] = 2 + self.assertEqual(list(node_map), test_map) + self.assertEqual(ts_simplified.num_edges, 2) + t = next(ts_simplified.trees()) + self.assertEqual(t.mrca(0, 1), 2) + self.assertEqual(t.tmrca(0, 1), root_time) + + def verify_unary_tree_sequence(self, ts): + """ + Take the specified tree sequence and produce an equivalent in which + unary records have been interspersed. + """ + self.assertGreater(ts.num_trees, 2) + self.assertGreater(ts.num_mutations, 2) + tables = ts.dump_tables() + next_node = ts.num_nodes + node_times = {j: node.time for j, node in enumerate(ts.nodes())} + edges = [] + for e in ts.edges(): + node = ts.node(e.parent) + t = node.time - 1e-14 # Arbitrary small value. + next_node = len(tables.nodes) + tables.nodes.add_row(time=t, population=node.population) + edges.append(tskit.Edge( + left=e.left, right=e.right, parent=next_node, child=e.child)) + node_times[next_node] = t + edges.append(tskit.Edge( + left=e.left, right=e.right, parent=e.parent, child=next_node)) + edges.sort(key=lambda e: node_times[e.parent]) + tables.edges.reset() + for e in edges: + tables.edges.add_row( + left=e.left, right=e.right, child=e.child, parent=e.parent) + ts_new = tables.tree_sequence() + self.assertGreater(ts_new.num_edges, ts.num_edges) + self.assert_haplotypes_equal(ts, ts_new) + self.assert_variants_equal(ts, ts_new) + ts_simplified = ts_new.simplify() + self.assertEqual(list(ts_simplified.records()), list(ts.records())) + self.assert_haplotypes_equal(ts, ts_simplified) + self.assert_variants_equal(ts, ts_simplified) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + + def test_binary_tree_sequence_unary_nodes(self): + ts = msprime.simulate( + 20, recombination_rate=5, mutation_rate=5, random_seed=self.random_seed) + self.verify_unary_tree_sequence(ts) + + def test_nonbinary_tree_sequence_unary_nodes(self): + demographic_events = [ + msprime.SimpleBottleneck(time=1.0, population=0, proportion=0.95)] + ts = msprime.simulate( + 20, recombination_rate=10, mutation_rate=5, + demographic_events=demographic_events, random_seed=self.random_seed) + found = False + for r in ts.edgesets(): + if len(r.children) > 2: + found = True + self.assertTrue(found) + self.verify_unary_tree_sequence(ts) + + +class TestGeneralSamples(TopologyTestCase): + """ + Test cases in which we have samples at arbitrary nodes (i.e., not at + {0,...,n - 1}). + """ + def test_simple_case(self): + # Simple case where we have n = 3 and samples starting at n. + nodes = six.StringIO("""\ + id is_sample time + 0 0 2 + 1 0 1 + 2 1 0 + 3 1 0 + 4 1 0 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 1 2,3 + 0 1 0 1,4 + """) + sites = six.StringIO("""\ + position ancestral_state + 0.1 0 + 0.2 0 + 0.3 0 + 0.4 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 2 1 + 1 3 1 + 2 4 1 + 3 1 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + + self.assertEqual(ts.sample_size, 3) + self.assertEqual(list(ts.samples()), [2, 3, 4]) + self.assertEqual(ts.num_nodes, 5) + self.assertEqual(ts.num_nodes, 5) + self.assertEqual(ts.num_sites, 4) + self.assertEqual(ts.num_mutations, 4) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + t = next(ts.trees()) + self.assertEqual(t.root, 0) + self.assertEqual(t.parent_dict, {1: 0, 2: 1, 3: 1, 4: 0}) + H = list(ts.haplotypes()) + self.assertEqual(H[0], "1001") + self.assertEqual(H[1], "0101") + self.assertEqual(H[2], "0010") + + tss, node_map = ts.simplify(map_nodes=True) + self.assertEqual(list(node_map), [4, 3, 0, 1, 2]) + # We should have the same tree sequence just with canonicalised nodes. + self.assertEqual(tss.sample_size, 3) + self.assertEqual(list(tss.samples()), [0, 1, 2]) + self.assertEqual(tss.num_nodes, 5) + self.assertEqual(tss.num_trees, 1) + self.assertEqual(tss.num_sites, 4) + self.assertEqual(tss.num_mutations, 4) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + t = next(tss.trees()) + self.assertEqual(t.root, 4) + self.assertEqual(t.parent_dict, {0: 3, 1: 3, 2: 4, 3: 4}) + H = list(tss.haplotypes()) + self.assertEqual(H[0], "1001") + self.assertEqual(H[1], "0101") + self.assertEqual(H[2], "0010") + + def verify_permuted_nodes(self, ts): + """ + Take the specified tree sequence and permute the nodes, verifying that we + get back a tree sequence with the correct properties. + """ + # Mapping from the original nodes into nodes in the new tree sequence. + node_map = list(range(ts.num_nodes)) + random.shuffle(node_map) + # Change the permutation so that the relative order of samples is maintained. + # Then, we should get back exactly the same tree sequence after simplify + # and haplotypes and variants are also equal. + samples = sorted(node_map[:ts.sample_size]) + node_map = samples + node_map[ts.sample_size:] + permuted = tsutil.permute_nodes(ts, node_map) + self.assertEqual(ts.sequence_length, permuted.sequence_length) + self.assertEqual(list(permuted.samples()), samples) + self.assertEqual(list(permuted.haplotypes()), list(ts.haplotypes())) + self.assertEqual( + [v.genotypes for v in permuted.variants(as_bytes=True)], + [v.genotypes for v in ts.variants(as_bytes=True)]) + self.assertEqual(ts.num_trees, permuted.num_trees) + j = 0 + for t1, t2 in zip(ts.trees(), permuted.trees()): + t1_dict = {node_map[k]: node_map[v] for k, v in t1.parent_dict.items()} + self.assertEqual(node_map[t1.root], t2.root) + self.assertEqual(t1_dict, t2.parent_dict) + for u1 in t1.nodes(): + u2 = node_map[u1] + self.assertEqual( + sorted([node_map[v] for v in t1.samples(u1)]), + sorted(list(t2.samples(u2)))) + j += 1 + self.assertEqual(j, ts.num_trees) + + # The simplified version of the permuted tree sequence should be in canonical + # form, and identical to the original. + simplified, s_node_map = permuted.simplify(map_nodes=True) + + original_tables = ts.dump_tables() + simplified_tables = simplified.dump_tables() + original_tables.provenances.clear() + simplified_tables.provenances.clear() + + self.assertEqual( + original_tables.sequence_length, simplified_tables.sequence_length) + self.assertEqual(original_tables.nodes, simplified_tables.nodes) + self.assertEqual(original_tables.edges, simplified_tables.edges) + self.assertEqual(original_tables.sites, simplified_tables.sites) + self.assertEqual(original_tables.mutations, simplified_tables.mutations) + self.assertEqual(original_tables.individuals, simplified_tables.individuals) + self.assertEqual(original_tables.populations, simplified_tables.populations) + + self.assertEqual(original_tables, simplified_tables) + self.assertEqual(ts.sequence_length, simplified.sequence_length) + for tree in simplified.trees(): + pass + + for u, v in enumerate(node_map): + self.assertEqual(s_node_map[v], u) + self.assertTrue(np.array_equal(simplified.samples(), ts.samples())) + self.assertEqual(list(simplified.nodes()), list(ts.nodes())) + self.assertEqual(list(simplified.edges()), list(ts.edges())) + self.assertEqual(list(simplified.sites()), list(ts.sites())) + self.assertEqual(list(simplified.haplotypes()), list(ts.haplotypes())) + self.assertEqual( + list(simplified.variants(as_bytes=True)), list(ts.variants(as_bytes=True))) + + def test_single_tree_permuted_nodes(self): + ts = msprime.simulate(10, mutation_rate=5, random_seed=self.random_seed) + self.verify_permuted_nodes(ts) + + def test_binary_tree_sequence_permuted_nodes(self): + ts = msprime.simulate( + 20, recombination_rate=5, mutation_rate=5, random_seed=self.random_seed) + self.verify_permuted_nodes(ts) + + def test_nonbinary_tree_sequence_permuted_nodes(self): + demographic_events = [ + msprime.SimpleBottleneck(time=1.0, population=0, proportion=0.95)] + ts = msprime.simulate( + 20, recombination_rate=10, mutation_rate=5, + demographic_events=demographic_events, random_seed=self.random_seed) + found = False + for e in ts.edgesets(): + if len(e.children) > 2: + found = True + self.assertTrue(found) + self.verify_permuted_nodes(ts) + + +class TestSimplifyExamples(TopologyTestCase): + """ + Tests for simplify where we write out the input and expected output + or we detect expected errors. + """ + def verify_simplify( + self, samples, filter_sites=True, + nodes_before=None, edges_before=None, sites_before=None, + mutations_before=None, nodes_after=None, edges_after=None, + sites_after=None, mutations_after=None, debug=False): + """ + Verifies that if we run simplify on the specified input we get the + required output. + """ + ts = tskit.load_text( + nodes=six.StringIO(nodes_before), + edges=six.StringIO(edges_before), + sites=six.StringIO(sites_before) if sites_before is not None else None, + mutations=( + six.StringIO(mutations_before) + if mutations_before is not None else None), + strict=False) + before = ts.dump_tables() + + ts = tskit.load_text( + nodes=six.StringIO(nodes_after), + edges=six.StringIO(edges_after), + sites=six.StringIO(sites_after) if sites_after is not None else None, + mutations=( + six.StringIO(mutations_after) + if mutations_after is not None else None), + strict=False, + sequence_length=before.sequence_length) + after = ts.dump_tables() + + ts = before.tree_sequence() + # Make sure it's a valid topology. We want to be sure we evaluate the + # whole iterator + for t in ts.trees(): + self.assertTrue(t is not None) + before.simplify( + samples=samples, filter_sites=filter_sites) + if debug: + print("before") + print(before) + print("after") + print(after) + self.assertEqual(before, after) + + def test_unsorted_edges(self): + # We have two nodes at the same time and interleave edges for + # these nodes together. This is an error because all edges for + # a given parent must be contigous. + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1 + """ + edges_before = """\ + left right parent child + 0 1 2 0,1 + 0 1 3 0,1 + 1 2 2 0,1 + 1 2 3 0,1 + """ + nodes = tskit.parse_nodes(six.StringIO(nodes_before), strict=False) + edges = tskit.parse_edges(six.StringIO(edges_before), strict=False) + # Cannot use load_text here because it calls sort() + tables = tskit.TableCollection(sequence_length=2) + tables.nodes.set_columns(**nodes.asdict()) + tables.edges.set_columns(**edges.asdict()) + self.assertRaises(_tskit.LibraryError, tables.simplify, samples=[0, 1]) + + def test_single_binary_tree(self): + # + # 2 4 + # / \ + # 1 3 \ + # / \ \ + # 0 (0)(1) (2) + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + """ + edges_before = """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """ + # We sample 0 and 2, so we get + nodes_after = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 2 + """ + edges_after = """\ + left right parent child + 0 1 2 0,1 + """ + self.verify_simplify( + samples=[0, 2], + nodes_before=nodes_before, edges_before=edges_before, + nodes_after=nodes_after, edges_after=edges_after) + + def test_single_binary_tree_internal_sample(self): + # + # 2 4 + # / \ + # 1 (3) \ + # / \ \ + # 0 (0) 1 (2) + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 0 + 3 1 1 + 4 0 2 + """ + edges_before = """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """ + # We sample 0 and 3, so we get + nodes_after = """\ + id is_sample time + 0 1 0 + 1 1 1 + """ + edges_after = """\ + left right parent child + 0 1 1 0 + """ + self.verify_simplify( + samples=[0, 3], + nodes_before=nodes_before, edges_before=edges_before, + nodes_after=nodes_after, edges_after=edges_after) + + def test_single_binary_tree_internal_sample_meet_at_root(self): + # 3 5 + # / \ + # 2 4 (6) + # / \ + # 1 (3) \ + # / \ \ + # 0 (0) 1 2 + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 0 + 3 1 1 + 4 0 2 + 5 0 3 + 6 1 2 + """ + edges_before = """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + 0 1 5 4,6 + """ + # We sample 0 and 3 and 6, so we get + nodes_after = """\ + id is_sample time + 0 1 0 + 1 1 1 + 2 1 2 + 3 0 3 + """ + edges_after = """\ + left right parent child + 0 1 1 0 + 0 1 3 1,2 + """ + self.verify_simplify( + samples=[0, 3, 6], + nodes_before=nodes_before, edges_before=edges_before, + nodes_after=nodes_after, edges_after=edges_after) + + def test_single_binary_tree_simple_mutations(self): + # 3 5 + # / \ + # 2 4 \ + # / \ s0 + # 1 3 s1 \ + # / \ \ \ + # 0 (0) (1) 2 (6) + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 0 + 3 0 1 + 4 0 2 + 5 0 3 + 6 1 0 + """ + edges_before = """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + 0 1 5 4,6 + """ + sites_before = """\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + """ + mutations_before = """\ + site node derived_state + 0 6 1 + 1 2 1 + """ + + # We sample 0 and 2 and 6, so we get + nodes_after = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 3 0 3 + """ + edges_after = """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """ + sites_after = """\ + id position ancestral_state + 0 0.1 0 + """ + mutations_after = """\ + site node derived_state + 0 2 1 + """ + self.verify_simplify( + samples=[0, 1, 6], + nodes_before=nodes_before, edges_before=edges_before, + sites_before=sites_before, mutations_before=mutations_before, + nodes_after=nodes_after, edges_after=edges_after, + sites_after=sites_after, mutations_after=mutations_after) + # If we don't filter the fixed sites, we should get the same + # mutations and the original sites table back. + self.verify_simplify( + samples=[0, 1, 6], filter_sites=False, + nodes_before=nodes_before, edges_before=edges_before, + sites_before=sites_before, mutations_before=mutations_before, + nodes_after=nodes_after, edges_after=edges_after, + sites_after=sites_before, mutations_after=mutations_after) + + def test_overlapping_edges(self): + nodes = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """ + edges_before = """\ + left right parent child + 0 2 2 0 + 1 3 2 1 + """ + # We resolve the overlapping edges here. Since the flanking regions + # have no interesting edges, these are left out of the output. + edges_after = """\ + left right parent child + 1 2 2 0,1 + """ + self.verify_simplify( + samples=[0, 1], + nodes_before=nodes, edges_before=edges_before, + nodes_after=nodes, edges_after=edges_after) + + def test_overlapping_edges_internal_samples(self): + nodes = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 1 + """ + edges = """\ + left right parent child + 0 2 2 0 + 1 3 2 1 + """ + self.verify_simplify( + samples=[0, 1, 2], + nodes_before=nodes, edges_before=edges, nodes_after=nodes, edges_after=edges) + + def test_unary_edges_no_overlap(self): + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """ + edges_before = """\ + left right parent child + 0 2 2 0 + 2 3 2 1 + """ + # Because there is no overlap between the samples, we just get an + # empty set of output edges. + nodes_after = """\ + id is_sample time + 0 1 0 + 1 1 0 + """ + edges_after = """\ + left right parent child + """ + self.verify_simplify( + samples=[0, 1], + nodes_before=nodes_before, edges_before=edges_before, + nodes_after=nodes_after, edges_after=edges_after) + + def test_unary_edges_no_overlap_internal_sample(self): + nodes_before = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 1 + """ + edges_before = """\ + left right parent child + 0 1 2 0 + 1 2 2 1 + """ + self.verify_simplify( + samples=[0, 1, 2], + nodes_before=nodes_before, edges_before=edges_before, + nodes_after=nodes_before, edges_after=edges_before) + + +class TestNonSampleExternalNodes(TopologyTestCase): + """ + Tests for situations in which we have tips that are not samples. + """ + def test_simple_case(self): + # Simplest case where we have n = 2 and external non-sample nodes. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 0 + 4 0 0 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 2 0,1,3,4 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + 2 0.3 0 + 3 0.4 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 0 1 + 1 1 1 + 2 3 1 + 3 4 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.sample_size, 2) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 5) + self.assertEqual(ts.num_sites, 4) + self.assertEqual(ts.num_mutations, 4) + t = next(ts.trees()) + self.assertEqual(t.parent_dict, {0: 2, 1: 2, 3: 2, 4: 2}) + self.assertEqual(t.root, 2) + ts_simplified, node_map = ts.simplify(map_nodes=True) + self.assertEqual(list(node_map), [0, 1, 2, -1, -1]) + self.assertEqual(ts_simplified.num_nodes, 3) + self.assertEqual(ts_simplified.num_trees, 1) + t = next(ts_simplified.trees()) + self.assertEqual(t.parent_dict, {0: 2, 1: 2}) + self.assertEqual(t.root, 2) + # We should have removed the two non-sample mutations. + self.assertEqual([s.position for s in t.sites()], [0.1, 0.2]) + + def test_unary_non_sample_external_nodes(self): + # Take an ordinary tree sequence and put a bunch of external non + # sample nodes on it. + ts = msprime.simulate( + 15, recombination_rate=5, random_seed=self.random_seed, mutation_rate=5) + self.assertGreater(ts.num_trees, 2) + self.assertGreater(ts.num_mutations, 2) + tables = ts.dump_tables() + next_node = ts.num_nodes + tables.edges.reset() + for e in ts.edges(): + tables.edges.add_row(e.left, e.right, e.parent, e.child) + tables.edges.add_row(e.left, e.right, e.parent, next_node) + tables.nodes.add_row(time=0) + next_node += 1 + tables.sort() + ts_new = tables.tree_sequence() + self.assertEqual(ts_new.num_nodes, next_node) + self.assertEqual(ts_new.sample_size, ts.sample_size) + self.assert_haplotypes_equal(ts, ts_new) + self.assert_variants_equal(ts, ts_new) + ts_simplified = ts_new.simplify() + self.assertEqual(ts_simplified.num_nodes, ts.num_nodes) + self.assertEqual(ts_simplified.sample_size, ts.sample_size) + self.assertEqual(list(ts_simplified.records()), list(ts.records())) + self.assert_haplotypes_equal(ts, ts_simplified) + self.assert_variants_equal(ts, ts_simplified) + + +class TestMultipleRoots(TopologyTestCase): + """ + Tests for situations where we have multiple roots for the samples. + """ + + def test_simplest_degenerate_case(self): + # Simplest case where we have n = 2 and no edges. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + """) + edges = six.StringIO("""\ + left right parent child + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 0 1 + 1 1 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, + sequence_length=1, strict=False) + self.assertEqual(ts.num_nodes, 2) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 2) + self.assertEqual(ts.num_mutations, 2) + t = next(ts.trees()) + self.assertEqual(t.parent_dict, {}) + self.assertEqual(sorted(t.roots), [0, 1]) + self.assertEqual(list(ts.haplotypes()), ["10", "01"]) + self.assertEqual( + [v.genotypes for v in ts.variants(as_bytes=True)], [b"10", b"01"]) + simplified = ts.simplify() + t1 = ts.dump_tables() + t2 = simplified.dump_tables() + self.assertEqual(t1.nodes, t2.nodes) + self.assertEqual(t1.edges, t2.edges) + + def test_simplest_non_degenerate_case(self): + # Simplest case where we have n = 4 and two trees. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 0 1 + 5 0 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 4 0,1 + 0 1 5 2,3 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + 2 0.3 0 + 3 0.4 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 0 1 + 1 1 1 + 2 2 1 + 3 3 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.num_nodes, 6) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 4) + self.assertEqual(ts.num_mutations, 4) + t = next(ts.trees()) + self.assertEqual(t.parent_dict, {0: 4, 1: 4, 2: 5, 3: 5}) + self.assertEqual(list(ts.haplotypes()), ["1000", "0100", "0010", "0001"]) + self.assertEqual( + [v.genotypes for v in ts.variants(as_bytes=True)], + [b"1000", b"0100", b"0010", b"0001"]) + self.assertEqual(t.mrca(0, 1), 4) + self.assertEqual(t.mrca(0, 4), 4) + self.assertEqual(t.mrca(2, 3), 5) + self.assertEqual(t.mrca(0, 2), tskit.NULL) + self.assertEqual(t.mrca(0, 3), tskit.NULL) + self.assertEqual(t.mrca(2, 4), tskit.NULL) + ts_simplified, node_map = ts.simplify(map_nodes=True) + for j in range(4): + self.assertEqual(node_map[j], j) + self.assertEqual(ts_simplified.num_nodes, 6) + self.assertEqual(ts_simplified.num_trees, 1) + self.assertEqual(ts_simplified.num_sites, 4) + self.assertEqual(ts_simplified.num_mutations, 4) + t = next(ts_simplified.trees()) + self.assertEqual(t.parent_dict, {0: 4, 1: 4, 2: 5, 3: 5}) + + def test_two_reducable_trees(self): + # We have n = 4 and two trees, with some unary nodes and non-sample leaves + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 0 1 + 5 0 1 + 6 0 2 + 7 0 3 + 8 0 0 # Non sample leaf + """) + edges = six.StringIO("""\ + left right parent child + 0 1 4 0 + 0 1 5 1 + 0 1 6 4,5 + 0 1 7 2,3,8 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + 2 0.3 0 + 3 0.4 0 + 4 0.5 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 0 1 + 1 1 1 + 2 2 1 + 3 3 1 + 4 8 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.num_nodes, 9) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 5) + self.assertEqual(ts.num_mutations, 5) + t = next(ts.trees()) + self.assertEqual(t.parent_dict, {0: 4, 1: 5, 2: 7, 3: 7, 4: 6, 5: 6, 8: 7}) + self.assertEqual(list(ts.haplotypes()), ["10000", "01000", "00100", "00010"]) + self.assertEqual( + [v.genotypes for v in ts.variants(as_bytes=True)], + [b"1000", b"0100", b"0010", b"0001", b"0000"]) + self.assertEqual(t.mrca(0, 1), 6) + self.assertEqual(t.mrca(2, 3), 7) + self.assertEqual(t.mrca(2, 8), 7) + self.assertEqual(t.mrca(0, 2), tskit.NULL) + self.assertEqual(t.mrca(0, 3), tskit.NULL) + self.assertEqual(t.mrca(0, 8), tskit.NULL) + ts_simplified, node_map = ts.simplify(map_nodes=True) + for j in range(4): + self.assertEqual(node_map[j], j) + self.assertEqual(ts_simplified.num_nodes, 6) + self.assertEqual(ts_simplified.num_trees, 1) + t = next(ts_simplified.trees()) + # print(ts_simplified.tables) + self.assertEqual( + list(ts_simplified.haplotypes()), ["1000", "0100", "0010", "0001"]) + self.assertEqual( + [v.genotypes for v in ts_simplified.variants(as_bytes=True)], + [b"1000", b"0100", b"0010", b"0001"]) + # The site over the non-sample external node should have been discarded. + sites = list(t.sites()) + self.assertEqual(sites[-1].position, 0.4) + self.assertEqual(t.parent_dict, {0: 4, 1: 4, 2: 5, 3: 5}) + + def test_one_reducable_tree(self): + # We have n = 4 and two trees. One tree is reducable and the other isn't. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 0 1 + 5 0 1 + 6 0 2 + 7 0 3 + 8 0 0 # Non sample leaf + """) + edges = six.StringIO("""\ + left right parent child + 0 1 4 0 + 0 1 5 1 + 0 1 6 4,5 + 0 1 7 2,3,8 + """) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + self.assertEqual(ts.num_nodes, 9) + self.assertEqual(ts.num_trees, 1) + t = next(ts.trees()) + self.assertEqual(t.parent_dict, {0: 4, 1: 5, 2: 7, 3: 7, 4: 6, 5: 6, 8: 7}) + self.assertEqual(t.mrca(0, 1), 6) + self.assertEqual(t.mrca(2, 3), 7) + self.assertEqual(t.mrca(2, 8), 7) + self.assertEqual(t.mrca(0, 2), tskit.NULL) + self.assertEqual(t.mrca(0, 3), tskit.NULL) + self.assertEqual(t.mrca(0, 8), tskit.NULL) + ts_simplified = ts.simplify() + self.assertEqual(ts_simplified.num_nodes, 6) + self.assertEqual(ts_simplified.num_trees, 1) + t = next(ts_simplified.trees()) + self.assertEqual(t.parent_dict, {0: 4, 1: 4, 2: 5, 3: 5}) + + # NOTE: This test has not been checked since updating to the text representation + # so there might be other problems with it. + def test_mutations_over_roots(self): + # Mutations over root nodes should be ok when we have multiple roots. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + 5 0 2 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 3 0,1 + 0 1 4 3 + 0 1 5 2 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.1 0 + 1 0.2 0 + 2 0.3 0 + 3 0.4 0 + 4 0.5 0 + 5 0.6 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 0 1 + 1 1 1 + 2 3 1 + 3 4 1 + 4 2 1 + 5 5 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.num_nodes, 6) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 6) + self.assertEqual(ts.num_mutations, 6) + t = next(ts.trees()) + self.assertEqual(len(list(t.sites())), 6) + haplotypes = ["101100", "011100", "000011"] + variants = [b"100", b"010", b"110", b"110", b"001", b"001"] + self.assertEqual(list(ts.haplotypes()), haplotypes) + self.assertEqual([v.genotypes for v in ts.variants(as_bytes=True)], variants) + ts_simplified = ts.simplify(filter_sites=False) + self.assertEqual(list(ts_simplified.haplotypes()), haplotypes) + self.assertEqual( + [v.genotypes for v in ts_simplified.variants(as_bytes=True)], variants) + + def test_break_single_tree(self): + # Take a single largish tree from tskit, and remove the oldest record. + # This breaks it into two subtrees. + ts = msprime.simulate(20, random_seed=self.random_seed, mutation_rate=4) + self.assertGreater(ts.num_mutations, 5) + tables = ts.dump_tables() + tables.edges.set_columns( + left=tables.edges.left[:-1], + right=tables.edges.right[:-1], + parent=tables.edges.parent[:-1], + child=tables.edges.child[:-1]) + ts_new = tables.tree_sequence() + self.assertEqual(ts.sample_size, ts_new.sample_size) + self.assertEqual(ts.num_edges, ts_new.num_edges + 1) + self.assertEqual(ts.num_trees, ts_new.num_trees) + self.assert_haplotypes_equal(ts, ts_new) + self.assert_variants_equal(ts, ts_new) + roots = set() + t_new = next(ts_new.trees()) + for u in ts_new.samples(): + while t_new.parent(u) != tskit.NULL: + u = t_new.parent(u) + roots.add(u) + self.assertEqual(len(roots), 2) + self.assertEqual(sorted(roots), sorted(t_new.roots)) + + +class TestWithVisuals(TopologyTestCase): + """ + Some pedantic tests with ascii depictions of what's supposed to happen. + """ + + def verify_simplify_topology(self, ts, sample, haplotypes=False): + # copies from test_highlevel.py + new_ts, node_map = ts.simplify(sample, map_nodes=True) + old_trees = ts.trees() + old_tree = next(old_trees) + self.assertGreaterEqual(ts.get_num_trees(), new_ts.get_num_trees()) + for new_tree in new_ts.trees(): + new_left, new_right = new_tree.get_interval() + old_left, old_right = old_tree.get_interval() + # Skip ahead on the old tree until new_left is within its interval + while old_right <= new_left: + old_tree = next(old_trees) + old_left, old_right = old_tree.get_interval() + # If the TMRCA of all pairs of samples is the same, then we have the + # same information. We limit this to at most 500 pairs + pairs = itertools.islice(itertools.combinations(sample, 2), 500) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = old_tree.get_mrca(*pair) + mrca2 = new_tree.get_mrca(*mapped_pair) + self.assertEqual(mrca2, node_map[mrca1]) + if haplotypes: + orig_haps = list(ts.haplotypes()) + simp_haps = list(new_ts.haplotypes()) + for i, j in enumerate(sample): + self.assertEqual(orig_haps[j], simp_haps[i]) + + def test_partial_non_sample_external_nodes(self): + # A somewhat more complicated test case with a partially specified, + # non-sampled tip. + # + # Here is the situation: + # + # 1.0 7 + # 0.7 / \ 6 + # / \ / \ + # 0.5 / 5 5 / 5 + # / / \ / \ / / \ + # 0.4 / / 4 / 4 / / 4 + # / / / \ / / \ / / / \ + # / / 3 \ / / \ / / 3 \ + # / / \ / / \ / / \ + # 0.0 0 1 2 1 0 2 0 1 2 + # + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 0.2 # Non sample leaf + 4 0 0.4 + 5 0 0.5 + 6 0 0.7 + 7 0 1.0 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.0 0.2 7 0,5 + """) + true_trees = [ + {0: 7, 1: 5, 2: 4, 3: 4, 4: 5, 5: 7, 6: -1, 7: -1}, + {0: 4, 1: 5, 2: 4, 3: -1, 4: 5, 5: -1, 6: -1, 7: -1}, + {0: 6, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: -1, 7: -1}] + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + tree_dicts = [t.parent_dict for t in ts.trees()] + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.num_nodes, 8) + # check topologies agree: + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + # check .simplify() works here + self.verify_simplify_topology(ts, [0, 1, 2]) + + def test_partial_non_sample_external_nodes_2(self): + # The same situation as above, but partial tip is labeled '7' not '3': + # + # 1.0 6 + # 0.7 / \ 5 + # / \ / \ + # 0.5 / 4 4 / 4 + # / / \ / \ / / \ + # 0.4 / / 3 / 3 / / 3 + # / / / \ / / \ / / / \ + # / / 7 \ / / \ / / 7 \ + # / / \ / / \ / / \ + # 0.0 0 1 2 1 0 2 0 1 2 + # + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 0.4 + 4 0 0.5 + 5 0 0.7 + 6 0 1.0 + 7 0 0 # Non sample leaf + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 3 2,7 + 0.2 0.8 3 0,2 + 0.8 1.0 3 2,7 + 0.0 0.2 4 1,3 + 0.2 0.8 4 1,3 + 0.8 1.0 4 1,3 + 0.8 1.0 5 0,4 + 0.0 0.2 6 0,4 + """) + true_trees = [ + {0: 6, 1: 4, 2: 3, 3: 4, 4: 6, 5: -1, 6: -1, 7: 3}, + {0: 3, 1: 4, 2: 3, 3: 4, 4: -1, 5: -1, 6: -1, 7: -1}, + {0: 5, 1: 4, 2: 3, 3: 4, 4: 5, 5: -1, 6: -1, 7: 3}] + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + tree_dicts = [t.parent_dict for t in ts.trees()] + # sample size check works here since 7 > 3 + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.num_nodes, 8) + # check topologies agree: + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + self.verify_simplify_topology(ts, [0, 1, 2]) + + def test_single_offspring_records(self): + # Here we have inserted a single-offspring record + # (for 6 on the left segment): + # + # 1.0 7 + # 0.7 / 6 6 + # / \ / \ + # 0.5 / 5 5 / 5 + # / / \ / \ / / \ + # 0.4 / / 4 / 4 / / 4 + # 0.3 / / / \ / / \ / / / \ + # / / 3 \ / / \ / / 3 \ + # / / \ / / \ / / \ + # 0.0 0 1 2 1 0 2 0 1 2 + # + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 0 # Non sample leaf + 4 0 0.4 + 5 0 0.5 + 6 0 0.7 + 7 0 1.0 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.0 0.2 6 5 + 0.0 0.2 7 0,6 + """) + ts = tskit.load_text(nodes, edges, strict=False) + true_trees = [ + {0: 7, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: 7, 7: -1}, + {0: 4, 1: 5, 2: 4, 3: -1, 4: 5, 5: -1, 6: -1, 7: -1}, + {0: 6, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: -1, 7: -1}] + tree_dicts = [t.parent_dict for t in ts.trees()] + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.num_nodes, 8) + # check topologies agree: + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + self.verify_simplify_topology(ts, [0, 1, 2]) + + def test_many_single_offspring(self): + # a more complex test with single offspring + # With `(i,j,x)->k` denoting that individual `k` inherits from `i` on `[0,x)` + # and from `j` on `[x,1)`: + # 1. Begin with an individual `3` (and another anonymous one) at `t=0`. + # 2. `(3,?,1.0)->4` and `(3,?,1.0)->5` at `t=1` + # 3. `(4,3,0.9)->6` and `(3,5,0.1)->7` and then `3` dies at `t=2` + # 4. `(6,7,0.7)->8` at `t=3` + # 5. `(8,6,0.8)->9` and `(7,8,0.2)->10` at `t=4`. + # 6. `(3,9,0.6)->0` and `(9,10,0.5)->1` and `(10,4,0.4)->2` at `t=5`. + # 7. We sample `0`, `1`, and `2`. + # Here are the trees: + # t | | | | + # + # 0 --3-- | --3-- | --3-- | --3-- | --3-- + # / | \ | / | \ | / \ | / \ | / \ + # 1 4 | 5 | 4 * 5 | 4 5 | 4 5 | 4 5 + # |\ / \ /| | |\ \ | |\ / | |\ / | |\ /| + # 2 | 6 7 | | | 6 7 | | 6 7 | | 6 7 | | 6 7 | + # | |\ /| | | | \ * | | \ | | | * | | * | ... + # 3 | | 8 | | | | 8 | | * 8 * | | 8 | | 8 | + # | |/ \| | | | / | | | / | | | * * | | / \ | + # 4 | 9 10 | | | 9 10 | | 9 10 | | 9 10 | | 9 10 | + # |/ \ / \| | | \ * | | \ \ | | \ * | | \ | + # 5 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 + # + # | 0.0 - 0.1 | 0.1 - 0.2 | 0.2 - 0.4 | 0.4 - 0.5 + # ... continued: + # t | | | | + # + # 0 --3-- | --3-- | --3-- | --3-- | --3-- + # / \ | / \ | / \ | / \ | / | \ + # 1 4 5 | 4 5 | 4 5 | 4 5 | 4 | 5 + # |\ /| | \ /| | \ /| | \ /| | / /| + # 2 | 6 7 | | 6 7 | | 6 7 | | 6 7 | | 6 7 | + # | \ | | \ | | / | | | / | | | / | + # 3 ... | 8 | | 8 | | 8 | | | 8 | | | 8 | + # | / \ | | / \ | | / \ | | | \ | | | \ | + # 4 | 9 10 | | 9 10 | | 9 10 | | 9 10 | | 9 10 | + # | / | | / / | | / / | | / / | | / / | + # 5 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 + # + # 0.5 - 0.6 | 0.6 - 0.7 | 0.7 - 0.8 | 0.8 - 0.9 | 0.9 - 1.0 + + true_trees = [ + {0: 4, 1: 9, 2: 10, 3: -1, 4: 3, 5: 3, 6: 4, 7: 3, 8: 6, 9: 8, 10: 7}, + {0: 4, 1: 9, 2: 10, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 6, 9: 8, 10: 7}, + {0: 4, 1: 9, 2: 10, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 6, 9: 8, 10: 8}, + {0: 4, 1: 9, 2: 5, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 6, 9: 8, 10: 8}, + {0: 4, 1: 10, 2: 5, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 6, 9: 8, 10: 8}, + {0: 9, 1: 10, 2: 5, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 6, 9: 8, 10: 8}, + {0: 9, 1: 10, 2: 5, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 7, 9: 8, 10: 8}, + {0: 9, 1: 10, 2: 5, 3: -1, 4: 3, 5: 3, 6: 4, 7: 5, 8: 7, 9: 6, 10: 8}, + {0: 9, 1: 10, 2: 5, 3: -1, 4: 3, 5: 3, 6: 3, 7: 5, 8: 7, 9: 6, 10: 8} + ] + true_haplotypes = ['0100', '0001', '1110'] + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 5 + 4 0 4 + 5 0 4 + 6 0 3 + 7 0 3 + 8 0 2 + 9 0 1 + 10 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0.5 1.0 10 1 + 0.0 0.4 10 2 + 0.6 1.0 9 0 + 0.0 0.5 9 1 + 0.8 1.0 8 10 + 0.2 0.8 8 9,10 + 0.0 0.2 8 9 + 0.7 1.0 7 8 + 0.0 0.2 7 10 + 0.8 1.0 6 9 + 0.0 0.7 6 8 + 0.4 1.0 5 2,7 + 0.1 0.4 5 7 + 0.6 0.9 4 6 + 0.0 0.6 4 0,6 + 0.9 1.0 3 4,5,6 + 0.1 0.9 3 4,5 + 0.0 0.1 3 4,5,7 + """) + sites = six.StringIO("""\ + position ancestral_state + 0.05 0 + 0.15 0 + 0.25 0 + 0.4 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 7 1 -1 + 0 10 0 0 + 0 2 1 1 + 1 0 1 -1 + 1 10 1 -1 + 2 8 1 -1 + 2 9 0 5 + 2 10 0 5 + 2 2 1 7 + 3 8 1 -1 + """) + ts = tskit.load_text(nodes, edges, sites, mutations, strict=False) + tree_dicts = [t.parent_dict for t in ts.trees()] + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, len(true_trees)) + self.assertEqual(ts.num_nodes, 11) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + # check topologies agree: + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + for j, x in enumerate(ts.haplotypes()): + self.assertEqual(x, true_haplotypes[j]) + self.verify_simplify_topology(ts, [0, 1, 2], haplotypes=True) + self.verify_simplify_topology(ts, [1, 0, 2], haplotypes=True) + self.verify_simplify_topology(ts, [0, 1], haplotypes=False) + self.verify_simplify_topology(ts, [1, 2], haplotypes=False) + self.verify_simplify_topology(ts, [2, 0], haplotypes=False) + + def test_tricky_switches(self): + # suppose the topology has: + # left right parent child + # 0.0 0.5 6 0,1 + # 0.5 1.0 6 4,5 + # 0.0 0.4 7 2,3 + # + # -------------------------- + # + # 12 . 12 . 12 . + # / \ . / \ . / \ . + # 11 \ . / \ . / \ . + # / \ \ . / 10 . / 10 . + # / \ \ . / / \ . / / \ . + # 6 7 8 . 6 9 8 . 6 9 8 . + # / \ / \ /\ . / \ / \ /\ . / \ / \ /\ . + # 0 1 2 3 4 5 . 0 1 2 3 4 5 . 4 5 2 3 0 1 . + # . . . + # 0.0 0.4 0.5 1.0 + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 1 0 + 5 1 0 + 6 0 1 + 7 0 1 + 8 0 1 + 9 0 1 + 10 0 2 + 11 0 3 + 12 0 4 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.5 6 0 + 0.0 0.5 6 1 + 0.5 1.0 6 4 + 0.5 1.0 6 5 + 0.0 0.4 7 2,3 + 0.5 1.0 8 0 + 0.5 1.0 8 1 + 0.0 0.5 8 4 + 0.0 0.5 8 5 + 0.4 1.0 9 2,3 + 0.4 1.0 10 8,9 + 0.0 0.4 11 6,7 + 0.4 1.0 12 6 + 0.0 0.4 12 8 + 0.4 1.0 12 10 + 0.0 0.4 12 11 + """) + true_trees = [ + {0: 6, 1: 6, 2: 7, 3: 7, 4: 8, 5: 8, 6: 11, + 7: 11, 8: 12, 9: -1, 10: -1, 11: 12, 12: -1}, + {0: 6, 1: 6, 2: 9, 3: 9, 4: 8, 5: 8, 6: 12, + 7: -1, 8: 10, 9: 10, 10: 12, 11: -1, 12: -1}, + {0: 8, 1: 8, 2: 9, 3: 9, 4: 6, 5: 6, 6: 12, + 7: -1, 8: 10, 9: 10, 10: 12, 11: -1, 12: -1} + ] + ts = tskit.load_text(nodes, edges, strict=False) + tree_dicts = [t.parent_dict for t in ts.trees()] + self.assertEqual(ts.sample_size, 6) + self.assertEqual(ts.num_trees, len(true_trees)) + self.assertEqual(ts.num_nodes, 13) + self.assertEqual(len(list(ts.edge_diffs())), ts.num_trees) + # check topologies agree: + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + self.verify_simplify_topology(ts, [0, 2]) + self.verify_simplify_topology(ts, [0, 4]) + self.verify_simplify_topology(ts, [2, 4]) + + def test_tricky_simplify(self): + # Continue as above but invoke simplfy: + # + # 12 . 12 . + # / \ . / \ . + # 11 \ . 11 \ . + # / \ \ . / \ \ . + # 13 \ \ . / 15 \ . + # / \ \ \ . / / \ \ . + # 6 14 7 8 . 6 14 7 8 . + # / \ / \ /\ . / \ / \ /\ . + # 0 1 2 3 4 5 . 0 1 2 3 4 5 . + # . . + # 0.0 0.1 0.4 + # + # . 12 . 12 . + # . / \ . / \ . + # . / \ . / \ . + # . / 10 . / 10 . + # . / / \ . / / \ . + # . 6 9 8 . 6 9 8 . + # . / \ / \ /\ . / \ / \ /\ . + # . 0 1 2 3 4 5 . 4 5 2 3 0 1 . + # . . . + # 0.4 0.5 1.0 + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 1 0 + 5 1 0 + 6 0 1 + 7 0 1 + 8 0 1 + 9 0 1 + 10 0 2 + 11 0 3 + 12 0 4 + 13 0 2 + 14 0 1 + 15 0 2 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.5 6 0,1 + 0.5 1.0 6 4,5 + 0.0 0.4 7 2,3 + 0.0 0.5 8 4,5 + 0.5 1.0 8 0,1 + 0.4 1.0 9 2,3 + 0.4 1.0 10 8,9 + 0.0 0.1 13 6,14 + 0.1 0.4 15 7,14 + 0.0 0.1 11 7,13 + 0.1 0.4 11 6,15 + 0.0 0.4 12 8,11 + 0.4 1.0 12 6,10 + """) + true_trees = [ + {0: 6, 1: 6, 2: 7, 3: 7, 4: 8, 5: 8, 6: 11, + 7: 11, 8: 12, 9: -1, 10: -1, 11: 12, 12: -1}, + {0: 6, 1: 6, 2: 9, 3: 9, 4: 8, 5: 8, 6: 12, + 7: -1, 8: 10, 9: 10, 10: 12, 11: -1, 12: -1}, + {0: 8, 1: 8, 2: 9, 3: 9, 4: 6, 5: 6, 6: 12, + 7: -1, 8: 10, 9: 10, 10: 12, 11: -1, 12: -1} + ] + big_ts = tskit.load_text(nodes, edges, strict=False) + self.assertEqual(big_ts.num_trees, 1 + len(true_trees)) + self.assertEqual(big_ts.num_nodes, 16) + ts, node_map = big_ts.simplify(map_nodes=True) + self.assertEqual(list(node_map[:6]), list(range(6))) + self.assertEqual(ts.sample_size, 6) + self.assertEqual(ts.num_nodes, 13) + + def test_ancestral_samples(self): + # Check that specifying samples to be not at time 0.0 works. + # + # 1.0 7 + # 0.7 / \ 8 6 + # / \ / \ / \ + # 0.5 / 5 / 5 / 5 + # / / \ / / \ / / \ + # 0.4 / / 4 / / 4 / / 4 + # / / / \ / / / \ / / / \ + # 0.2 / / 3 \ 3 / / \ / / 3 \ + # / / * \ * / / \ / / * \ + # 0.0 0 1 2 1 0 2 0 1 2 + # * * * * * * + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + # + # Simplified, keeping [1,2,3] + # + # 1.0 + # 0.7 5 + # / \ + # 0.5 4 / 4 4 + # / \ / / \ / \ + # 0.4 / 3 / / 3 / 3 + # / / \ / / \ / / \ + # 0.2 / 2 \ 2 / \ / 2 \ + # / * \ * / \ / * \ + # 0.0 0 1 0 1 0 1 + # * * * * * * + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 1 0 + 2 1 0 + 3 1 0.2 + 4 0 0.4 + 5 0 0.5 + 6 0 0.7 + 7 0 1.0 + 8 0 0.8 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.2 0.8 8 3,5 + 0.0 0.2 7 0,5 + """) + first_ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ts, node_map = first_ts.simplify(map_nodes=True) + true_trees = [ + {0: 7, 1: 5, 2: 4, 3: 4, 4: 5, 5: 7, 6: -1, 7: -1}, + {0: 4, 1: 5, 2: 4, 3: 8, 4: 5, 5: 8, 6: -1, 7: -1}, + {0: 6, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: -1, 7: -1}] + # maps [1,2,3] -> [0,1,2] + self.assertEqual(node_map[1], 0) + self.assertEqual(node_map[2], 1) + self.assertEqual(node_map[3], 2) + true_simplified_trees = [ + {0: 4, 1: 3, 2: 3, 3: 4}, + {0: 4, 1: 4, 2: 5, 4: 5}, + {0: 4, 1: 3, 2: 3, 3: 4}] + self.assertEqual(first_ts.sample_size, 3) + self.assertEqual(ts.sample_size, 3) + self.assertEqual(first_ts.num_trees, 3) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(first_ts.num_nodes, 9) + self.assertEqual(ts.num_nodes, 6) + self.assertEqual(first_ts.node(3).time, 0.2) + self.assertEqual(ts.node(2).time, 0.2) + # check topologies agree: + tree_dicts = [t.parent_dict for t in first_ts.trees()] + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + tree_simplified_dicts = [t.parent_dict for t in ts.trees()] + for a, t in zip(true_simplified_trees, tree_simplified_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + # check .simplify() works here + self.verify_simplify_topology(first_ts, [1, 2, 3]) + + def test_all_ancestral_samples(self): + # Check that specifying samples all to be not at time 0.0 works. + # + # 1.0 7 + # 0.7 / \ 8 6 + # / \ / \ / \ + # 0.5 / 5 / 5 / 5 + # / / \ / / \ / / \ + # 0.4 / / 4 / / 4 / / 4 + # / / / \ / / / \ / / / \ + # 0.2 / / 3 \ 3 / / \ / / 3 \ + # / 1 * 2 * 1 / 2 / 1 * 2 + # 0.0 0 * * * 0 * 0 * * + # + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 1 0.1 + 2 1 0.1 + 3 1 0.2 + 4 0 0.4 + 5 0 0.5 + 6 0 0.7 + 7 0 1.0 + 8 0 0.8 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.2 0.8 8 3,5 + 0.0 0.2 7 0,5 + """) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + true_trees = [ + {0: 7, 1: 5, 2: 4, 3: 4, 4: 5, 5: 7, 6: -1, 7: -1}, + {0: 4, 1: 5, 2: 4, 3: 8, 4: 5, 5: 8, 6: -1, 7: -1}, + {0: 6, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: -1, 7: -1}] + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.num_nodes, 9) + self.assertEqual(ts.node(0).time, 0.0) + self.assertEqual(ts.node(1).time, 0.1) + self.assertEqual(ts.node(2).time, 0.1) + self.assertEqual(ts.node(3).time, 0.2) + # check topologies agree: + tree_dicts = [t.parent_dict for t in ts.trees()] + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + # check .simplify() works here + self.verify_simplify_topology(ts, [1, 2, 3]) + + def test_internal_sampled_node(self): + # 1.0 7 + # 0.7 / \ 8 6 + # / \ / \ / \ + # 0.5 / 5 / 5 / 5 + # / /*\ / /*\ / /*\ + # 0.4 / / 4 / / 4 / / 4 + # / / / \ / / / \ / / / \ + # 0.2 / / 3 \ 3 / / \ / / 3 \ + # / 1 * 2 * 1 / 2 / 1 * 2 + # 0.0 0 * * * 0 * 0 * * + # + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + nodes = six.StringIO("""\ + id is_sample time + 0 0 0 + 1 1 0.1 + 2 1 0.1 + 3 1 0.2 + 4 0 0.4 + 5 1 0.5 + 6 0 0.7 + 7 0 1.0 + 8 0 0.8 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.2 0.8 8 3,5 + 0.0 0.2 7 0,5 + """) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + true_trees = [ + {0: 7, 1: 5, 2: 4, 3: 4, 4: 5, 5: 7, 6: -1, 7: -1}, + {0: 4, 1: 5, 2: 4, 3: 8, 4: 5, 5: 8, 6: -1, 7: -1}, + {0: 6, 1: 5, 2: 4, 3: 4, 4: 5, 5: 6, 6: -1, 7: -1}] + self.assertEqual(ts.sample_size, 4) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.num_nodes, 9) + self.assertEqual(ts.node(0).time, 0.0) + self.assertEqual(ts.node(1).time, 0.1) + self.assertEqual(ts.node(2).time, 0.1) + self.assertEqual(ts.node(3).time, 0.2) + # check topologies agree: + tree_dicts = [t.parent_dict for t in ts.trees()] + for a, t in zip(true_trees, tree_dicts): + for k in a.keys(): + if k in t.keys(): + self.assertEqual(t[k], a[k]) + else: + self.assertEqual(a[k], tskit.NULL) + # check .simplify() works here + self.verify_simplify_topology(ts, [1, 2, 3]) + self.check_num_samples( + ts, + [(0, 5, 4), (0, 2, 1), (0, 7, 4), (0, 4, 2), + (1, 4, 1), (1, 5, 3), (1, 8, 4), (1, 0, 0), + (2, 5, 4), (2, 1, 1)]) + self.check_num_tracked_samples( + ts, [1, 2, 5], + [(0, 5, 3), (0, 2, 1), (0, 7, 3), (0, 4, 1), + (1, 4, 1), (1, 5, 3), (1, 8, 3), (1, 0, 0), + (2, 5, 3), (2, 1, 1)]) + self.check_sample_iterator( + ts, + [(0, 0, []), (0, 5, [5, 1, 2, 3]), (0, 4, [2, 3]), + (1, 5, [5, 1, 2]), (2, 4, [2, 3])]) + # pedantically check the SparseTree methods on the second tree + tst = ts.trees() + t = next(tst) + t = next(tst) + self.assertEqual(t.branch_length(1), 0.4) + self.assertEqual(t.is_internal(0), False) + self.assertEqual(t.is_leaf(0), True) + self.assertEqual(t.is_sample(0), False) + self.assertEqual(t.is_internal(1), False) + self.assertEqual(t.is_leaf(1), True) + self.assertEqual(t.is_sample(1), True) + self.assertEqual(t.is_internal(5), True) + self.assertEqual(t.is_leaf(5), False) + self.assertEqual(t.is_sample(5), True) + self.assertEqual(t.is_internal(4), True) + self.assertEqual(t.is_leaf(4), False) + self.assertEqual(t.is_sample(4), False) + self.assertEqual(t.root, 8) + self.assertEqual(t.mrca(0, 1), 5) + self.assertEqual(t.sample_size, 4) + + +class TestBadTrees(unittest.TestCase): + """ + Tests for bad tree sequence topologies that can only be detected when we + try to create trees. + """ + def test_simplest_contradictory_children(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 2 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 1.0 2 0 + 0.0 1.0 3 0 + """) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + self.assertRaises(_tskit.LibraryError, list, ts.trees()) + + def test_partial_overlap_contradictory_children(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 2 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 1.0 2 0,1 + 0.5 1.0 3 0 + """) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + self.assertRaises(_tskit.LibraryError, list, ts.trees()) + + +class TestSimplify(unittest.TestCase): + """ + Tests that the implementations of simplify() do what they are supposed to. + """ + random_seed = 23 + # + # 8 + # / \ + # / \ + # / \ + # 7 \ + # / \ 6 + # / 5 / \ + # / / \ / \ + # 4 0 1 2 3 + small_tree_ex_nodes = """\ + id is_sample population time + 0 1 0 0.00000000000000 + 1 1 0 0.00000000000000 + 2 1 0 0.00000000000000 + 3 1 0 0.00000000000000 + 4 1 0 0.00000000000000 + 5 0 0 0.14567111023387 + 6 0 0 0.21385545626353 + 7 0 0 0.43508024345063 + 8 0 0 1.60156352971203 + """ + small_tree_ex_edges = """\ + id left right parent child + 0 0.00000000 1.00000000 5 0,1 + 1 0.00000000 1.00000000 6 2,3 + 2 0.00000000 1.00000000 7 4,5 + 3 0.00000000 1.00000000 8 6,7 + """ + + def do_simplify( + self, ts, samples=None, compare_lib=True, filter_sites=True, + filter_populations=True, filter_individuals=True): + """ + Runs the Python test implementation of simplify. + """ + if samples is None: + samples = ts.samples() + s = tests.Simplifier( + ts, samples, filter_sites=filter_sites, + filter_populations=filter_populations, filter_individuals=filter_individuals) + new_ts, node_map = s.simplify() + if compare_lib: + sts, lib_node_map1 = ts.simplify( + samples, + filter_sites=filter_sites, + filter_individuals=filter_individuals, + filter_populations=filter_populations, + map_nodes=True) + lib_tables1 = sts.dump_tables() + + lib_tables2 = ts.dump_tables() + lib_node_map2 = lib_tables2.simplify( + samples, + filter_sites=filter_sites, + filter_individuals=filter_individuals, + filter_populations=filter_populations) + + py_tables = new_ts.dump_tables() + for lib_tables, lib_node_map in [ + (lib_tables1, lib_node_map1), (lib_tables2, lib_node_map2)]: + # print("lib = ") + # print(lib_tables.nodes) + # print(lib_tables.edges) + # print("py = ") + # print(py_tables.nodes) + # print(py_tables.edges) + + self.assertEqual(lib_tables.nodes, py_tables.nodes) + self.assertEqual(lib_tables.edges, py_tables.edges) + self.assertEqual(lib_tables.migrations, py_tables.migrations) + self.assertEqual(lib_tables.sites, py_tables.sites) + self.assertEqual(lib_tables.mutations, py_tables.mutations) + self.assertEqual(lib_tables.individuals, py_tables.individuals) + self.assertEqual(lib_tables.populations, py_tables.populations) + self.assertTrue(all(node_map == lib_node_map)) + return new_ts, node_map + + def verify_single_childified(self, ts): + """ + Modify the specified tree sequence so that it has lots of unary + nodes. Run simplify and verify we get the same tree sequence back. + """ + ts_single = tsutil.single_childify(ts) + tss, node_map = self.do_simplify(ts_single) + # All original nodes should still be present. + for u in range(ts.num_samples): + self.assertEqual(u, node_map[u]) + # All introduced nodes should be mapped to null. + for u in range(ts.num_samples, ts_single.num_samples): + self.assertEqual(node_map[u], tskit.NULL) + t1 = ts.dump_tables() + t2 = tss.dump_tables() + self.assertEqual(t1.nodes, t2.nodes) + self.assertEqual(t1.edges, t2.edges) + self.assertEqual(t1.sites, t2.sites) + self.assertEqual(t1.mutations, t2.mutations) + + def verify_multiroot_internal_samples(self, ts): + ts_multiroot = tsutil.decapitate(ts, ts.num_edges // 2) + ts1 = tsutil.jiggle_samples(ts_multiroot) + ts2, node_map = self.do_simplify(ts1) + self.assertGreaterEqual(ts1.num_trees, ts2.num_trees) + trees2 = ts2.trees() + t2 = next(trees2) + for t1 in ts1.trees(): + self.assertTrue(t2.interval[0] <= t1.interval[0]) + self.assertTrue(t2.interval[1] >= t1.interval[1]) + pairs = itertools.combinations(ts1.samples(), 2) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = t1.get_mrca(*pair) + mrca2 = t2.get_mrca(*mapped_pair) + if mrca1 == tskit.NULL: + assert mrca2 == tskit.NULL + else: + self.assertEqual(node_map[mrca1], mrca2) + if t2.interval[1] == t1.interval[1]: + t2 = next(trees2, None) + + def test_single_tree(self): + ts = msprime.simulate(10, random_seed=self.random_seed) + self.verify_single_childified(ts) + self.verify_multiroot_internal_samples(ts) + + def test_single_tree_mutations(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=self.random_seed) + self.assertGreater(ts.num_sites, 1) + self.do_simplify(ts) + self.verify_single_childified(ts) + + def test_many_trees_mutations(self): + ts = msprime.simulate( + 10, recombination_rate=1, mutation_rate=10, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + self.assertGreater(ts.num_sites, 2) + self.do_simplify(ts) + self.verify_single_childified(ts) + + def test_many_trees(self): + ts = msprime.simulate(5, recombination_rate=4, random_seed=self.random_seed) + self.assertGreater(ts.num_trees, 2) + self.verify_single_childified(ts) + self.verify_multiroot_internal_samples(ts) + + def test_small_tree_internal_samples(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + # The parent of samples 0 and 1 is 5. Change this to an internal sample + # and set 0 and 1 to be unsampled. + flags[0] = 0 + flags[1] = 0 + flags[5] = tskit.NODE_IS_SAMPLE + tables.nodes.set_columns(flags=flags, time=nodes.time) + ts = tables.tree_sequence() + self.assertEqual(ts.sample_size, 4) + tss, node_map = self.do_simplify(ts, [3, 5]) + self.assertEqual(node_map[3], 0) + self.assertEqual(node_map[5], 1) + self.assertEqual(tss.num_nodes, 3) + self.assertEqual(tss.num_edges, 2) + + def test_small_tree_linear_samples(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + # 7 is above 0. These are the only two samples + flags[:] = 0 + flags[0] = tskit.NODE_IS_SAMPLE + flags[7] = tskit.NODE_IS_SAMPLE + nodes.set_columns(flags=flags, time=nodes.time) + ts = tables.tree_sequence() + self.assertEqual(ts.sample_size, 2) + tss, node_map = self.do_simplify(ts, [0, 7]) + self.assertEqual(node_map[0], 0) + self.assertEqual(node_map[7], 1) + self.assertEqual(tss.num_nodes, 2) + self.assertEqual(tss.num_edges, 1) + t = next(tss.trees()) + self.assertEqual(t.parent_dict, {0: 1}) + + def test_small_tree_internal_and_external_samples(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + # 7 is above 0 and 1. + flags[:] = 0 + flags[0] = tskit.NODE_IS_SAMPLE + flags[1] = tskit.NODE_IS_SAMPLE + flags[7] = tskit.NODE_IS_SAMPLE + nodes.set_columns(flags=flags, time=nodes.time) + ts = tables.tree_sequence() + self.assertEqual(ts.sample_size, 3) + tss, node_map = self.do_simplify(ts, [0, 1, 7]) + self.assertEqual(node_map[0], 0) + self.assertEqual(node_map[1], 1) + self.assertEqual(node_map[7], 2) + self.assertEqual(tss.num_nodes, 4) + self.assertEqual(tss.num_edges, 3) + t = next(tss.trees()) + self.assertEqual(t.parent_dict, {0: 3, 1: 3, 3: 2}) + + def test_small_tree_mutations(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + # Add some simple mutations here above the nodes we're keeping. + tables.sites.add_row(position=0.25, ancestral_state="0") + tables.sites.add_row(position=0.5, ancestral_state="0") + tables.sites.add_row(position=0.75, ancestral_state="0") + tables.sites.add_row(position=0.8, ancestral_state="0") + tables.mutations.add_row(site=0, node=0, derived_state="1") + tables.mutations.add_row(site=1, node=2, derived_state="1") + tables.mutations.add_row(site=2, node=7, derived_state="1") + tables.mutations.add_row(site=3, node=0, derived_state="1") + ts = tables.tree_sequence() + self.assertEqual(ts.num_sites, 4) + self.assertEqual(ts.num_mutations, 4) + tss = self.do_simplify(ts, [0, 2])[0] + self.assertEqual(tss.sample_size, 2) + self.assertEqual(tss.num_mutations, 4) + self.assertEqual(list(tss.haplotypes()), ["1011", "0100"]) + + def test_small_tree_filter_zero_mutations(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + ts = tsutil.insert_branch_sites(ts) + self.assertEqual(ts.num_sites, 8) + self.assertEqual(ts.num_mutations, 8) + tss, _ = self.do_simplify(ts, [4, 0, 1], filter_sites=True) + self.assertEqual(tss.num_sites, 5) + self.assertEqual(tss.num_mutations, 5) + tss, _ = self.do_simplify(ts, [4, 0, 1], filter_sites=False) + self.assertEqual(tss.num_sites, 8) + self.assertEqual(tss.num_mutations, 5) + + def test_small_tree_fixed_sites(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + # Add some simple mutations that will be fixed after simplify + tables.sites.add_row(position=0.25, ancestral_state="0") + tables.sites.add_row(position=0.5, ancestral_state="0") + tables.sites.add_row(position=0.75, ancestral_state="0") + tables.mutations.add_row(site=0, node=2, derived_state="1") + tables.mutations.add_row(site=1, node=3, derived_state="1") + tables.mutations.add_row(site=2, node=6, derived_state="1") + ts = tables.tree_sequence() + self.assertEqual(ts.num_sites, 3) + self.assertEqual(ts.num_mutations, 3) + tss, _ = self.do_simplify(ts, [4, 1]) + self.assertEqual(tss.sample_size, 2) + self.assertEqual(tss.num_mutations, 0) + self.assertEqual(list(tss.haplotypes()), ["", ""]) + + def test_small_tree_mutations_over_root(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + tables.sites.add_row(position=0.25, ancestral_state="0") + tables.mutations.add_row(site=0, node=8, derived_state="1") + ts = tables.tree_sequence() + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 1) + for filt in [True, False]: + tss, _ = self.do_simplify(ts, [0, 1], filter_sites=filt) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 1) + + def test_small_tree_recurrent_mutations(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + # Add recurrent mutation on the root branches + tables.sites.add_row(position=0.25, ancestral_state="0") + tables.mutations.add_row(site=0, node=6, derived_state="1") + tables.mutations.add_row(site=0, node=7, derived_state="1") + ts = tables.tree_sequence() + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 2) + tss = self.do_simplify(ts, [4, 3])[0] + self.assertEqual(tss.sample_size, 2) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 2) + self.assertEqual(list(tss.haplotypes()), ["1", "1"]) + + def best_small_tree_back_mutations(self): + ts = tskit.load_text( + nodes=six.StringIO(self.small_tree_ex_nodes), + edges=six.StringIO(self.small_tree_ex_edges), strict=False) + tables = ts.dump_tables() + # Add a chain of mutations + tables.sites.add_row(position=0.25, ancestral_state="0") + tables.mutations.add_row(site=0, node=7, derived_state="1") + tables.mutations.add_row(site=0, node=5, derived_state="0") + tables.mutations.add_row(site=0, node=1, derived_state="1") + ts = tables.tree_sequence() + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 3) + self.assertEqual(list(ts.haplotypes()), ["0", "1", "0", "0", "1"]) + # First check if we simplify for all samples and keep original state. + tss = self.do_simplify(ts, [0, 1, 2, 3, 4]) + self.assertEqual(tss.sample_size, 5) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 3) + self.assertEqual(list(tss.haplotypes()), ["0", "1", "0", "0", "1"]) + + # The ancestral state above 5 should be 0. + tss = self.do_simplify(ts, [0, 1]) + self.assertEqual(tss.sample_size, 2) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 1) + self.assertEqual(list(tss.haplotypes()), ["0", "1"]) + + # The ancestral state above 7 should be 1. + tss = self.do_simplify(ts, [4, 0, 1]) + self.assertEqual(tss.sample_size, 3) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 2) + self.assertEqual(list(tss.haplotypes()), ["1", "0", "1"]) + + def test_overlapping_unary_edges(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 2 2 0 + 1 3 2 1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + self.assertEqual(ts.sample_size, 2) + self.assertEqual(ts.num_trees, 3) + self.assertEqual(ts.sequence_length, 3) + tss, node_map = self.do_simplify(ts) + self.assertEqual(list(node_map), [0, 1, 2]) + trees = [{}, {0: 2, 1: 2}, {}] + for t in tss.trees(): + self.assertEqual(t.parent_dict, trees[t.index]) + + def test_overlapping_unary_edges_internal_samples(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 2 2 0 + 1 3 2 1 + """) + ts = tskit.load_text(nodes, edges, strict=False) + self.assertEqual(ts.sample_size, 3) + self.assertEqual(ts.num_trees, 3) + trees = [{0: 2}, {0: 2, 1: 2}, {1: 2}] + for t in ts.trees(): + self.assertEqual(t.parent_dict, trees[t.index]) + tss, node_map = self.do_simplify(ts) + self.assertEqual(list(node_map), [0, 1, 2]) + + def test_isolated_samples(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 1 + 2 1 2 + """) + edges = six.StringIO("""\ + left right parent child + """) + ts = tskit.load_text(nodes, edges, sequence_length=1, strict=False) + self.assertEqual(ts.num_samples, 3) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_nodes, 3) + tss, node_map = self.do_simplify(ts, compare_lib=True) + self.assertEqual(ts.tables.nodes, tss.tables.nodes) + self.assertEqual(ts.tables.edges, tss.tables.edges) + self.assertEqual(list(node_map), [0, 1, 2]) + + def test_internal_samples(self): + nodes = six.StringIO("""\ + id is_sample population time + 0 1 -1 1.00000000000000 + 1 0 -1 1.00000000000000 + 2 1 -1 1.00000000000000 + 3 0 -1 1.31203521181726 + 4 0 -1 2.26776380586006 + 5 1 -1 0.00000000000000 + + """) + edges = six.StringIO("""\ + id left right parent child + 0 0.62185118 1.00000000 1 5 + 1 0.00000000 0.62185118 2 5 + 2 0.00000000 1.00000000 3 0,2 + 3 0.00000000 1.00000000 4 1,3 + """) + + ts = tskit.load_text(nodes, edges, strict=False) + tss, node_map = self.do_simplify(ts, [5, 2, 0], compare_lib=True) + self.assertEqual(node_map[5], 0) + self.assertEqual(node_map[2], 1) + self.assertEqual(node_map[0], 2) + self.assertEqual(node_map[1], -1) + self.assertEqual(node_map[3], 3) + self.assertEqual(node_map[4], 4) + self.assertEqual(tss.sample_size, 3) + self.assertEqual(tss.num_trees, 2) + trees = [{0: 1, 1: 3, 2: 3}, {0: 4, 1: 3, 2: 3, 3: 4}] + for t in tss.trees(): + self.assertEqual(t.parent_dict, trees[t.index]) + + def test_many_mutations_over_single_sample_ancestral_state(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 1 0 + """) + sites = six.StringIO("""\ + position ancestral_state + 0 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 0 1 -1 + 0 0 0 0 + """) + ts = tskit.load_text( + nodes, edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.sample_size, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 2) + tss, node_map = self.do_simplify(ts) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 2) + self.assertEqual(list(tss.haplotypes()), ["0"]) + + def test_many_mutations_over_single_sample_derived_state(self): + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 0 1 + """) + edges = six.StringIO("""\ + left right parent child + 0 1 1 0 + """) + sites = six.StringIO("""\ + position ancestral_state + 0 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 0 1 -1 + 0 0 0 0 + 0 0 1 1 + """) + ts = tskit.load_text( + nodes, edges, sites=sites, mutations=mutations, strict=False) + self.assertEqual(ts.sample_size, 1) + self.assertEqual(ts.num_trees, 1) + self.assertEqual(ts.num_sites, 1) + self.assertEqual(ts.num_mutations, 3) + tss, node_map = self.do_simplify(ts) + self.assertEqual(tss.num_sites, 1) + self.assertEqual(tss.num_mutations, 3) + self.assertEqual(list(tss.haplotypes()), ["1"]) + + def test_many_trees_filter_zero_mutations(self): + ts = msprime.simulate(5, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + ts = tsutil.insert_branch_sites(ts) + self.assertEqual(ts.num_sites, ts.num_mutations) + self.assertGreater(ts.num_sites, ts.num_trees) + for filter_sites in [True, False]: + tss, _ = self.do_simplify( + ts, samples=None, filter_sites=filter_sites) + self.assertEqual(ts.num_sites, tss.num_sites) + self.assertEqual(ts.num_mutations, tss.num_mutations) + + def test_many_trees_filter_zero_multichar_mutations(self): + ts = msprime.simulate(5, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + ts = tsutil.insert_multichar_mutations(ts) + self.assertEqual(ts.num_sites, ts.num_trees) + self.assertEqual(ts.num_mutations, ts.num_trees) + for filter_sites in [True, False]: + tss, _ = self.do_simplify(ts, samples=None, filter_sites=filter_sites) + self.assertEqual(ts.num_sites, tss.num_sites) + self.assertEqual(ts.num_mutations, tss.num_mutations) + + def test_simple_population_filter(self): + ts = msprime.simulate(10, random_seed=2) + tables = ts.dump_tables() + tables.populations.add_row(metadata=b"unreferenced") + self.assertEqual(len(tables.populations), 2) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_populations=True) + self.assertEqual(tss.num_populations, 1) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_populations=False) + self.assertEqual(tss.num_populations, 2) + + def test_interleaved_populations_filter(self): + ts = msprime.simulate( + population_configurations=[ + msprime.PopulationConfiguration(), + msprime.PopulationConfiguration(10), + msprime.PopulationConfiguration(), + msprime.PopulationConfiguration()], + random_seed=2) + self.assertEqual(ts.num_populations, 4) + tables = ts.dump_tables() + # Edit the populations so we can identify the rows. + tables.populations.clear() + for j in range(4): + tables.populations.add_row(metadata=bytes([j])) + ts = tables.tree_sequence() + id_map = np.array([-1, 0, -1, -1], dtype=np.int32) + tss, _ = self.do_simplify(ts, filter_populations=True) + self.assertEqual(tss.num_populations, 1) + population = tss.population(0) + self.assertEqual(population.metadata, bytes([1])) + self.assertTrue(np.array_equal( + id_map[ts.tables.nodes.population], tss.tables.nodes.population)) + tss, _ = self.do_simplify(ts, filter_populations=False) + self.assertEqual(tss.num_populations, 4) + + def test_removed_node_population_filter(self): + tables = tskit.TableCollection(1) + tables.populations.add_row(metadata=bytes(0)) + tables.populations.add_row(metadata=bytes(1)) + tables.populations.add_row(metadata=bytes(2)) + tables.nodes.add_row(flags=1, population=0) + # Because flags=0 here, this node will be simplified out and the node + # will disappear. + tables.nodes.add_row(flags=0, population=1) + tables.nodes.add_row(flags=1, population=2) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_populations=True) + self.assertEqual(tss.num_nodes, 2) + self.assertEqual(tss.num_populations, 2) + self.assertEqual(tss.population(0).metadata, bytes(0)) + self.assertEqual(tss.population(1).metadata, bytes(2)) + self.assertEqual(tss.node(0).population, 0) + self.assertEqual(tss.node(1).population, 1) + + tss, _ = self.do_simplify( + tables.tree_sequence(), filter_populations=False) + self.assertEqual(tss.tables.populations, tables.populations) + + def test_simple_individual_filter(self): + tables = tskit.TableCollection(1) + tables.individuals.add_row(flags=0) + tables.individuals.add_row(flags=1) + tables.nodes.add_row(flags=1, individual=0) + tables.nodes.add_row(flags=1, individual=0) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=True) + self.assertEqual(tss.num_nodes, 2) + self.assertEqual(tss.num_individuals, 1) + self.assertEqual(tss.individual(0).flags, 0) + + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=False) + self.assertEqual(tss.tables.individuals, tables.individuals) + + def test_interleaved_individual_filter(self): + tables = tskit.TableCollection(1) + tables.individuals.add_row(flags=0) + tables.individuals.add_row(flags=1) + tables.individuals.add_row(flags=2) + tables.nodes.add_row(flags=1, individual=1) + tables.nodes.add_row(flags=1, individual=-1) + tables.nodes.add_row(flags=1, individual=1) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=True) + self.assertEqual(tss.num_nodes, 3) + self.assertEqual(tss.num_individuals, 1) + self.assertEqual(tss.individual(0).flags, 1) + + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=False) + self.assertEqual(tss.tables.individuals, tables.individuals) + + def test_removed_node_individual_filter(self): + tables = tskit.TableCollection(1) + tables.individuals.add_row(flags=0) + tables.individuals.add_row(flags=1) + tables.individuals.add_row(flags=2) + tables.nodes.add_row(flags=1, individual=0) + # Because flags=0 here, this node will be simplified out and the node + # will disappear. + tables.nodes.add_row(flags=0, individual=1) + tables.nodes.add_row(flags=1, individual=2) + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=True) + self.assertEqual(tss.num_nodes, 2) + self.assertEqual(tss.num_individuals, 2) + self.assertEqual(tss.individual(0).flags, 0) + self.assertEqual(tss.individual(1).flags, 2) + self.assertEqual(tss.node(0).individual, 0) + self.assertEqual(tss.node(1).individual, 1) + + tss, _ = self.do_simplify(tables.tree_sequence(), filter_individuals=False) + self.assertEqual(tss.tables.individuals, tables.individuals) + + def verify_simplify_haplotypes(self, ts, samples): + sub_ts, node_map = self.do_simplify(ts, samples, filter_sites=False) + self.assertEqual(ts.num_sites, sub_ts.num_sites) + sub_haplotypes = list(sub_ts.haplotypes()) + all_samples = list(ts.samples()) + k = 0 + for j, h in enumerate(ts.haplotypes()): + if k == len(samples): + break + if samples[k] == all_samples[j]: + self.assertEqual(h, sub_haplotypes[k]) + k += 1 + + def test_single_tree_recurrent_mutations(self): + ts = msprime.simulate(6, random_seed=10) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + def test_many_trees_recurrent_mutations(self): + ts = msprime.simulate(5, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + def test_single_multiroot_tree_recurrent_mutations(self): + ts = msprime.simulate(6, random_seed=10) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + def test_many_multiroot_trees_recurrent_mutations(self): + ts = msprime.simulate(7, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + def test_single_tree_recurrent_mutations_internal_samples(self): + ts = msprime.simulate(6, random_seed=10) + ts = tsutil.jiggle_samples(ts) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + def test_many_trees_recurrent_mutations_internal_samples(self): + ts = msprime.simulate(5, recombination_rate=1, random_seed=10) + ts = tsutil.jiggle_samples(ts) + self.assertGreater(ts.num_trees, 3) + for mutations_per_branch in [1, 2, 3]: + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + for num_samples in range(1, ts.num_samples): + for samples in itertools.combinations(ts.samples(), num_samples): + self.verify_simplify_haplotypes(ts, samples) + + +class TestMutationParent(unittest.TestCase): + """ + Tests that mutation parent is correctly specified, and that we correctly + recompute it with compute_mutation_parent. + """ + seed = 42 + + def verify_parents(self, ts): + parent = tsutil.compute_mutation_parent(ts) + tables = ts.tables + self.assertTrue(np.array_equal(parent, tables.mutations.parent)) + mutations = tables.mutations + mutations.set_columns( + site=mutations.site, node=mutations.node, + derived_state=mutations.derived_state, + derived_state_offset=mutations.derived_state_offset) + self.assertTrue(np.all(mutations.parent == tskit.NULL)) + tables.compute_mutation_parents() + self.assertTrue(np.array_equal(parent, tables.mutations.parent)) + + def test_example(self): + nodes = six.StringIO("""\ + id is_sample time + 0 0 2.0 + 1 0 1.0 + 2 0 1.0 + 3 1 0 + 4 1 0 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.5 2 3 + 0.0 0.8 2 4 + 0.5 1.0 1 3 + 0.0 1.0 0 1 + 0.0 1.0 0 2 + 0.8 1.0 0 4 + """) + sites = six.StringIO("""\ + position ancestral_state + 0.1 0 + 0.5 0 + 0.9 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 1 1 -1 + 0 2 1 -1 + 0 3 2 1 + 1 0 1 -1 + 1 1 1 3 + 1 3 2 4 + 1 2 1 3 + 1 4 2 6 + 2 0 1 -1 + 2 1 1 8 + 2 2 1 8 + 2 4 1 8 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + self.verify_parents(ts) + + def test_single_muts(self): + ts = msprime.simulate(10, random_seed=self.seed, mutation_rate=3.0, + recombination_rate=1.0) + self.verify_parents(ts) + + def test_with_jukes_cantor(self): + ts = msprime.simulate(10, random_seed=self.seed, mutation_rate=0.0, + recombination_rate=1.0) + # make *lots* of recurrent mutations + mut_ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, + multiple_per_node=False, seed=self.seed) + self.verify_parents(mut_ts) + + def test_with_jukes_cantor_multiple_per_node(self): + ts = msprime.simulate(10, random_seed=self.seed, mutation_rate=0.0, + recombination_rate=1.0) + # make *lots* of recurrent mutations + mut_ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, + multiple_per_node=True, seed=self.seed) + self.verify_parents(mut_ts) + + def verify_branch_mutations(self, ts, mutations_per_branch): + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + self.assertGreater(ts.num_mutations, 1) + self.verify_parents(ts) + + def test_single_tree_one_mutation_per_branch(self): + ts = msprime.simulate(6, random_seed=10) + self.verify_branch_mutations(ts, 1) + + def test_single_tree_two_mutations_per_branch(self): + ts = msprime.simulate(10, random_seed=9) + self.verify_branch_mutations(ts, 2) + + def test_single_tree_three_mutations_per_branch(self): + ts = msprime.simulate(8, random_seed=9) + self.verify_branch_mutations(ts, 3) + + def test_single_multiroot_tree_recurrent_mutations(self): + ts = msprime.simulate(6, random_seed=10) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + self.verify_branch_mutations(ts, mutations_per_branch) + + def test_many_multiroot_trees_recurrent_mutations(self): + ts = msprime.simulate(7, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + self.verify_branch_mutations(ts, mutations_per_branch) + + +class TestSimpleTreeAlgorithm(unittest.TestCase): + """ + Tests for the direct implementation of Algorithm T in tsutil.py. + + See TestHoleyTreeSequences above for further tests on wacky topologies. + """ + def test_zero_nodes(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + # Test the simple tree iterator. + trees = list(tsutil.algorithm_T(ts)) + self.assertEqual(len(trees), 1) + (left, right), parent = trees[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(parent, []) + + def test_one_node(self): + tables = tskit.TableCollection(1) + tables.nodes.add_row() + ts = tables.tree_sequence() + self.assertEqual(ts.sequence_length, 1) + self.assertEqual(ts.num_trees, 1) + # Test the simple tree iterator. + trees = list(tsutil.algorithm_T(ts)) + self.assertEqual(len(trees), 1) + (left, right), parent = trees[0] + self.assertEqual(left, 0) + self.assertEqual(right, 1) + self.assertEqual(parent, [-1]) + + def test_single_coalescent_tree(self): + ts = msprime.simulate(10, random_seed=1, length=10) + tree = ts.first() + p1 = [tree.parent(j) for j in range(ts.num_nodes)] + interval, p2 = next(tsutil.algorithm_T(ts)) + self.assertEqual(interval, tree.interval) + self.assertEqual(p1, p2) + + def test_coalescent_trees(self): + ts = msprime.simulate(8, recombination_rate=5, random_seed=1, length=2) + self.assertGreater(ts.num_trees, 2) + new_trees = tsutil.algorithm_T(ts) + for tree in ts.trees(): + interval, p2 = next(new_trees) + p1 = [tree.parent(j) for j in range(ts.num_nodes)] + self.assertEqual(interval, tree.interval) + self.assertEqual(p1, p2) + self.assertRaises(StopIteration, next, new_trees) + + +class TestSampleLists(unittest.TestCase): + """ + Tests for the sample lists algorithm. + """ + def verify(self, ts): + tree1 = tsutil.LinkedTree(ts) + s = str(tree1) + self.assertIsNotNone(s) + trees = ts.trees(sample_lists=True) + for left, right in tree1.sample_lists(): + tree2 = next(trees) + assert (left, right) == tree2.interval + for u in tree2.nodes(): + self.assertEqual(tree1.left_sample[u], tree2.left_sample(u)) + self.assertEqual(tree1.right_sample[u], tree2.right_sample(u)) + for j in range(ts.num_samples): + self.assertEqual(tree1.next_sample[j], tree2.next_sample(j)) + assert right == ts.sequence_length + + tree1 = tsutil.LinkedTree(ts) + trees = ts.trees(sample_lists=False) + sample_index_map = ts.samples() + for left, right in tree1.sample_lists(): + tree2 = next(trees) + for u in range(ts.num_nodes): + samples2 = list(tree2.samples(u)) + samples1 = [] + index = tree1.left_sample[u] + if index != tskit.NULL: + self.assertEqual( + sample_index_map[tree1.left_sample[u]], samples2[0]) + self.assertEqual( + sample_index_map[tree1.right_sample[u]], samples2[-1]) + stop = tree1.right_sample[u] + while True: + assert index != -1 + samples1.append(sample_index_map[index]) + if index == stop: + break + index = tree1.next_sample[index] + self.assertEqual(samples1, samples2) + assert right == ts.sequence_length + + def test_single_coalescent_tree(self): + ts = msprime.simulate(10, random_seed=1, length=10) + self.verify(ts) + + def test_coalescent_trees(self): + ts = msprime.simulate(8, recombination_rate=5, random_seed=1, length=2) + self.assertGreater(ts.num_trees, 2) + self.verify(ts) + + def test_coalescent_trees_internal_samples(self): + ts = msprime.simulate(8, recombination_rate=5, random_seed=10, length=2) + self.assertGreater(ts.num_trees, 2) + self.verify(tsutil.jiggle_samples(ts)) + + def test_coalescent_trees_all_samples(self): + ts = msprime.simulate(8, recombination_rate=5, random_seed=10, length=2) + self.assertGreater(ts.num_trees, 2) + tables = ts.dump_tables() + tables.nodes.set_columns( + flags=np.zeros_like(tables.nodes.flags) + tskit.NODE_IS_SAMPLE, + time=tables.nodes.time) + self.verify(tables.tree_sequence()) + + def test_wright_fisher_trees_unsimplified(self): + tables = wf.wf_sim(10, 5, deep_history=False, seed=2) + tables.sort() + ts = tables.tree_sequence() + self.verify(ts) + + def test_wright_fisher_trees_simplified(self): + tables = wf.wf_sim(10, 5, deep_history=False, seed=1) + tables.sort() + ts = tables.tree_sequence() + ts = ts.simplify() + self.verify(ts) + + def test_wright_fisher_trees_simplified_one_gen(self): + tables = wf.wf_sim(10, 1, deep_history=False, seed=1) + tables.sort() + ts = tables.tree_sequence() + ts = ts.simplify() + self.verify(ts) + + def test_nonbinary_trees(self): + demographic_events = [ + msprime.SimpleBottleneck(time=1.0, population=0, proportion=0.95)] + ts = msprime.simulate( + 20, recombination_rate=10, mutation_rate=5, + demographic_events=demographic_events, random_seed=7) + found = False + for e in ts.edgesets(): + if len(e.children) > 2: + found = True + self.assertTrue(found) + self.verify(ts) + + def test_many_multiroot_trees(self): + ts = msprime.simulate(7, recombination_rate=1, random_seed=10) + self.assertGreater(ts.num_trees, 3) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + self.verify(ts) + + +def squash_edges(ts): + """ + Returns the edges in the tree sequence squashed. + """ + t = ts.tables.nodes.time + edges = list(ts.edges()) + edges.sort(key=lambda e: (t[e.parent], e.parent, e.child, e.left)) + if len(edges) == 0: + return [] + + squashed = [] + last_e = edges[0] + for e in edges[1:]: + condition = ( + e.parent != last_e.parent or + e.child != last_e.child or + e.left != last_e.right) + if condition: + squashed.append(last_e) + last_e = e + last_e.right = e.right + squashed.append(last_e) + return squashed + + +def reduce_topology(ts): + """ + Returns a tree sequence with the minimal information required to represent + the tree topologies at its sites. Uses a left-to-right algorithm. + """ + tables = ts.dump_tables() + edge_map = {} + + def add_edge(left, right, parent, child): + new_edge = tskit.Edge(left, right, parent, child) + if child not in edge_map: + edge_map[child] = new_edge + else: + edge = edge_map[child] + if edge.right == left and edge.parent == parent: + # Squash + edge.right = right + else: + tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child) + edge_map[child] = new_edge + + tables.edges.clear() + + edge_buffer = [] + first_site = True + for tree in ts.trees(): + # print(tree.interval) + # print(tree.draw(format="unicode")) + if tree.num_sites > 0: + sites = list(tree.sites()) + if first_site: + x = 0 + # print("First site", sites) + first_site = False + else: + x = sites[0].position + # Flush the edge buffer. + for left, parent, child in edge_buffer: + add_edge(left, x, parent, child) + # Add edges for each node in the tree. + edge_buffer = [] + for root in tree.roots: + for u in tree.nodes(root): + if u != root: + edge_buffer.append((x, tree.parent(u), u)) + # Add the final edges. + for left, parent, child in edge_buffer: + add_edge(left, tables.sequence_length, parent, child) + # Flush the remaining edges to the table + for edge in edge_map.values(): + tables.edges.add_row(edge.left, edge.right, edge.parent, edge.child) + tables.sort() + ts = tables.tree_sequence() + # Now simplify to remove redundant nodes. + return ts.simplify(map_nodes=True, filter_sites=False) + + +class TestReduceTopology(unittest.TestCase): + """ + Tests to ensure that reduce topology in simplify is equivalent to the + reduce_topology function above. + """ + + def verify(self, ts): + source_tables = ts.tables + X = source_tables.sites.position + position_count = {x: 0 for x in X} + position_count[0] = 0 + position_count[ts.sequence_length] = 0 + mts, node_map = reduce_topology(ts) + for edge in mts.edges(): + self.assertIn(edge.left, position_count) + self.assertIn(edge.right, position_count) + position_count[edge.left] += 1 + position_count[edge.right] += 1 + if ts.num_sites == 0: + # We should have zero edges output. + self.assertEqual(mts.num_edges, 0) + elif X[0] != 0: + # The first site (if it's not zero) should be mapped to zero so + # this never occurs in edges. + self.assertEqual(position_count[X[0]], 0) + + minimised_trees = mts.trees() + minimised_tree = next(minimised_trees) + minimised_tree_sites = minimised_tree.sites() + for tree in ts.trees(): + for site in tree.sites(): + minimised_site = next(minimised_tree_sites, None) + if minimised_site is None: + minimised_tree = next(minimised_trees) + minimised_tree_sites = minimised_tree.sites() + minimised_site = next(minimised_tree_sites) + self.assertEqual(site.position, minimised_site.position) + self.assertEqual(site.ancestral_state, minimised_site.ancestral_state) + self.assertEqual(site.metadata, minimised_site.metadata) + self.assertEqual(len(site.mutations), len(minimised_site.mutations)) + + for mutation, minimised_mutation in zip( + site.mutations, minimised_site.mutations): + self.assertEqual( + mutation.derived_state, minimised_mutation.derived_state) + self.assertEqual(mutation.metadata, minimised_mutation.metadata) + self.assertEqual(mutation.parent, minimised_mutation.parent) + self.assertEqual(node_map[mutation.node], minimised_mutation.node) + if tree.num_sites > 0: + mapped_dict = { + node_map[u]: node_map[v] for u, v in tree.parent_dict.items()} + self.assertEqual(mapped_dict, minimised_tree.parent_dict) + self.assertTrue(np.array_equal(ts.genotype_matrix(), mts.genotype_matrix())) + + edges = list(mts.edges()) + squashed = squash_edges(mts) + self.assertEqual(len(edges), len(squashed)) + self.assertEqual(edges, squashed) + + # Verify against simplify implementations. + s = tests.Simplifier( + ts, ts.samples(), reduce_to_site_topology=True, filter_sites=False) + sts1, _ = s.simplify() + sts2 = ts.simplify(reduce_to_site_topology=True, filter_sites=False) + t1 = mts.tables + for sts in [sts2, sts2]: + t2 = sts.tables + self.assertEqual(t1.nodes, t2.nodes) + self.assertEqual(t1.edges, t2.edges) + self.assertEqual(t1.sites, t2.sites) + self.assertEqual(t1.mutations, t2.mutations) + self.assertEqual(t1.populations, t2.populations) + self.assertEqual(t1.individuals, t2.individuals) + return mts + + def test_no_recombination_one_site(self): + ts = msprime.simulate(15, random_seed=1) + tables = ts.dump_tables() + tables.sites.add_row(position=0.25, ancestral_state="0") + mts = self.verify(tables.tree_sequence()) + self.assertEqual(mts.num_trees, 1) + + def test_simple_recombination_one_site(self): + ts = msprime.simulate(15, random_seed=1, recombination_rate=2) + tables = ts.dump_tables() + tables.sites.add_row(position=0.25, ancestral_state="0") + mts = self.verify(tables.tree_sequence()) + self.assertEqual(mts.num_trees, 1) + + def test_simple_recombination_fixed_sites(self): + ts = msprime.simulate(5, random_seed=1, recombination_rate=2) + tables = ts.dump_tables() + for x in [0.25, 0.5, 0.75]: + tables.sites.add_row(position=x, ancestral_state="0") + self.verify(tables.tree_sequence()) + + def get_integer_edge_ts(self, n, m): + recombination_map = msprime.RecombinationMap.uniform_map(m, 1, num_loci=m) + ts = msprime.simulate(n, random_seed=1, recombination_map=recombination_map) + self.assertGreater(ts.num_trees, 1) + for edge in ts.edges(): + self.assertEqual(int(edge.left), edge.left) + self.assertEqual(int(edge.right), edge.right) + return ts + + def test_integer_edges_one_site(self): + ts = self.get_integer_edge_ts(5, 10) + tables = ts.dump_tables() + tables.sites.add_row(position=1, ancestral_state="0") + mts = self.verify(tables.tree_sequence()) + self.assertEqual(mts.num_trees, 1) + + def test_integer_edges_all_sites(self): + ts = self.get_integer_edge_ts(5, 10) + tables = ts.dump_tables() + for x in range(10): + tables.sites.add_row(position=x, ancestral_state="0") + mts = self.verify(tables.tree_sequence()) + self.assertEqual(mts.num_trees, ts.num_trees) + + def test_simple_recombination_site_at_zero(self): + ts = msprime.simulate(5, random_seed=1, recombination_rate=2) + tables = ts.dump_tables() + tables.sites.add_row(position=0, ancestral_state="0") + mts = self.verify(tables.tree_sequence()) + self.assertEqual(mts.num_trees, 1) + + def test_simple_recombination(self): + ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) + self.verify(ts) + + def test_large_recombination(self): + ts = msprime.simulate(25, random_seed=12, recombination_rate=5, mutation_rate=15) + self.verify(ts) + + def test_no_recombination(self): + ts = msprime.simulate(5, random_seed=1, mutation_rate=2) + self.verify(ts) + + def test_no_mutation(self): + ts = msprime.simulate(5, random_seed=1) + self.verify(ts) + + def test_zero_sites(self): + ts = msprime.simulate(5, random_seed=2) + self.assertEqual(ts.num_sites, 0) + mts = ts.simplify(reduce_to_site_topology=True) + self.assertEqual(mts.num_trees, 1) + self.assertEqual(mts.num_edges, 0) + + def test_many_roots(self): + ts = msprime.simulate(25, random_seed=12, recombination_rate=2, length=10) + tables = tsutil.decapitate(ts, ts.num_edges // 2).dump_tables() + for x in range(10): + tables.sites.add_row(x, "0") + self.verify(tables.tree_sequence()) + + def test_branch_sites(self): + ts = msprime.simulate(15, random_seed=12, recombination_rate=2, length=10) + ts = tsutil.insert_branch_sites(ts) + self.verify(ts) + + def test_jiggled_samples(self): + ts = msprime.simulate(8, random_seed=13, recombination_rate=2, length=10) + ts = tsutil.jiggle_samples(ts) + self.verify(ts) + + +def search_sorted(a, v): + """ + Implementation of searchsorted based on binary search with the same + semantics as numpy's searchsorted. Used as the basis of the C + implementation which we use in the simplify algorithm. + """ + upper = len(a) + if upper == 0: + return 0 + lower = 0 + while upper - lower > 1: + mid = (upper + lower) // 2 + if (v >= a[mid]): + lower = mid + else: + upper = mid + offset = 0 + if a[lower] < v: + offset = 1 + return lower + offset + + +class TestSearchSorted(unittest.TestCase): + """ + Tests for the basic implementation of search_sorted. + """ + def verify(self, a): + a = np.array(a) + start, end = a[0], a[-1] + # Check random values. + np.random.seed(43) + for v in np.random.uniform(start, end, 10): + self.assertEqual(search_sorted(a, v), np.searchsorted(a, v)) + # Check equal values. + for v in a: + self.assertEqual(search_sorted(a, v), np.searchsorted(a, v)) + # Check values outside bounds. + for v in [start - 2, start - 1, end, end + 1, end + 2]: + self.assertEqual(search_sorted(a, v), np.searchsorted(a, v)) + + def test_range(self): + for j in range(1, 20): + self.verify(range(j)) + + def test_negative_range(self): + for j in range(1, 20): + self.verify(-1 * np.arange(j)[::-1]) + + def test_random_unit_interval(self): + np.random.seed(143) + for size in range(1, 100): + a = np.random.random(size=size) + a.sort() + self.verify(a) + + def test_random_interval(self): + np.random.seed(143) + for _ in range(10): + interval = np.random.random(2) * 10 + interval.sort() + a = np.random.uniform(*interval, size=100) + a.sort() + self.verify(a) + + def test_random_negative(self): + np.random.seed(143) + for _ in range(10): + interval = np.random.random(2) * 5 + interval.sort() + a = -1 * np.random.uniform(*interval, size=100) + a.sort() + self.verify(a) + + def test_edge_cases(self): + for v in [0, 1]: + self.assertEqual(search_sorted([], v), np.searchsorted([], v)) + self.assertEqual(search_sorted([1], v), np.searchsorted([1], v)) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py new file mode 100644 index 0000000000..aa308730fb --- /dev/null +++ b/python/tests/test_tree_stats.py @@ -0,0 +1,1474 @@ +""" +Test cases for generalized statistic computation. +""" +from __future__ import print_function +from __future__ import division + + +import unittest +import random + +import numpy as np +import numpy.testing as nt + +import six +import msprime + +import tskit +import tests.tsutil as tsutil + + +def path_length(tr, x, y): + L = 0 + mrca = tr.mrca(x, y) + for u in x, y: + while u != mrca: + L += tr.branch_length(u) + u = tr.parent(u) + return L + + +class PythonBranchLengthStatCalculator(object): + """ + Python implementations of various ("tree") branch-length statistics - + inefficient but more clear what they are doing. + """ + + def __init__(self, tree_sequence): + self.tree_sequence = tree_sequence + + def divergence(self, X, Y, begin=0.0, end=None): + ''' + Computes average pairwise diversity between a random choice from x + and a random choice from y over the window specified. + ''' + if end is None: + end = self.tree_sequence.sequence_length + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + SS = 0 + for x in X: + for y in Y: + SS += path_length(tr, x, y) / 2.0 + S += SS*(min(end, tr.interval[1]) - max(begin, tr.interval[0])) + return S/((end-begin)*len(X)*len(Y)) + + def tree_length_diversity(self, X, Y, begin=0.0, end=None): + ''' + Computes average pairwise diversity between a random choice from x + and a random choice from y over the window specified. + ''' + if end is None: + end = self.tree_sequence.sequence_length + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + SS = 0 + for x in X: + for y in Y: + SS += path_length(tr, x, y) + S += SS*(min(end, tr.interval[1]) - max(begin, tr.interval[0])) + return S/((end-begin)*len(X)*len(Y)) + + def Y3(self, X, Y, Z, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + for x in X: + for y in Y: + for z in Z: + xy_mrca = tr.mrca(x, y) + xz_mrca = tr.mrca(x, z) + yz_mrca = tr.mrca(y, z) + if xy_mrca == xz_mrca: + # /\ + # / /\ + # x y z + S += path_length(tr, x, yz_mrca) * this_length + elif xy_mrca == yz_mrca: + # /\ + # / /\ + # y x z + S += path_length(tr, x, xz_mrca) * this_length + elif xz_mrca == yz_mrca: + # /\ + # / /\ + # z x y + S += path_length(tr, x, xy_mrca) * this_length + return S/((end - begin) * len(X) * len(Y) * len(Z)) + + def Y2(self, X, Y, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + for x in X: + for y in Y: + for z in set(Y) - set([y]): + xy_mrca = tr.mrca(x, y) + xz_mrca = tr.mrca(x, z) + yz_mrca = tr.mrca(y, z) + if xy_mrca == xz_mrca: + # /\ + # / /\ + # x y z + S += path_length(tr, x, yz_mrca) * this_length + elif xy_mrca == yz_mrca: + # /\ + # / /\ + # y x z + S += path_length(tr, x, xz_mrca) * this_length + elif xz_mrca == yz_mrca: + # /\ + # / /\ + # z x y + S += path_length(tr, x, xy_mrca) * this_length + return S/((end - begin) * len(X) * len(Y) * (len(Y)-1)) + + def Y1(self, X, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + for x in X: + for y in set(X) - set([x]): + for z in set(X) - set([x, y]): + xy_mrca = tr.mrca(x, y) + xz_mrca = tr.mrca(x, z) + yz_mrca = tr.mrca(y, z) + if xy_mrca == xz_mrca: + # /\ + # / /\ + # x y z + S += path_length(tr, x, yz_mrca) * this_length + elif xy_mrca == yz_mrca: + # /\ + # / /\ + # y x z + S += path_length(tr, x, xz_mrca) * this_length + elif xz_mrca == yz_mrca: + # /\ + # / /\ + # z x y + S += path_length(tr, x, xy_mrca) * this_length + return S/((end - begin) * len(X) * (len(X)-1) * (len(X)-2)) + + def f4(self, A, B, C, D, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + for U in A, B, C, D: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A,B,C, and D cannot contain repeated elements.") + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + SS = 0 + for a in A: + for b in B: + for c in C: + for d in D: + SS += path_length(tr, tr.mrca(a, c), tr.mrca(b, d)) + SS -= path_length(tr, tr.mrca(a, d), tr.mrca(b, c)) + S += SS * this_length + return S / ((end - begin) * len(A) * len(B) * len(C) * len(D)) + + def f3(self, A, B, C, begin=0.0, end=None): + # this is f4(A,B;A,C) but drawing distinct samples from A + if end is None: + end = self.tree_sequence.sequence_length + assert(len(A) > 1) + for U in A, B, C: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A, B and C cannot contain repeated elements.") + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + SS = 0 + for a in A: + for b in B: + for c in set(A) - set([a]): + for d in C: + SS += path_length(tr, tr.mrca(a, c), tr.mrca(b, d)) + SS -= path_length(tr, tr.mrca(a, d), tr.mrca(b, c)) + S += SS * this_length + return S / ((end - begin) * len(A) * (len(A) - 1) * len(B) * len(C)) + + def f2(self, A, B, begin=0.0, end=None): + # this is f4(A,B;A,B) but drawing distinct samples from A and B + if end is None: + end = self.tree_sequence.sequence_length + assert(len(A) > 1) + for U in A, B: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A and B cannot contain repeated elements.") + S = 0 + for tr in self.tree_sequence.trees(): + if tr.interval[1] <= begin: + continue + if tr.interval[0] >= end: + break + this_length = min(end, tr.interval[1]) - max(begin, tr.interval[0]) + SS = 0 + for a in A: + for b in B: + for c in set(A) - set([a]): + for d in set(B) - set([b]): + SS += path_length(tr, tr.mrca(a, c), tr.mrca(b, d)) + SS -= path_length(tr, tr.mrca(a, d), tr.mrca(b, c)) + S += SS * this_length + return S / ((end - begin) * len(A) * (len(A) - 1) * len(B) * (len(B) - 1)) + + def tree_stat(self, sample_sets, weight_fun, begin=0.0, end=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a function + whose argument is a list of integers of the same length as sample_sets + that returns a number. Each branch in a tree is weighted by weight_fun(x), + where x[i] is the number of samples in sample_sets[i] below that + branch. This finds the sum of all counted branches for each tree, + and averages this across the tree sequence ts, weighted by genomic length. + + This version is inefficient as it iterates over all nodes in each tree. + ''' + out = self.tree_stat_vector(sample_sets, + lambda x: [weight_fun(x)], + begin=begin, end=end) + if len(out) > 1: + raise ValueError("Expecting output of length 1.") + return out[0] + + def tree_stat_vector(self, sample_sets, weight_fun, begin=0.0, end=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a function + whose argument is a list of integers of the same length as sample_sets + that returns a list of numbers; there will be one output for each element. + For each value, each branch in a tree is weighted by weight_fun(x), + where x[i] is the number of samples in sample_sets[i] below that + branch. This finds the sum of all counted branches for each tree, + and averages this across the tree sequence ts, weighted by genomic length. + + This version is inefficient as it iterates over all nodes in each tree. + ''' + for U in sample_sets: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("elements of sample_sets", + "cannot contain repeated elements.") + if end is None: + end = self.tree_sequence.sequence_length + tr_its = [ + self.tree_sequence.trees(tracked_samples=x, sample_counts=True) + for x in sample_sets] + n = [len(U) for U in sample_sets] + n_out = len(weight_fun([0 for a in sample_sets])) + S = [0.0 for j in range(n_out)] + for k in range(self.tree_sequence.num_trees): + trs = [next(x) for x in tr_its] + root = trs[0].root + tr_len = min(end, trs[0].interval[1]) - max(begin, trs[0].interval[0]) + if tr_len > 0: + for node in trs[0].nodes(): + if node != root: + x = [tr.num_tracked_samples(node) for tr in trs] + nx = [a - b for a, b in zip(n, x)] + w = [a + b for a, b in zip(weight_fun(x), weight_fun(nx))] + for j in range(n_out): + S[j] += w[j] * trs[0].branch_length(node) * tr_len + for j in range(n_out): + S[j] /= (end-begin) + return S + + def site_frequency_spectrum(self, sample_set, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + n_out = len(sample_set) + S = [0.0 for j in range(n_out)] + for t in self.tree_sequence.trees(tracked_samples=sample_set, + sample_counts=True): + root = t.root + tr_len = min(end, t.interval[1]) - max(begin, t.interval[0]) + if tr_len > 0: + for node in t.nodes(): + if node != root: + x = t.num_tracked_samples(node) + if x > 0: + S[x - 1] += t.branch_length(node) * tr_len + for j in range(n_out): + S[j] /= (end-begin) + return S + + +class PythonSiteStatCalculator(object): + """ + Python implementations of various single-site statistics - + inefficient but more clear what they are doing. + """ + + def __init__(self, tree_sequence): + self.tree_sequence = tree_sequence + + def divergence(self, X, Y, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for x in X: + for y in Y: + if (haps[x][k] != haps[y][k]): + # x|y + S += 1 + return S/((end - begin) * len(X) * len(Y)) + + def Y3(self, X, Y, Z, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for x in X: + for y in Y: + for z in Z: + if ((haps[x][k] != haps[y][k]) + and (haps[x][k] != haps[z][k])): + # x|yz + S += 1 + return S/((end - begin) * len(X) * len(Y) * len(Z)) + + def Y2(self, X, Y, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for x in X: + for y in Y: + for z in set(Y) - set([y]): + if ((haps[x][k] != haps[y][k]) + and (haps[x][k] != haps[z][k])): + # x|yz + S += 1 + return S/((end - begin) * len(X) * len(Y) * (len(Y) - 1)) + + def Y1(self, X, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for x in X: + for y in set(X) - set([x]): + for z in set(X) - set([x, y]): + if ((haps[x][k] != haps[y][k]) + and (haps[x][k] != haps[z][k])): + # x|yz + S += 1 + return S/((end - begin) * len(X) * (len(X) - 1) * (len(X) - 2)) + + def f4(self, A, B, C, D, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + for U in A, B, C, D: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A,B,C, and D cannot contain repeated elements.") + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for a in A: + for b in B: + for c in C: + for d in D: + if ((haps[a][k] == haps[c][k]) + and (haps[a][k] != haps[d][k]) + and (haps[a][k] != haps[b][k])): + # ac|bd + S += 1 + elif ((haps[a][k] == haps[d][k]) + and (haps[a][k] != haps[c][k]) + and (haps[a][k] != haps[b][k])): + # ad|bc + S -= 1 + return S / ((end - begin) * len(A) * len(B) * len(C) * len(D)) + + def f3(self, A, B, C, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + for U in A, B, C: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A,B,and C cannot contain repeated elements.") + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for a in A: + for b in B: + for c in set(A) - set([a]): + for d in C: + if ((haps[a][k] == haps[c][k]) + and (haps[a][k] != haps[d][k]) + and (haps[a][k] != haps[b][k])): + # ac|bd + S += 1 + elif ((haps[a][k] == haps[d][k]) + and (haps[a][k] != haps[c][k]) + and (haps[a][k] != haps[b][k])): + # ad|bc + S -= 1 + return S / ((end - begin) * len(A) * len(B) * len(C) * (len(A) - 1)) + + def f2(self, A, B, begin=0.0, end=None): + if end is None: + end = self.tree_sequence.sequence_length + for U in A, B: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("A,and B cannot contain repeated elements.") + haps = list(self.tree_sequence.haplotypes()) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = 0 + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + for a in A: + for b in B: + for c in set(A) - set([a]): + for d in set(B) - set([b]): + if ((haps[a][k] == haps[c][k]) + and (haps[a][k] != haps[d][k]) + and (haps[a][k] != haps[b][k])): + # ac|bd + S += 1 + elif ((haps[a][k] == haps[d][k]) + and (haps[a][k] != haps[c][k]) + and (haps[a][k] != haps[b][k])): + # ad|bc + S -= 1 + return S / ((end - begin) * len(A) * len(B) + * (len(A) - 1) * (len(B) - 1)) + + def tree_stat_vector(self, sample_sets, weight_fun, begin=0.0, end=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a function + whose argument is a list of integers of the same length as sample_sets + that returns a list of numbers; there will be one output for each element. + For each value, each allele in a tree is weighted by weight_fun(x), where + x[i] is the number of samples in sample_sets[i] that inherit that allele. + This finds the sum of this value for all alleles at all polymorphic sites, + and across the tree sequence ts, weighted by genomic length. + + This version is inefficient as it works directly with haplotypes. + ''' + for U in sample_sets: + if max([U.count(x) for x in set(U)]) > 1: + raise ValueError("elements of sample_sets", + "cannot contain repeated elements.") + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + n_out = len(weight_fun([0 for a in sample_sets])) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = [0.0 for j in range(n_out)] + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)] + g = [[haps[j][k] for j in u] for u in sample_sets] + for a in set(all_g): + x = [h.count(a) for h in g] + w = weight_fun(x) + for j in range(n_out): + S[j] += w[j] + for j in range(n_out): + S[j] /= (end - begin) + return S + + def tree_stat(self, sample_sets, weight_fun, begin=0.0, end=None): + ''' + This provides a non-vectorized interface to `tree_stat_vector()`. + ''' + out = self.tree_stat_vector(sample_sets, lambda x: [weight_fun(x)], + begin=begin, end=end) + if len(out) > 1: + raise ValueError("Expecting output of length 1.") + return out[0] + + def site_frequency_spectrum(self, sample_set, begin=0.0, end=None): + ''' + ''' + if end is None: + end = self.tree_sequence.sequence_length + haps = list(self.tree_sequence.haplotypes()) + n_out = len(sample_set) + site_positions = [x.position for x in self.tree_sequence.sites()] + S = [0.0 for j in range(n_out)] + for k in range(self.tree_sequence.num_sites): + if (site_positions[k] >= begin) and (site_positions[k] < end): + all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)] + g = [haps[j][k] for j in sample_set] + for a in set(all_g): + x = g.count(a) + if x > 0: + S[x - 1] += 1.0 + for j in range(n_out): + S[j] /= (end - begin) + return S + + +def upper_tri_to_matrix(x): + """ + Given x, a vector of entries of the upper triangle of a matrix + in row-major order, including the diagonal, return the corresponding matrix. + """ + # n^2 + n = 2 u => n = (-1 + sqrt(1 + 8*u))/2 + n = int((np.sqrt(1 + 8 * len(x)) - 1)/2.0) + out = np.ones((n, n)) + k = 0 + for i in range(n): + for j in range(i, n): + out[i, j] = out[j, i] = x[k] + k += 1 + return out + + +class TestStatsInterface(unittest.TestCase): + """ + Tests basic stat calculator interface. + """ + + def test_interface(self): + self.assertRaises(TypeError, tskit.GeneralStatCalculator) + self.assertRaises(TypeError, tskit.SiteStatCalculator) + self.assertRaises(TypeError, tskit.BranchLengthStatCalculator) + + +class GeneralStatsTestCase(unittest.TestCase): + """ + Tests of statistic computation. Derived classes should have attributes + `stat_class` and `py_stat_class`. + """ + random_seed = 123456 + + def assertListAlmostEqual(self, x, y): + self.assertEqual(len(x), len(y)) + for a, b in zip(x, y): + self.assertAlmostEqual(a, b) + + def assertArrayEqual(self, x, y): + nt.assert_equal(x, y) + + def assertArrayAlmostEqual(self, x, y): + nt.assert_array_almost_equal(x, y) + + def compare_stats(self, ts, tree_fn, sample_sets, index_length, + tsc_fn=None, tsc_vector_fn=None): + """ + Use to compare a tree sequence method tsc_vector_fn to a single-window-based + implementation tree_fn that takes index_length leaf sets at once. Pass + index_length=0 to signal that tsc_fn does not take an 'indices' argument; + otherwise, gives the length of each of the tuples. + + Here are the arguments these functions will get: + tree_fn(sample_set[i], ... , sample_set[k], begin=left, end=right) + tsc_vector_fn(sample_sets, windows, indices) + ... or tsc_vector_fn(sample_sets, windows) + tsc_fn(sample_sets, windows) + """ + assert(len(sample_sets) >= index_length) + nl = len(sample_sets) + windows = [k * ts.sequence_length / 20 for k in + [0] + sorted(random.sample(range(1, 20), 4)) + [20]] + indices = [random.sample(range(nl), max(1, index_length)) for _ in range(5)] + leafset_args = [[sample_sets[i] for i in ii] for ii in indices] + win_args = [{'begin': windows[i], 'end': windows[i+1]} + for i in range(len(windows)-1)] + tree_vals = [[tree_fn(*a, **b) for a in leafset_args] for b in win_args] + # flatten if necessary + if isinstance(tree_vals[0][0], list): + tree_vals = [[x for a in b for x in a] for b in tree_vals] + + if tsc_vector_fn is not None: + if index_length > 0: + tsc_vector_vals = tsc_vector_fn(sample_sets, windows, indices) + else: + tsc_vector_vals = tsc_vector_fn([sample_sets[i[0]] for i in indices], + windows) + # print("vector:") + # print(tsc_vector_vals) + # print(tree_vals) + self.assertEqual(len(tree_vals), len(windows)-1) + self.assertEqual(len(tsc_vector_vals), len(windows)-1) + for i in range(len(windows)-1): + self.assertListAlmostEqual(tsc_vector_vals[i], tree_vals[i]) + + if tsc_fn is not None: + tsc_vals_orig = [tsc_fn(*([ls] + [windows])) for ls in leafset_args] + tsc_vals = [[x[k][0] for x in tsc_vals_orig] for k in range(len(windows)-1)] + # print("not:") + # print(tsc_vals) + # print(tree_vals) + self.assertEqual(len(tsc_vals), len(windows)-1) + for i in range(len(windows)-1): + self.assertListAlmostEqual(tsc_vals[i], tree_vals[i]) + + def compare_sfs(self, ts, tree_fn, sample_sets, tsc_fn): + """ + """ + for sample_set in sample_sets: + windows = [k * ts.sequence_length / 20 for k in + [0] + sorted(random.sample(range(1, 20), 4)) + [20]] + win_args = [{'begin': windows[i], 'end': windows[i+1]} + for i in range(len(windows)-1)] + tree_vals = [tree_fn(sample_set, **b) for b in win_args] + + tsc_vals = tsc_fn(sample_set, windows) + self.assertEqual(len(tsc_vals), len(windows) - 1) + for i in range(len(windows) - 1): + self.assertListAlmostEqual(tsc_vals[i], tree_vals[i]) + + def check_tree_stat_interface(self, ts): + samples = list(ts.samples()) + tsc = self.stat_class(ts) + + def wfn(x): + return [1] + + # empty sample sets will raise an error + self.assertRaises(ValueError, tsc.tree_stat_vector, + samples[0:2] + [], wfn) + # sample_sets must be lists without repeated elements + self.assertRaises(ValueError, tsc.tree_stat_vector, + samples[0:2], wfn) + self.assertRaises(ValueError, tsc.tree_stat_vector, + [samples[0:2], [samples[2], samples[2]]], wfn) + # and must all be samples + self.assertRaises(ValueError, tsc.tree_stat_vector, + [samples[0:2], [max(samples)+1]], wfn) + # windows must start at 0.0, be increasing, and extend to the end + self.assertRaises(ValueError, tsc.tree_stat_vector, + [samples[0:2], samples[2:4]], wfn, + [0.1, ts.sequence_length]) + self.assertRaises(ValueError, tsc.tree_stat_vector, + [samples[0:2], samples[2:4]], wfn, + [0.0, 0.8*ts.sequence_length]) + self.assertRaises(ValueError, tsc.tree_stat_vector, + [samples[0:2], samples[2:4]], wfn, + [0.0, 0.8*ts.sequence_length, 0.4*ts.sequence_length, + ts.sequence_length]) + + def check_sfs_interface(self, ts): + samples = ts.samples() + tsc = self.stat_class(ts) + + # empty sample sets will raise an error + self.assertRaises(ValueError, tsc.site_frequency_spectrum, []) + # sample_sets must be lists without repeated elements + self.assertRaises(ValueError, tsc.site_frequency_spectrum, + [samples[2], samples[2]]) + # and must all be samples + self.assertRaises(ValueError, tsc.site_frequency_spectrum, + [samples[0], max(samples)+1]) + # windows must start at 0.0, be increasing, and extend to the end + self.assertRaises(ValueError, tsc.site_frequency_spectrum, + samples[0:2], [0.1, ts.sequence_length]) + self.assertRaises(ValueError, tsc.site_frequency_spectrum, + samples[0:2], [0.0, 0.8*ts.sequence_length]) + self.assertRaises(ValueError, tsc.site_frequency_spectrum, + samples[0:2], + [0.0, 0.8*ts.sequence_length, 0.4*ts.sequence_length, + ts.sequence_length]) + + def check_tree_stat_vector(self, ts): + # test the general tree_stat_vector() machinery + self.check_tree_stat_interface(ts) + samples = random.sample(list(ts.samples()), 12) + A = [[samples[0], samples[1], samples[6]], + [samples[2], samples[3], samples[7]], + [samples[4], samples[5], samples[8]], + [samples[9], samples[10], samples[11]]] + tsc = self.stat_class(ts) + py_tsc = self.py_stat_class(ts) + + # a made-up example + def tsf(sample_sets, windows, indices): + def f(x): + return [x[i] + 2.0 * x[j] + 3.5 * x[k] for i, j, k in indices] + return tsc.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + + def py_tsf(X, Y, Z, begin, end): + def f(x): + return x[0] + 2.0 * x[1] + 3.5 * x[2] + return py_tsc.tree_stat([X, Y, Z], weight_fun=f, + begin=begin, end=end) + + self.compare_stats(ts, py_tsf, A, 3, tsc_vector_fn=tsf) + + def check_sfs(self, ts): + # check site frequency spectrum + self.check_sfs_interface(ts) + A = [random.sample(list(ts.samples()), 2), + random.sample(list(ts.samples()), 4), + random.sample(list(ts.samples()), 8), + random.sample(list(ts.samples()), 10), + random.sample(list(ts.samples()), 12)] + tsc = self.stat_class(ts) + py_tsc = self.py_stat_class(ts) + + self.compare_sfs(ts, py_tsc.site_frequency_spectrum, A, + tsc.site_frequency_spectrum) + + def check_f_interface(self, ts): + tsc = self.stat_class(ts) + # sample sets must have at least two samples + self.assertRaises(ValueError, tsc.f2_vector, + [[0, 1], [3]], [0, ts.sequence_length], [(0, 1)]) + + def check_f_stats(self, ts): + self.check_f_interface(ts) + samples = random.sample(list(ts.samples()), 12) + A = [[samples[0], samples[1], samples[2]], + [samples[3], samples[4]], + [samples[5], samples[6]], + [samples[7], samples[8]], + [samples[9], samples[10], samples[11]]] + tsc = self.stat_class(ts) + py_tsc = self.py_stat_class(ts) + self.compare_stats(ts, py_tsc.f2, A, 2, + tsc_fn=tsc.f2, tsc_vector_fn=tsc.f2_vector) + self.compare_stats(ts, py_tsc.f3, A, 3, + tsc_fn=tsc.f3, tsc_vector_fn=tsc.f3_vector) + self.compare_stats(ts, py_tsc.f4, A, 4, + tsc_fn=tsc.f4, tsc_vector_fn=tsc.f4_vector) + + def check_Y_stat(self, ts): + samples = random.sample(list(ts.samples()), 12) + A = [[samples[0], samples[1], samples[6]], + [samples[2], samples[3], samples[7]], + [samples[4], samples[5], samples[8]], + [samples[9], samples[10], samples[11]]] + tsc = self.stat_class(ts) + py_tsc = self.py_stat_class(ts) + self.compare_stats(ts, py_tsc.Y3, A, 3, + tsc_fn=tsc.Y3, tsc_vector_fn=tsc.Y3_vector) + self.compare_stats(ts, py_tsc.Y2, A, 2, + tsc_fn=tsc.Y2, tsc_vector_fn=tsc.Y2_vector) + self.compare_stats(ts, py_tsc.Y1, A, 0, + tsc_vector_fn=tsc.Y1_vector) + + +class SpecificTreesTestCase(GeneralStatsTestCase): + seed = 21 + + def test_case_1(self): + # With mutations: + # + # 1.0 6 + # 0.7 / \ 5 + # / X / \ + # 0.5 X 4 4 / 4 + # / / \ / \ / X X + # 0.4 X X \ X 3 X / \ + # / / X / / X / / \ + # 0.0 0 1 2 1 0 2 0 1 2 + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + # + branch_true_diversity_01 = 2*(1 * (0.2-0) + 0.5 * (0.8-0.2) + 0.7 * (1.0-0.8)) + branch_true_diversity_02 = 2*(1 * (0.2-0) + 0.4 * (0.8-0.2) + 0.7 * (1.0-0.8)) + branch_true_diversity_12 = 2*(0.5 * (0.2-0) + 0.5 * (0.8-0.2) + 0.5 * (1.0-0.8)) + branch_true_Y = 0.2*(1 + 0.5) + 0.6*(0.4) + 0.2*(0.7+0.2) + site_true_Y = 3 + 0 + 1 + + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 0.4 + 4 0 0.5 + 5 0 0.7 + 6 0 1.0 + """) + edges = six.StringIO("""\ + left right parent child + 0.2 0.8 3 0,2 + 0.0 0.2 4 1,2 + 0.2 0.8 4 1,3 + 0.8 1.0 4 1,2 + 0.8 1.0 5 0,4 + 0.0 0.2 6 0,4 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.05 0 + 1 0.1 0 + 2 0.11 0 + 3 0.15 0 + 4 0.151 0 + 5 0.3 0 + 6 0.6 0 + 7 0.9 0 + 8 0.95 0 + 9 0.951 0 + """) + mutations = six.StringIO("""\ + site node derived_state + 0 4 1 + 1 0 1 + 2 2 1 + 3 0 1 + 4 1 1 + 5 1 1 + 6 2 1 + 7 0 1 + 8 1 1 + 9 2 1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + branch_tsc = tskit.BranchLengthStatCalculator(ts) + py_branch_tsc = PythonBranchLengthStatCalculator(ts) + site_tsc = tskit.SiteStatCalculator(ts) + py_site_tsc = PythonSiteStatCalculator(ts) + + # diversity between 0 and 1 + A = [[0], [1]] + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]) + (n[0]-x[0])*x[1])/float(2*n[0]*n[1]) + + # tree lengths: + self.assertAlmostEqual(py_branch_tsc.tree_length_diversity([0], [1]), + branch_true_diversity_01) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), + branch_true_diversity_01) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), + branch_true_diversity_01) + + # mean diversity between [0, 1] and [0, 2]: + branch_true_mean_diversity = (0 + branch_true_diversity_02 + + branch_true_diversity_01 + + branch_true_diversity_12)/4 + A = [[0, 1], [0, 2]] + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]) + (n[0]-x[0])*x[1])/8.0 + + # tree lengths: + self.assertAlmostEqual(py_branch_tsc.tree_length_diversity(A[0], A[1]), + branch_true_mean_diversity) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), + branch_true_mean_diversity) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), + branch_true_mean_diversity) + + # Y-statistic for (0/12) + A = [[0], [1, 2]] + + def f(x): + return float(((x[0] == 1) and (x[1] == 0)) + or ((x[0] == 0) and (x[1] == 2)))/2.0 + + # tree lengths: + branch_tsc_Y = branch_tsc.Y3([[0], [1], [2]], [0.0, 1.0])[0][0] + py_branch_tsc_Y = py_branch_tsc.Y3([0], [1], [2], 0.0, 1.0) + self.assertAlmostEqual(branch_tsc_Y, branch_true_Y) + self.assertAlmostEqual(py_branch_tsc_Y, branch_true_Y) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), branch_true_Y) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), branch_true_Y) + + # sites, Y: + site_tsc_Y = site_tsc.Y3([[0], [1], [2]], [0.0, 1.0])[0][0] + py_site_tsc_Y = py_site_tsc.Y3([0], [1], [2], 0.0, 1.0) + self.assertAlmostEqual(site_tsc_Y, site_true_Y) + self.assertAlmostEqual(py_site_tsc_Y, site_true_Y) + self.assertAlmostEqual(site_tsc.tree_stat(A, f), site_true_Y) + self.assertAlmostEqual(py_site_tsc.tree_stat(A, f), site_true_Y) + + def test_case_odds_and_ends(self): + # Tests having (a) the first site after the first window, and + # (b) no samples having the ancestral state. + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 0.5 + 3 0 1.0 + """) + edges = six.StringIO("""\ + left right parent child + 0.0 0.5 2 0,1 + 0.5 1.0 3 0,1 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.65 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 0 1 -1 + 0 1 2 -1 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + site_tsc = tskit.SiteStatCalculator(ts) + py_site_tsc = PythonSiteStatCalculator(ts) + + # recall that divergence returns the upper triangle + # with nans on the diag in this case + py_div = [[np.nan, py_site_tsc.divergence([0], [1], 0.0, 0.5), np.nan], + [np.nan, py_site_tsc.divergence([0], [1], 0.5, 1.0), np.nan]] + div = site_tsc.divergence([[0], [1]], [0.0, 0.5, 1.0]) + self.assertListEqual(py_div[0], div[0]) + self.assertListEqual(py_div[1], div[1]) + + def test_case_recurrent_muts(self): + # With mutations: + # + # 1.0 6 + # 0.7 / \ 5 + # (0) \ /(6) + # 0.5 (1) 4 4 / 4 + # / / \ / \ / (7|8) + # 0.4 (2) (3) \ (4) 3 / / \ + # / / \ / /(5) / / \ + # 0.0 0 1 2 1 0 2 0 1 2 + # (0.0, 0.2), (0.2, 0.8), (0.8, 1.0) + # genotypes: + # 0 2 0 1 0 1 0 2 3 + site_true_Y = 0 + 1 + 1 + + nodes = six.StringIO("""\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 0.4 + 4 0 0.5 + 5 0 0.7 + 6 0 1.0 + """) + edges = six.StringIO("""\ + left right parent child + 0.2 0.8 3 0,2 + 0.0 0.2 4 1,2 + 0.2 0.8 4 1,3 + 0.8 1.0 4 1,2 + 0.8 1.0 5 0,4 + 0.0 0.2 6 0,4 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.05 0 + 1 0.3 0 + 2 0.9 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 0 1 -1 + 0 0 2 0 + 0 0 0 1 + 0 1 2 -1 + 1 1 1 -1 + 1 2 1 -1 + 2 4 1 -1 + 2 1 2 6 + 2 2 3 6 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + site_tsc = tskit.SiteStatCalculator(ts) + py_site_tsc = PythonSiteStatCalculator(ts) + + # Y3: + site_tsc_Y = site_tsc.Y3([[0], [1], [2]], [0.0, 1.0])[0][0] + py_site_tsc_Y = py_site_tsc.Y3([0], [1], [2], 0.0, 1.0) + self.assertAlmostEqual(site_tsc_Y, site_true_Y) + self.assertAlmostEqual(py_site_tsc_Y, site_true_Y) + + def test_case_2(self): + # Here are the trees: + # t | | | | + # + # 0 --3-- | --3-- | --3-- | --3-- | --3-- + # / | \ | / | \ | / \ | / \ | / \ + # 1 4 | 5 | 4 | 5 | 4 5 | 4 5 | 4 5 + # |\ / \ /| | |\ \ | |\ / | |\ / | |\ /| + # 2 | 6 7 | | | 6 7 | | 6 7 | | 6 7 | | 6 7 | + # | |\ /| | | * \ | | | \ | | | \ | | \ | ... + # 3 | | 8 | | | | 8 * | | 8 | | | 8 | | 8 | + # | |/ \| | | | / | | | / | | | / \ | | / \ | + # 4 | 9 10 | | * 9 10 | | 9 10 | | 9 10 | | 9 10 | + # |/ \ / \| | | \ \ | | \ \ | | \ \ | | \ | + # 5 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 + # + # | 0.0 - 0.1 | 0.1 - 0.2 | 0.2 - 0.4 | 0.4 - 0.5 + # ... continued: + # t | | | | + # + # 0 --3-- | --3-- | --3-- | --3-- | --3-- + # / \ | / \ | / \ | / \ | / | \ + # 1 4 5 | 4 5 | 4 5 | 4 5 | 4 | 5 + # |\ /| | \ /| | \ /| | \ /| | / /| + # 2 | 6 7 | | 6 7 | | 6 7 | | 6 7 | | 6 7 | + # | * * | \ | | * | | | / | | | / | + # 3 ... | 8 | | 8 | | 8 | | | 8 | | | 8 | + # | / \ | | / \ | | * \ | | | \ | | | \ | + # 4 | 9 10 | | 9 10 | | 9 10 | | 9 10 | | 9 10 | + # | / | | / / | | / / | | / / | | / / | + # 5 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 | 0 1 2 + # + # 0.5 - 0.6 | 0.6 - 0.7 | 0.7 - 0.8 | 0.8 - 0.9 | 0.9 - 1.0 + # + # Above, subsequent mutations are backmutations. + + # divergence betw 0 and 1 + branch_true_diversity_01 = 2*(0.6*4 + 0.2*2 + 0.2*5) + # divergence betw 1 and 2 + branch_true_diversity_12 = 2*(0.2*5 + 0.2*2 + 0.3*5 + 0.3*4) + # divergence betw 0 and 2 + branch_true_diversity_02 = 2*(0.2*5 + 0.2*4 + 0.3*5 + 0.1*4 + 0.2*5) + # mean divergence between 0, 1 and 0, 2 + branch_true_mean_diversity = ( + 0 + branch_true_diversity_02 + branch_true_diversity_01 + + branch_true_diversity_12) / 4 + # Y(0;1, 2) + branch_true_Y = 0.2*4 + 0.2*(4+2) + 0.2*4 + 0.2*2 + 0.2*(5+1) + + # site stats + # Y(0;1, 2) + site_true_Y = 1 + + nodes = six.StringIO("""\ + is_sample time population + 1 0.000000 0 + 1 0.000000 0 + 1 0.000000 0 + 0 5.000000 0 + 0 4.000000 0 + 0 4.000000 0 + 0 3.000000 0 + 0 3.000000 0 + 0 2.000000 0 + 0 1.000000 0 + 0 1.000000 0 + """) + edges = six.StringIO("""\ + left right parent child + 0.500000 1.000000 10 1 + 0.000000 0.400000 10 2 + 0.600000 1.000000 9 0 + 0.000000 0.500000 9 1 + 0.800000 1.000000 8 10 + 0.200000 0.800000 8 9,10 + 0.000000 0.200000 8 9 + 0.700000 1.000000 7 8 + 0.000000 0.200000 7 10 + 0.800000 1.000000 6 9 + 0.000000 0.700000 6 8 + 0.400000 1.000000 5 2,7 + 0.100000 0.400000 5 7 + 0.600000 0.900000 4 6 + 0.000000 0.600000 4 0,6 + 0.900000 1.000000 3 4,5,6 + 0.100000 0.900000 3 4,5 + 0.000000 0.100000 3 4,5,7 + """) + sites = six.StringIO("""\ + id position ancestral_state + 0 0.0 0 + 1 0.55 0 + 2 0.75 0 + 3 0.85 0 + """) + mutations = six.StringIO("""\ + site node derived_state parent + 0 0 1 -1 + 0 10 1 -1 + 0 0 0 0 + 1 8 1 -1 + 1 2 1 -1 + 2 8 1 -1 + 2 9 0 5 + """) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False) + branch_tsc = tskit.BranchLengthStatCalculator(ts) + py_branch_tsc = PythonBranchLengthStatCalculator(ts) + site_tsc = tskit.SiteStatCalculator(ts) + py_site_tsc = PythonSiteStatCalculator(ts) + + # divergence between 0 and 1 + A = [[0], [1]] + + def f(x): + return float((x[0] > 0) != (x[1] > 0))/2.0 + + # tree lengths: + self.assertAlmostEqual(py_branch_tsc.tree_length_diversity([0], [1]), + branch_true_diversity_01) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), + branch_true_diversity_01) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), + branch_true_diversity_01) + + # mean divergence between 0, 1 and 0, 2 + A = [[0, 1], [0, 2]] + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]) + (n[0]-x[0])*x[1])/8.0 + + # tree lengths: + self.assertAlmostEqual(py_branch_tsc.tree_length_diversity(A[0], A[1]), + branch_true_mean_diversity) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), + branch_true_mean_diversity) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), + branch_true_mean_diversity) + + # Y-statistic for (0/12) + A = [[0], [1, 2]] + + def f(x): + return float(((x[0] == 1) and (x[1] == 0)) + or ((x[0] == 0) and (x[1] == 2)))/2.0 + + # tree lengths: + self.assertAlmostEqual(py_branch_tsc.Y3([0], [1], [2]), branch_true_Y) + self.assertAlmostEqual(branch_tsc.tree_stat(A, f), branch_true_Y) + self.assertAlmostEqual(py_branch_tsc.tree_stat(A, f), branch_true_Y) + + # sites: + site_tsc_Y = site_tsc.Y3([[0], [1], [2]], [0.0, 1.0])[0][0] + py_site_tsc_Y = py_site_tsc.Y3([0], [1], [2], 0.0, 1.0) + self.assertAlmostEqual(site_tsc_Y, site_true_Y) + self.assertAlmostEqual(py_site_tsc_Y, site_true_Y) + self.assertAlmostEqual(site_tsc.tree_stat(A, f), site_true_Y) + self.assertAlmostEqual(py_site_tsc.tree_stat(A, f), site_true_Y) + + def test_small_sim(self): + orig_ts = msprime.simulate(4, random_seed=self.random_seed, + mutation_rate=0.0, + recombination_rate=3.0) + ts = tsutil.jukes_cantor(orig_ts, num_sites=3, mu=3, + multiple_per_node=True, seed=self.seed) + branch_tsc = tskit.BranchLengthStatCalculator(ts) + py_branch_tsc = PythonBranchLengthStatCalculator(ts) + site_tsc = tskit.SiteStatCalculator(ts) + py_site_tsc = PythonSiteStatCalculator(ts) + + A = [[0], [1], [2]] + self.assertAlmostEqual(branch_tsc.Y3(A, [0.0, 1.0])[0][0], + py_branch_tsc.Y3(*A)) + self.assertAlmostEqual(site_tsc.Y3(A, [0.0, 1.0])[0][0], + py_site_tsc.Y3(*A)) + + A = [[0], [1, 2]] + self.assertAlmostEqual(branch_tsc.Y2(A, [0.0, 1.0])[0][0], + py_branch_tsc.Y2(*A)) + self.assertAlmostEqual(site_tsc.Y2(A, [0.0, 1.0])[0][0], + py_site_tsc.Y2(*A)) + + +class BranchLengthStatsTestCase(GeneralStatsTestCase): + """ + Tests of tree statistic computation. + """ + stat_class = tskit.BranchLengthStatCalculator + py_stat_class = PythonBranchLengthStatCalculator + + def get_ts(self): + for N in [12, 15, 20]: + yield msprime.simulate(N, random_seed=self.random_seed, + recombination_rate=10) + + def check_pairwise_diversity(self, ts): + samples = random.sample(list(ts.samples()), 2) + tsc = tskit.BranchLengthStatCalculator(ts) + py_tsc = PythonBranchLengthStatCalculator(ts) + A_one = [[samples[0]], [samples[1]]] + A_many = [random.sample(list(ts.samples()), 2), + random.sample(list(ts.samples()), 2)] + for A in (A_one, A_many): + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]))/float(n[0]*n[1]) + + self.assertAlmostEqual( + py_tsc.tree_stat(A, f), + py_tsc.tree_length_diversity(A[0], A[1])) + self.assertAlmostEqual( + tsc.tree_stat(A, f), + py_tsc.tree_length_diversity(A[0], A[1])) + + def check_divergence_matrix(self, ts): + # nonoverlapping samples + samples = random.sample(list(ts.samples()), 6) + tsc = tskit.BranchLengthStatCalculator(ts) + py_tsc = PythonBranchLengthStatCalculator(ts) + A = [samples[0:3], samples[3:5], samples[5:6]] + windows = [0.0, ts.sequence_length/2, ts.sequence_length] + ts_values = tsc.divergence(A, windows) + ts_matrix_values = tsc.divergence_matrix(A, windows) + self.assertListEqual([len(x) for x in ts_values], [len(samples), len(samples)]) + assert(len(A[2]) == 1) + self.assertListEqual([x[5] for x in ts_values], [np.nan, np.nan]) + self.assertEqual(len(ts_values), len(ts_matrix_values)) + for w in range(len(ts_values)): + self.assertArrayEqual( + ts_matrix_values[w, :, :], upper_tri_to_matrix(ts_values[w])) + here_values = np.array([[[py_tsc.tree_length_diversity(A[i], A[j], + begin=windows[k], + end=windows[k+1]) + for i in range(len(A))] + for j in range(len(A))] + for k in range(len(windows)-1)]) + for k in range(len(windows)-1): + for i in range(len(A)): + for j in range(len(A)): + if i == j: + if len(A[i]) == 1: + here_values[k, i, i] = np.nan + else: + here_values[k, i, i] /= (len(A[i])-1)/len(A[i]) + else: + here_values[k, j, i] + for k in range(len(windows)-1): + self.assertArrayAlmostEqual(here_values[k], ts_matrix_values[k]) + + def test_errors(self): + ts = msprime.simulate(10, random_seed=self.random_seed, recombination_rate=10) + tsc = tskit.BranchLengthStatCalculator(ts) + self.assertRaises(ValueError, + tsc.divergence, [[0], [11]], [0, ts.sequence_length]) + self.assertRaises(ValueError, + tsc.divergence, [[0], [1]], [0, ts.sequence_length/2]) + self.assertRaises(ValueError, + tsc.divergence, [[0], [1]], [ts.sequence_length/2, + ts.sequence_length]) + self.assertRaises(ValueError, + tsc.divergence, [[0], [1]], [0.0, 2.0, 1.0, + ts.sequence_length]) + # errors for not enough sample_sets + self.assertRaises(ValueError, + tsc.f4, [[0, 1], [2], [3]], [0, ts.sequence_length]) + self.assertRaises(ValueError, + tsc.f3, [[0], [2]], [0, ts.sequence_length]) + self.assertRaises(ValueError, + tsc.f2, [[0], [1], [2]], [0, ts.sequence_length]) + # errors if indices aren't of the right length + self.assertRaises(ValueError, + tsc.Y3_vector, [[0], [1], [2]], [0, ts.sequence_length], + [[0, 1]]) + self.assertRaises(ValueError, + tsc.f4_vector, [[0], [1], [2], [3]], [0, ts.sequence_length], + [[0, 1]]) + self.assertRaises(ValueError, + tsc.f3_vector, [[0], [1], [2], [3]], [0, ts.sequence_length], + [[0, 1]]) + self.assertRaises(ValueError, + tsc.f2_vector, [[0], [1], [2], [3]], [0, ts.sequence_length], + [[0, 1, 2]]) + + def test_windowization(self): + ts = msprime.simulate(10, random_seed=self.random_seed, recombination_rate=100) + samples = random.sample(list(ts.samples()), 2) + tsc = tskit.BranchLengthStatCalculator(ts) + py_tsc = PythonBranchLengthStatCalculator(ts) + A_one = [[samples[0]], [samples[1]]] + A_many = [random.sample(list(ts.samples()), 2), + random.sample(list(ts.samples()), 2)] + some_breaks = list(set([0.0, ts.sequence_length/2, ts.sequence_length] + + random.sample(list(ts.breakpoints()), 5))) + some_breaks.sort() + tiny_breaks = ([(k / 4) * list(ts.breakpoints())[1] for k in range(4)] + + [ts.sequence_length]) + wins = [[0.0, ts.sequence_length], + [0.0, ts.sequence_length/2, ts.sequence_length], + tiny_breaks, + some_breaks] + + with self.assertRaises(ValueError): + tsc.tree_stat_vector(A_one, lambda x: 1.0, + windows=[0.0, 1.0, ts.sequence_length+1.1]) + + for A in (A_one, A_many): + for windows in wins: + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]) + (n[0]-x[0])*x[1])/float(2*n[0]*n[1]) + + def g(x): + return [f(x)] + + tsdiv_v = tsc.tree_stat_vector(A, g, windows) + tsdiv_vx = [x[0] for x in tsdiv_v] + tsdiv = tsc.tree_stat_windowed(A, f, windows) + pydiv = [py_tsc.tree_length_diversity(A[0], A[1], windows[k], + windows[k+1]) + for k in range(len(windows)-1)] + self.assertEqual(len(tsdiv), len(windows)-1) + self.assertListAlmostEqual(tsdiv, pydiv) + self.assertListEqual(tsdiv, tsdiv_vx) + + def test_tree_stat_vector_interface(self): + ts = msprime.simulate(10) + tsc = tskit.BranchLengthStatCalculator(ts) + + def f(x): + return [1.0] + + # Duplicated samples raise an error + self.assertRaises(ValueError, tsc.tree_stat_vector, [[1, 1]], f) + self.assertRaises(ValueError, tsc.tree_stat_vector, [[1], [2, 2]], f) + # Make sure the basic call doesn't throw an exception + tsc.tree_stat_vector([[1, 2]], f) + # Check for bad windows + for bad_start in [-1, 1, 1e-7]: + self.assertRaises( + ValueError, tsc.tree_stat_vector, [[1, 2]], f, + [bad_start, ts.sequence_length]) + for bad_end in [0, ts.sequence_length - 1, ts.sequence_length + 1]: + self.assertRaises( + ValueError, tsc.tree_stat_vector, [[1, 2]], f, + [0, bad_end]) + # Windows must be increasing. + self.assertRaises( + ValueError, tsc.tree_stat_vector, [[1, 2]], f, [0, 1, 1]) + + def test_sfs_interface(self): + ts = msprime.simulate(10) + tsc = tskit.BranchLengthStatCalculator(ts) + + # Duplicated samples raise an error + self.assertRaises(ValueError, tsc.site_frequency_spectrum, [1, 1]) + self.assertRaises(ValueError, tsc.site_frequency_spectrum, []) + self.assertRaises(ValueError, tsc.site_frequency_spectrum, [0, 11]) + # Check for bad windows + for bad_start in [-1, 1, 1e-7]: + self.assertRaises( + ValueError, tsc.site_frequency_spectrum, [1, 2], + [bad_start, ts.sequence_length]) + for bad_end in [0, ts.sequence_length - 1, ts.sequence_length + 1]: + self.assertRaises( + ValueError, tsc.site_frequency_spectrum, [1, 2], + [0, bad_end]) + # Windows must be increasing. + self.assertRaises( + ValueError, tsc.site_frequency_spectrum, [1, 2], [0, 1, 1]) + + def test_branch_general_stats(self): + for ts in self.get_ts(): + self.check_tree_stat_vector(ts) + + def test_branch_f_stats(self): + for ts in self.get_ts(): + self.check_f_stats(ts) + + def test_branch_Y_stats(self): + for ts in self.get_ts(): + self.check_Y_stat(ts) + + def test_diversity(self): + for ts in self.get_ts(): + self.check_pairwise_diversity(ts) + self.check_divergence_matrix(ts) + + def test_branch_sfs(self): + for ts in self.get_ts(): + self.check_sfs(ts) + + +class SiteStatsTestCase(GeneralStatsTestCase): + """ + Tests of site statistic computation. + """ + stat_class = tskit.SiteStatCalculator + py_stat_class = PythonSiteStatCalculator + seed = 23 + + def get_ts(self): + for mut in [0.0, 3.0]: + yield msprime.simulate(20, random_seed=self.random_seed, + mutation_rate=mut, + recombination_rate=3.0) + ts = msprime.simulate(20, random_seed=self.random_seed, + mutation_rate=0.0, + recombination_rate=3.0) + for mpn in [False, True]: + for num_sites in [10, 100]: + mut_ts = tsutil.jukes_cantor(ts, num_sites=num_sites, mu=3, + multiple_per_node=mpn, seed=self.seed) + yield mut_ts + + def check_pairwise_diversity_mutations(self, ts): + py_tsc = PythonSiteStatCalculator(ts) + samples = random.sample(list(ts.samples()), 2) + A = [[samples[0]], [samples[1]]] + n = [len(a) for a in A] + + def f(x): + return float(x[0]*(n[1]-x[1]) + (n[0]-x[0])*x[1])/float(2*n[0]*n[1]) + + self.assertAlmostEqual( + py_tsc.tree_stat(A, f), ts.pairwise_diversity(samples=samples)) + + def test_pairwise_diversity(self): + ts = msprime.simulate(20, random_seed=self.random_seed, recombination_rate=100) + self.check_pairwise_diversity_mutations(ts) + + def test_site_general_stats(self): + for ts in self.get_ts(): + self.check_tree_stat_vector(ts) + + def test_site_f_stats(self): + for ts in self.get_ts(): + self.check_f_stats(ts) + + def test_site_Y_stats(self): + for ts in self.get_ts(): + self.check_Y_stat(ts) + + def test_site_sfs(self): + for ts in self.get_ts(): + self.check_sfs(ts) diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py new file mode 100644 index 0000000000..b0a201ebfd --- /dev/null +++ b/python/tests/test_vcf.py @@ -0,0 +1,241 @@ +""" +Test cases for VCF output in tskit. +""" +from __future__ import print_function +from __future__ import division + +import collections +import math +import os +import tempfile +import unittest + +import msprime +import vcf +import tskit + +# Pysam is not available on windows, so we don't make it mandatory here. +_pysam_imported = False +try: + import pysam + _pysam_imported = True +except ImportError: + pass + + +test_data = [] + + +def setUp(): + Datum = collections.namedtuple( + "Datum", + ["tree_sequence", "ploidy", "contig_id", "vcf_file", "sample_names"]) + L = 100 + for ploidy in [1, 2, 3, 5]: + for contig_id in ["1", "x" * 8]: + for n in [2, 10]: + for rho in [0, 0.5]: + for mu in [0, 1.0]: + ts = msprime.simulate( + n * ploidy, length=L, recombination_rate=rho, + mutation_rate=mu) + fd, file_name = tempfile.mkstemp(prefix="tskit_vcf_") + os.close(fd) + with open(file_name, "w") as f: + ts.write_vcf(f, ploidy, contig_id) + sample_names = ["msp_{}".format(j) for j in range(n)] + test_data.append( + Datum(ts, ploidy, contig_id, file_name, sample_names)) + + +def tearDown(): + for datum in test_data: + os.unlink(datum.vcf_file) + + +def write_vcf(tree_sequence, output, ploidy, contig_id): + """ + Writes a VCF using the sample algorithm as the low level code. + """ + if tree_sequence.get_sample_size() % ploidy != 0: + raise ValueError("Sample size must a multiple of ploidy") + n = tree_sequence.get_sample_size() // ploidy + sample_names = ["msp_{}".format(j) for j in range(n)] + last_pos = 0 + positions = [] + for variant in tree_sequence.variants(): + pos = int(round(variant.position)) + if pos <= last_pos: + pos = last_pos + 1 + positions.append(pos) + last_pos = pos + contig_length = int(math.ceil(tree_sequence.get_sequence_length())) + if len(positions) > 0: + contig_length = max(positions[-1], contig_length) + print("##fileformat=VCFv4.2", file=output) + print("##source=tskit {}".format(tskit.__version__), file=output) + print( + '##FILTER=', + file=output) + print("##contig=".format(contig_id, contig_length), file=output) + print( + '##FORMAT=', + file=output) + print( + "#CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", + "FORMAT", sep="\t", end="", file=output) + for sample_name in sample_names: + print("\t", sample_name, sep="", end="", file=output) + print(file=output) + for variant in tree_sequence.variants(): + pos = positions[variant.index] + print( + contig_id, pos, ".", "A", "T", ".", "PASS", ".", "GT", + sep="\t", end="", file=output) + for j in range(n): + genotype = "|".join( + str(g) for g in + variant.genotypes[j * ploidy: j * ploidy + ploidy]) + print("\t", genotype, end="", sep="", file=output) + print(file=output) + + +@unittest.skip("Skipping until version headers sorted out") +class TestEquality(unittest.TestCase): + """ + Tests if the VCF file produced by the low level code is the + same as one we generate here. + """ + def test_equal(self): + for datum in test_data: + with tempfile.TemporaryFile("w+") as f: + write_vcf(datum.tree_sequence, f, datum.ploidy, datum.contig_id) + f.seek(0) + vcf1 = f.read() + with open(datum.vcf_file) as f: + vcf2 = f.read() + self.assertEqual(vcf1, vcf2) + + +class TestHeaderParsers(unittest.TestCase): + """ + Tests if we can parse the headers with various tools. + """ + def test_pyvcf(self): + for datum in test_data: + reader = vcf.Reader(filename=datum.vcf_file) + self.assertEqual(len(reader.contigs), 1) + contig = reader.contigs[datum.contig_id] + self.assertEqual(contig.id, datum.contig_id) + self.assertGreater(contig.length, 0) + self.assertEqual(len(reader.alts), 0) + self.assertEqual(len(reader.filters), 1) + p = reader.filters["PASS"] + self.assertEqual(p.id, "PASS") + self.assertEqual(len(reader.formats), 1) + f = reader.formats["GT"] + self.assertEqual(f.id, "GT") + self.assertEqual(len(reader.infos), 0) + + @unittest.skipIf(not _pysam_imported, "pysam not available") + def test_pysam(self): + for datum in test_data: + bcf_file = pysam.VariantFile(datum.vcf_file) + self.assertEqual(bcf_file.format, "VCF") + self.assertEqual(bcf_file.version, (4, 2)) + header = bcf_file.header + self.assertEqual(len(header.contigs), 1) + contig = header.contigs[0] + self.assertEqual(contig.name, datum.contig_id) + self.assertGreater(contig.length, 0) + self.assertEqual(len(header.filters), 1) + p = header.filters["PASS"] + self.assertEqual(p.name, "PASS") + self.assertEqual(p.description, "All filters passed") + self.assertEqual(len(header.info), 0) + self.assertEqual(len(header.formats), 1) + fmt = header.formats["GT"] + self.assertEqual(fmt.name, "GT") + self.assertEqual(fmt.number, 1) + self.assertEqual(fmt.type, "String") + self.assertEqual(fmt.description, "Genotype") + self.assertEqual(len(header.samples), len(datum.sample_names)) + for s1, s2 in zip(header.samples, datum.sample_names): + self.assertEqual(s1, s2) + bcf_file.close() + + +@unittest.skipIf(not _pysam_imported, "pysam not available") +class TestRecordsEqual(unittest.TestCase): + """ + Tests where we parse the input using PyVCF and Pysam + """ + def verify_records(self, datum, pyvcf_records, pysam_records): + self.assertEqual(len(pyvcf_records), len(pysam_records)) + for pyvcf_record, pysam_record in zip(pyvcf_records, pysam_records): + self.assertEqual(pyvcf_record.CHROM, pysam_record.chrom) + self.assertEqual(pyvcf_record.POS, pysam_record.pos) + self.assertEqual(pyvcf_record.ID, pysam_record.id) + self.assertEqual(pyvcf_record.ALT, list(pysam_record.alts)) + self.assertEqual(pyvcf_record.REF, pysam_record.ref) + self.assertEqual(pysam_record.filter[0].name, "PASS") + self.assertEqual(pyvcf_record.FORMAT, "GT") + self.assertEqual( + datum.sample_names, list(pysam_record.samples.keys())) + for value in pysam_record.samples.values(): + self.assertEqual(len(value.alleles), datum.ploidy) + for j, sample in enumerate(pyvcf_record.samples): + self.assertEqual(sample.sample, datum.sample_names[j]) + if datum.ploidy > 1: + self.assertTrue(sample.phased) + for call in sample.data.GT.split("|"): + self.assertIn(call, ["0", "1"]) + + def test_all_records(self): + for datum in test_data: + vcf_reader = vcf.Reader(filename=datum.vcf_file) + bcf_file = pysam.VariantFile(datum.vcf_file) + pyvcf_records = list(vcf_reader) + pysam_records = list(bcf_file) + self.verify_records(datum, pyvcf_records, pysam_records) + bcf_file.close() + + +class TestContigLengths(unittest.TestCase): + """ + Tests that we create sensible contig lengths under a variety of conditions. + """ + def setUp(self): + fd, self.temp_file = tempfile.mkstemp(prefix="msprime_vcf_") + os.close(fd) + + def tearDown(self): + os.unlink(self.temp_file) + + def get_contig_length(self, ts): + with open(self.temp_file, "w") as f: + ts.write_vcf(f) + reader = vcf.Reader(filename=self.temp_file) + contig = reader.contigs["1"] + return contig.length + + def test_no_mutations(self): + ts = msprime.simulate(10, length=1) + self.assertEqual(ts.num_mutations, 0) + contig_length = self.get_contig_length(ts) + self.assertEqual(contig_length, 1) + + def test_long_sequence(self): + # Nominal case where we expect the positions to map within the original + # sequence length + ts = msprime.simulate(10, length=100, mutation_rate=0.01, random_seed=3) + self.assertGreater(ts.num_mutations, 0) + contig_length = self.get_contig_length(ts) + self.assertEqual(contig_length, 100) + + def test_short_sequence(self): + # Degenerate case where the positions cannot map into the sequence length + ts = msprime.simulate(10, length=1, mutation_rate=10) + self.assertGreater(ts.num_mutations, 1) + contig_length = self.get_contig_length(ts) + self.assertEqual(contig_length, ts.num_mutations) diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py new file mode 100644 index 0000000000..e88f3f9f87 --- /dev/null +++ b/python/tests/test_wright_fisher.py @@ -0,0 +1,414 @@ +""" +Test various functions using messy tables output by a forwards-time simulator. +""" +from __future__ import print_function +from __future__ import division + +import itertools +import random +import unittest + +import numpy as np +import numpy.testing as nt +import msprime + +import tskit +import tests as tests +import tests.tsutil as tsutil + + +class WrightFisherSimulator(object): + """ + SIMPLE simulation of a bisexual, haploid Wright-Fisher population of size N + for ngens generations, in which each individual survives with probability + survival and only those who die are replaced. If num_loci is None, + the chromosome is 1.0 Morgans long, and the mutation rate is in units of + mutations/Morgan/generation. If num_loci not None, a discrete recombination + model is used where breakpoints are chosen uniformly from 1 to num_loci - 1. + """ + def __init__( + self, N, survival=0.0, seed=None, deep_history=True, debug=False, + initial_generation_samples=False, num_loci=None): + self.N = N + self.num_loci = num_loci + self.survival = survival + self.deep_history = deep_history + self.debug = debug + self.initial_generation_samples = initial_generation_samples + self.seed = seed + self.rng = random.Random(seed) + + def random_breakpoint(self): + if self.num_loci is None: + return min(1.0, max(0.0, 2 * self.rng.random() - 0.5)) + else: + return self.rng.randint(1, self.num_loci - 1) + + def run(self, ngens): + L = 1 + if self.num_loci is not None: + L = self.num_loci + tables = tskit.TableCollection(sequence_length=L) + tables.populations.add_row() + if self.deep_history: + # initial population + init_ts = msprime.simulate( + self.N, recombination_rate=1.0, length=L, random_seed=self.seed) + init_tables = init_ts.dump_tables() + flags = init_tables.nodes.flags + if not self.initial_generation_samples: + flags = np.zeros_like(init_tables.nodes.flags) + tables.nodes.set_columns( + time=init_tables.nodes.time + ngens, + flags=flags) + tables.edges.set_columns( + left=init_tables.edges.left, right=init_tables.edges.right, + parent=init_tables.edges.parent, child=init_tables.edges.child) + else: + flags = 0 + if self.initial_generation_samples: + flags = tskit.NODE_IS_SAMPLE + for _ in range(self.N): + tables.nodes.add_row(flags=flags, time=ngens, population=0) + + pop = list(range(self.N)) + for t in range(ngens - 1, -1, -1): + if self.debug: + print("t:", t) + print("pop:", pop) + + dead = [self.rng.random() > self.survival for k in pop] + # sample these first so that all parents are from the previous gen + new_parents = [ + (self.rng.choice(pop), self.rng.choice(pop)) for k in range(sum(dead))] + k = 0 + if self.debug: + print("Replacing", sum(dead), "individuals.") + for j in range(self.N): + if dead[j]: + # this is: offspring ID, lparent, rparent, breakpoint + offspring = len(tables.nodes) + tables.nodes.add_row(time=t, population=0) + lparent, rparent = new_parents[k] + k += 1 + bp = self.random_breakpoint() + if self.debug: + print("--->", offspring, lparent, rparent, bp) + pop[j] = offspring + if bp > 0.0: + tables.edges.add_row( + left=0.0, right=bp, parent=lparent, child=offspring) + if bp < L: + tables.edges.add_row( + left=bp, right=L, parent=rparent, child=offspring) + + if self.debug: + print("Done! Final pop:") + print(pop) + flags = tables.nodes.flags + flags[pop] = tskit.NODE_IS_SAMPLE + tables.nodes.set_columns( + flags=flags, + time=tables.nodes.time, + population=tables.nodes.population) + return tables + + +def wf_sim( + N, ngens, survival=0.0, deep_history=True, debug=False, seed=None, + initial_generation_samples=False, num_loci=None): + sim = WrightFisherSimulator( + N, survival=survival, deep_history=deep_history, debug=debug, seed=seed, + initial_generation_samples=initial_generation_samples, num_loci=num_loci) + return sim.run(ngens) + + +class TestSimulation(unittest.TestCase): + """ + Tests that the simulations produce the output we expect. + """ + random_seed = 5678 + + def test_non_overlapping_generations(self): + tables = wf_sim(N=10, ngens=10, survival=0.0, seed=self.random_seed) + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 0) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + # All trees should have exactly one root and all internal nodes should + # have arity > 1 + for tree in ts.trees(): + self.assertEqual(tree.num_roots, 1) + leaves = set(tree.leaves(tree.root)) + self.assertEqual(leaves, set(ts.samples())) + for u in tree.nodes(): + if tree.is_internal(u): + self.assertGreater(len(tree.children(u)), 1) + + def test_overlapping_generations(self): + tables = wf_sim(N=30, ngens=10, survival=0.85, seed=self.random_seed) + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 0) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + for tree in ts.trees(): + self.assertEqual(tree.num_roots, 1) + + def test_one_generation_no_deep_history(self): + N = 20 + tables = wf_sim(N=N, ngens=1, deep_history=False, seed=self.random_seed) + self.assertEqual(tables.nodes.num_rows, 2 * N) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 0) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + ts = tables.tree_sequence() + for tree in ts.trees(): + all_samples = set() + for root in tree.roots: + root_samples = set(tree.samples(root)) + self.assertEqual(len(root_samples & all_samples), 0) + all_samples |= root_samples + self.assertEqual(all_samples, set(ts.samples())) + + def test_many_generations_no_deep_history(self): + N = 10 + ngens = 100 + tables = wf_sim(N=N, ngens=ngens, deep_history=False, seed=self.random_seed) + self.assertEqual(tables.nodes.num_rows, N * (ngens + 1)) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 0) + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + ts = tables.tree_sequence() + # We are assuming that everything has coalesced and we have single-root trees + for tree in ts.trees(): + self.assertEqual(tree.num_roots, 1) + + def test_with_mutations(self): + N = 10 + ngens = 100 + tables = wf_sim(N=N, ngens=ngens, deep_history=False, seed=self.random_seed) + tables.sort() + ts = tables.tree_sequence() + ts = tsutil.jukes_cantor(ts, 10, 0.1, seed=self.random_seed) + tables = ts.tables + self.assertGreater(tables.sites.num_rows, 0) + self.assertGreater(tables.mutations.num_rows, 0) + samples = np.where( + tables.nodes.flags == tskit.NODE_IS_SAMPLE)[0].astype(np.int32) + tables.sort() + tables.simplify(samples) + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + self.assertGreater(tables.sites.num_rows, 0) + self.assertGreater(tables.mutations.num_rows, 0) + ts = tables.tree_sequence() + self.assertEqual(ts.sample_size, N) + for hap in ts.haplotypes(): + self.assertEqual(len(hap), ts.num_sites) + + def test_with_recurrent_mutations(self): + # actually with only ONE site, at 0.0 + N = 10 + ngens = 100 + tables = wf_sim(N=N, ngens=ngens, deep_history=False, seed=self.random_seed) + tables.sort() + ts = tables.tree_sequence() + ts = tsutil.jukes_cantor(ts, 1, 10, seed=self.random_seed) + tables = ts.tables + self.assertEqual(tables.sites.num_rows, 1) + self.assertGreater(tables.mutations.num_rows, 0) + # before simplify + for h in ts.haplotypes(): + self.assertEqual(len(h), 1) + # after simplify + tables.sort() + tables.simplify() + self.assertGreater(tables.nodes.num_rows, 0) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 1) + self.assertGreater(tables.mutations.num_rows, 0) + ts = tables.tree_sequence() + self.assertEqual(ts.sample_size, N) + for hap in ts.haplotypes(): + self.assertEqual(len(hap), ts.num_sites) + + +class TestIncrementalBuild(unittest.TestCase): + """ + Tests for incrementally building a tree sequence from forward time + simulations. + """ + + +class TestSimplify(unittest.TestCase): + """ + Tests for simplify on cases generated by the Wright-Fisher simulator. + """ + def assertArrayEqual(self, x, y): + nt.assert_equal(x, y) + + def assertTreeSequencesEqual(self, ts1, ts2): + self.assertEqual(list(ts1.samples()), list(ts2.samples())) + self.assertEqual(ts1.sequence_length, ts2.sequence_length) + ts1_tables = ts1.dump_tables() + ts2_tables = ts2.dump_tables() + # print("compare") + # print(ts1_tables.nodes) + # print(ts2_tables.nodes) + self.assertEqual(ts1_tables.nodes, ts2_tables.nodes) + self.assertEqual(ts1_tables.edges, ts2_tables.edges) + self.assertEqual(ts1_tables.sites, ts2_tables.sites) + self.assertEqual(ts1_tables.mutations, ts2_tables.mutations) + + def get_wf_sims(self, seed): + """ + Returns an iterator of example tree sequences produced by the WF simulator. + """ + for N in [5, 10, 20]: + for surv in [0.0, 0.5, 0.9]: + for mut in [0.01, 1.0]: + for nloci in [1, 2, 3]: + tables = wf_sim(N=N, ngens=N, survival=surv, seed=seed) + tables.sort() + ts = tables.tree_sequence() + ts = tsutil.jukes_cantor(ts, num_sites=nloci, mu=mut, seed=seed) + self.verify_simulation(ts, ngens=N) + yield ts + + def verify_simulation(self, ts, ngens): + """ + Verify that in the full set of returned tables there is parentage + information for every individual, except those initially present. + """ + tables = ts.dump_tables() + for u in range(tables.nodes.num_rows): + if tables.nodes.time[u] <= ngens: + lefts = [] + rights = [] + k = 0 + for edge in ts.edges(): + if u == edge.child: + lefts.append(edge.left) + rights.append(edge.right) + k += 1 + lefts.sort() + rights.sort() + self.assertEqual(lefts[0], 0.0) + self.assertEqual(rights[-1], 1.0) + for k in range(len(lefts) - 1): + self.assertEqual(lefts[k + 1], rights[k]) + + def verify_simplify(self, ts, new_ts, samples, node_map): + """ + Check that trees in `ts` match `new_ts` using the specified node_map. + Modified from `verify_simplify_topology`. Also check that the `parent` + column in the MutationTable is correct. + """ + # check trees agree at these points + locs = [random.random() for _ in range(20)] + locs += random.sample(list(ts.breakpoints())[:-1], min(20, ts.num_trees)) + locs.sort() + old_trees = ts.trees() + new_trees = new_ts.trees() + old_right = -1 + new_right = -1 + for loc in locs: + while old_right <= loc: + old_tree = next(old_trees) + old_left, old_right = old_tree.get_interval() + assert old_left <= loc < old_right + while new_right <= loc: + new_tree = next(new_trees) + new_left, new_right = new_tree.get_interval() + assert new_left <= loc < new_right + # print("comparing trees") + # print("interval:", old_tree.interval) + # print(old_tree.draw(format="unicode")) + # print("interval:", new_tree.interval) + # print(new_tree.draw(format="unicode")) + pairs = itertools.islice(itertools.combinations(samples, 2), 500) + for pair in pairs: + mapped_pair = [node_map[u] for u in pair] + mrca1 = old_tree.get_mrca(*pair) + self.assertNotEqual(mrca1, tskit.NULL) + mrca2 = new_tree.get_mrca(*mapped_pair) + self.assertNotEqual(mrca2, tskit.NULL) + self.assertEqual(node_map[mrca1], mrca2) + mut_parent = tsutil.compute_mutation_parent(ts=ts) + self.assertArrayEqual(mut_parent, ts.tables.mutations.parent) + + def verify_haplotypes(self, ts, samples): + """ + Check that haplotypes are unchanged by simplify. + """ + sub_ts, node_map = ts.simplify( + samples, map_nodes=True, filter_zero_mutation_sites=False) + # Sites tables should be equal + self.assertEqual(ts.tables.sites, sub_ts.tables.sites) + sub_haplotypes = dict(zip(sub_ts.samples(), sub_ts.haplotypes())) + all_haplotypes = dict(zip(ts.samples(), ts.haplotypes())) + mapped_ids = [] + for node_id, h in all_haplotypes.items(): + mapped_node_id = node_map[node_id] + if mapped_node_id in sub_haplotypes: + self.assertEqual(h, sub_haplotypes[mapped_node_id]) + mapped_ids.append(mapped_node_id) + self.assertEqual(sorted(mapped_ids), sorted(sub_ts.samples())) + + def test_simplify(self): + # check that simplify(big set) -> simplify(subset) equals simplify(subset) + seed = 23 + random.seed(seed) + for ts in self.get_wf_sims(seed=seed): + s = tests.Simplifier(ts, ts.samples()) + py_full_ts, py_full_map = s.simplify() + full_ts, full_map = ts.simplify(ts.samples(), map_nodes=True) + self.assertTrue(all(py_full_map == full_map)) + self.assertTreeSequencesEqual(full_ts, py_full_ts) + + for nsamples in [2, 5, 10]: + sub_samples = random.sample( + list(ts.samples()), min(nsamples, ts.sample_size)) + s = tests.Simplifier(ts, sub_samples) + py_small_ts, py_small_map = s.simplify() + small_ts, small_map = ts.simplify(samples=sub_samples, map_nodes=True) + self.assertTreeSequencesEqual(small_ts, py_small_ts) + self.verify_simplify(ts, small_ts, sub_samples, small_map) + self.verify_haplotypes(ts, samples=sub_samples) + + def test_simplify_tables(self): + seed = 71 + for ts in self.get_wf_sims(seed=seed): + for nsamples in [2, 5, 10]: + tables = ts.dump_tables() + sub_samples = random.sample( + list(ts.samples()), min(nsamples, ts.num_samples)) + node_map = tables.simplify(samples=sub_samples) + small_ts = tables.tree_sequence() + other_tables = small_ts.dump_tables() + tables.provenances.clear() + other_tables.provenances.clear() + self.assertEqual(tables, other_tables) + self.verify_simplify(ts, small_ts, sub_samples, node_map) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py new file mode 100644 index 0000000000..df021b66bc --- /dev/null +++ b/python/tests/tsutil.py @@ -0,0 +1,715 @@ +""" +A collection of utilities to edit and construct tree sequences. +""" +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import division + +import json +import random + +import numpy as np + +import tskit.provenance as provenance +import tskit + + +def add_provenance(provenance_table, method_name): + d = provenance.get_provenance_dict({"command": "tsutil.{}".format(method_name)}) + provenance_table.add_row(json.dumps(d)) + + +def subsample_sites(ts, num_sites): + """ + Returns a copy of the specified tree sequence with a random subsample of the + specified number of sites. + """ + t = ts.dump_tables() + t.sites.reset() + t.mutations.reset() + sites_to_keep = set(random.sample(list(range(ts.num_sites)), num_sites)) + for site in ts.sites(): + if site.id in sites_to_keep: + site_id = len(t.sites) + t.sites.add_row( + position=site.position, ancestral_state=site.ancestral_state) + for mutation in site.mutations: + t.mutations.add_row( + site=site_id, derived_state=mutation.derived_state, + node=mutation.node, parent=mutation.parent) + add_provenance(t.provenances, "subsample_sites") + return t.tree_sequence() + + +def decapitate(ts, num_edges): + """ + Returns a copy of the specified tree sequence in which the specified number of + edges have been retained. + """ + t = ts.dump_tables() + t.edges.set_columns( + left=t.edges.left[:num_edges], right=t.edges.right[:num_edges], + parent=t.edges.parent[:num_edges], child=t.edges.child[:num_edges]) + add_provenance(t.provenances, "decapitate") + return t.tree_sequence() + + +def insert_branch_mutations(ts, mutations_per_branch=1): + """ + Returns a copy of the specified tree sequence with a mutation on every branch + in every tree. + """ + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for tree in ts.trees(): + site = tables.sites.add_row(position=tree.interval[0], ancestral_state='0') + for root in tree.roots: + state = {root: 0} + mutation = {root: -1} + stack = [root] + while len(stack) > 0: + u = stack.pop() + stack.extend(tree.children(u)) + v = tree.parent(u) + if v != tskit.NULL: + state[u] = state[v] + parent = mutation[v] + for j in range(mutations_per_branch): + state[u] = (state[u] + 1) % 2 + mutation[u] = tables.mutations.add_row( + site=site, node=u, derived_state=str(state[u]), + parent=parent) + parent = mutation[u] + add_provenance(tables.provenances, "insert_branch_mutations") + return tables.tree_sequence() + + +def insert_branch_sites(ts): + """ + Returns a copy of the specified tree sequence with a site on every branch + of every tree. + """ + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for tree in ts.trees(): + left, right = tree.interval + delta = (right - left) / len(list(tree.nodes())) + x = left + for u in tree.nodes(): + if tree.parent(u) != tskit.NULL: + site = tables.sites.add_row(position=x, ancestral_state='0') + tables.mutations.add_row(site=site, node=u, derived_state='1') + x += delta + add_provenance(tables.provenances, "insert_branch_sites") + return tables.tree_sequence() + + +def insert_multichar_mutations(ts, seed=1, max_len=10): + """ + Returns a copy of the specified tree sequence with multiple chararacter + mutations on a randomly chosen branch in every tree. + """ + rng = random.Random(seed) + letters = ["A", "C", "T", "G"] + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for tree in ts.trees(): + ancestral_state = rng.choice(letters) * rng.randint(0, max_len) + site = tables.sites.add_row( + position=tree.interval[0], ancestral_state=ancestral_state) + nodes = list(tree.nodes()) + nodes.remove(tree.root) + u = rng.choice(nodes) + derived_state = ancestral_state + while ancestral_state == derived_state: + derived_state = rng.choice(letters) * rng.randint(0, max_len) + tables.mutations.add_row(site=site, node=u, derived_state=derived_state) + add_provenance(tables.provenances, "insert_multichar_mutations") + return tables.tree_sequence() + + +def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1): + """ + Takes random contiguous subsets of the samples an assigns them to individuals. + Also creates random locations in variable dimensions in the unit interval. + """ + rng = random.Random(seed) + samples = np.array(ts.samples(), dtype=int) + j = 0 + tables = ts.dump_tables() + tables.individuals.clear() + individual = tables.nodes.individual[:] + individual[:] = tskit.NULL + while j < len(samples): + ploidy = rng.randint(0, max_ploidy) + nodes = samples[j: min(j + ploidy, len(samples))] + dimension = rng.randint(0, max_dimension) + location = [rng.random() for _ in range(dimension)] + ind_id = tables.individuals.add_row(location=location) + individual[nodes] = ind_id + j += ploidy + tables.nodes.individual = individual + return tables.tree_sequence() + + +def permute_nodes(ts, node_map): + """ + Returns a copy of the specified tree sequence such that the nodes are + permuted according to the specified map. + """ + tables = ts.dump_tables() + tables.nodes.clear() + tables.edges.clear() + tables.mutations.clear() + # Mapping from nodes in the new tree sequence back to nodes in the original + reverse_map = [0 for _ in node_map] + for j in range(ts.num_nodes): + reverse_map[node_map[j]] = j + old_nodes = list(ts.nodes()) + for j in range(ts.num_nodes): + old_node = old_nodes[reverse_map[j]] + tables.nodes.add_row( + flags=old_node.flags, metadata=old_node.metadata, + population=old_node.population, time=old_node.time) + for edge in ts.edges(): + tables.edges.add_row( + left=edge.left, right=edge.right, parent=node_map[edge.parent], + child=node_map[edge.child]) + for site in ts.sites(): + for mutation in site.mutations: + tables.mutations.add_row( + site=site.id, derived_state=mutation.derived_state, + node=node_map[mutation.node], metadata=mutation.metadata) + tables.sort() + add_provenance(tables.provenances, "permute_nodes") + return tables.tree_sequence() + + +def insert_redundant_breakpoints(ts): + """ + Builds a new tree sequence containing redundant breakpoints. + """ + tables = ts.dump_tables() + tables.edges.reset() + for r in ts.edges(): + x = r.left + (r.right - r.left) / 2 + tables.edges.add_row(left=r.left, right=x, child=r.child, parent=r.parent) + tables.edges.add_row(left=x, right=r.right, child=r.child, parent=r.parent) + add_provenance(tables.provenances, "insert_redundant_breakpoints") + new_ts = tables.tree_sequence() + assert new_ts.num_edges == 2 * ts.num_edges + return new_ts + + +def single_childify(ts): + """ + Builds a new equivalent tree sequence which contains an extra node in the + middle of all exising branches. + """ + tables = ts.dump_tables() + + time = tables.nodes.time[:] + tables.edges.reset() + for edge in ts.edges(): + # Insert a new node in between the parent and child. + t = time[edge.child] + (time[edge.parent] - time[edge.child]) / 2 + u = tables.nodes.add_row(time=t) + tables.edges.add_row( + left=edge.left, right=edge.right, parent=u, child=edge.child) + tables.edges.add_row( + left=edge.left, right=edge.right, parent=edge.parent, child=u) + tables.sort() + add_provenance(tables.provenances, "insert_redundant_breakpoints") + return tables.tree_sequence() + + +def add_random_metadata(ts, seed=1, max_length=10): + """ + Returns a copy of the specified tree sequence with random metadata assigned + to the nodes, sites and mutations. + """ + tables = ts.dump_tables() + np.random.seed(seed) + + length = np.random.randint(0, max_length, ts.num_nodes) + offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) + # Older versions of numpy didn't have a dtype argument for randint, so + # must use astype instead. + metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + nodes = tables.nodes + nodes.set_columns( + flags=nodes.flags, population=nodes.population, time=nodes.time, + metadata_offset=offset, metadata=metadata, + individual=nodes.individual) + + length = np.random.randint(0, max_length, ts.num_sites) + offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) + metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + sites = tables.sites + sites.set_columns( + position=sites.position, + ancestral_state=sites.ancestral_state, + ancestral_state_offset=sites.ancestral_state_offset, + metadata_offset=offset, metadata=metadata) + + length = np.random.randint(0, max_length, ts.num_mutations) + offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) + metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + mutations = tables.mutations + mutations.set_columns( + site=mutations.site, + node=mutations.node, + parent=mutations.parent, + derived_state=mutations.derived_state, + derived_state_offset=mutations.derived_state_offset, + metadata_offset=offset, metadata=metadata) + + length = np.random.randint(0, max_length, ts.num_individuals) + offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) + metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + individuals = tables.individuals + individuals.set_columns( + flags=individuals.flags, + location=individuals.location, + location_offset=individuals.location_offset, + metadata_offset=offset, metadata=metadata) + + length = np.random.randint(0, max_length, ts.num_populations) + offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) + metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + populations = tables.populations + populations.set_columns(metadata_offset=offset, metadata=metadata) + + add_provenance(tables.provenances, "add_random_metadata") + ts = tables.tree_sequence() + return ts + + +def jiggle_samples(ts): + """ + Returns a copy of the specified tree sequence with the sample nodes switched + around. The first n / 2 existing samples become non samples, and the last + n / 2 node become samples. + """ + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + oldest_parent = tables.edges.parent[-1] + n = ts.sample_size + flags[:n // 2] = 0 + flags[oldest_parent - n // 2: oldest_parent] = 1 + nodes.set_columns(flags, nodes.time) + add_provenance(tables.provenances, "jiggle_samples") + return tables.tree_sequence() + + +def generate_site_mutations(tree, position, mu, site_table, mutation_table, + multiple_per_node=True): + """ + Generates mutations for the site at the specified position on the specified + tree. Mutations happen at rate mu along each branch. The site and mutation + information are recorded in the specified tables. Note that this records + more than one mutation per edge. + """ + assert tree.interval[0] <= position < tree.interval[1] + states = {"A", "C", "G", "T"} + state = random.choice(list(states)) + site_table.add_row(position, state) + site = site_table.num_rows - 1 + stack = [(tree.root, state, tskit.NULL)] + while len(stack) != 0: + u, state, parent = stack.pop() + if u != tree.root: + branch_length = tree.branch_length(u) + x = random.expovariate(mu) + new_state = state + while x < branch_length: + new_state = random.choice(list(states - set(state))) + if multiple_per_node and (state != new_state): + mutation_table.add_row(site, u, new_state, parent) + parent = mutation_table.num_rows - 1 + state = new_state + x += random.expovariate(mu) + else: + if (not multiple_per_node) and (state != new_state): + mutation_table.add_row(site, u, new_state, parent) + parent = mutation_table.num_rows - 1 + state = new_state + stack.extend(reversed([(v, state, parent) for v in tree.children(u)])) + + +def jukes_cantor(ts, num_sites, mu, multiple_per_node=True, seed=None): + """ + Returns a copy of the specified tree sequence with Jukes-Cantor mutations + applied at the specfied rate at the specifed number of sites. Site positions + are chosen uniformly. + """ + random.seed(seed) + positions = [ts.sequence_length * random.random() for _ in range(num_sites)] + positions.sort() + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + trees = ts.trees() + t = next(trees) + for position in positions: + while position >= t.interval[1]: + t = next(trees) + generate_site_mutations(t, position, mu, tables.sites, tables.mutations, + multiple_per_node=multiple_per_node) + add_provenance(tables.provenances, "jukes_cantor") + new_ts = tables.tree_sequence() + return new_ts + + +def compute_mutation_parent(ts): + """ + Compute the `parent` column of a MutationTable. Correct computation uses + topological information in the nodes and edges, as well as the fact that + each mutation must be listed after the mutation on whose background it + occurred (i.e., its parent). + + :param TreeSequence ts: The tree sequence to compute for. Need not + have a valid mutation parent column. + """ + mutation_parent = np.zeros(ts.num_mutations, dtype=np.int32) - 1 + # Maps nodes to the bottom mutation on each branch + bottom_mutation = np.zeros(ts.num_nodes, dtype=np.int32) - 1 + for tree in ts.trees(): + for site in tree.sites(): + # Go forward through the mutations creating a mapping from the + # mutations to the nodes. If we see more than one mutation + # at a node, then these must be parents since we're assuming + # they are in order. + for mutation in site.mutations: + if bottom_mutation[mutation.node] != tskit.NULL: + mutation_parent[mutation.id] = bottom_mutation[mutation.node] + bottom_mutation[mutation.node] = mutation.id + # There's no point in checking the first mutation, since this cannot + # have a parent. + for mutation in site.mutations[1:]: + if mutation_parent[mutation.id] == tskit.NULL: + v = tree.parent(mutation.node) + # Traverse upwards until we find a another mutation or root. + while v != tskit.NULL and bottom_mutation[v] == tskit.NULL: + v = tree.parent(v) + if v != tskit.NULL: + mutation_parent[mutation.id] = bottom_mutation[v] + # Reset the maps for the next site. + for mutation in site.mutations: + bottom_mutation[mutation.node] = tskit.NULL + assert np.all(bottom_mutation == -1) + return mutation_parent + + +def algorithm_T(ts): + """ + Simple implementation of algorithm T from the PLOS paper, taking into + account tree sequences with gaps and other complexities. + """ + sequence_length = ts.sequence_length + edges = list(ts.edges()) + M = len(edges) + time = [ts.node(edge.parent).time for edge in edges] + in_order = sorted(range(M), key=lambda j: ( + edges[j].left, time[j], edges[j].parent, edges[j].child)) + out_order = sorted(range(M), key=lambda j: ( + edges[j].right, -time[j], -edges[j].parent, -edges[j].child)) + j = 0 + k = 0 + left = 0 + parent = [-1 for _ in range(ts.num_nodes)] + while j < M or left < sequence_length: + while k < M and edges[out_order[k]].right == left: + edge = edges[out_order[k]] + parent[edge.child] = -1 + k += 1 + while j < M and edges[in_order[j]].left == left: + edge = edges[in_order[j]] + parent[edge.child] = edge.parent + j += 1 + right = sequence_length + if j < M: + right = min(right, edges[in_order[j]].left) + if k < M: + right = min(right, edges[out_order[k]].right) + yield (left, right), parent + left = right + + +class LinkedTree(object): + """ + Straightforward implementation of the quintuply linked tree for developing + and testing the sample lists feature. + + NOTE: The interface is pretty awkward; it's not intended for anything other + than testing. + """ + def __init__(self, tree_sequence, tracked_samples=None): + self.tree_sequence = tree_sequence + num_nodes = tree_sequence.num_nodes + # Quintuply linked tree. + self.parent = [-1 for _ in range(num_nodes)] + self.left_sib = [-1 for _ in range(num_nodes)] + self.right_sib = [-1 for _ in range(num_nodes)] + self.left_child = [-1 for _ in range(num_nodes)] + self.right_child = [-1 for _ in range(num_nodes)] + self.left_sample = [-1 for _ in range(num_nodes)] + self.right_sample = [-1 for _ in range(num_nodes)] + # This is too long, but it's convenient for printing. + self.next_sample = [-1 for _ in range(num_nodes)] + + self.sample_index_map = [-1 for _ in range(num_nodes)] + samples = tracked_samples + if tracked_samples is None: + samples = list(tree_sequence.samples()) + for j in range(len(samples)): + u = samples[j] + self.sample_index_map[u] = j + self.left_sample[u] = j + self.right_sample[u] = j + + def __str__(self): + fmt = "{:<5}{:>8}{:>8}{:>8}{:>8}{:>8}{:>8}{:>8}{:>8}\n" + s = fmt.format( + "node", "parent", "lsib", "rsib", "lchild", "rchild", + "nsamp", "lsamp", "rsamp") + for u in range(self.tree_sequence.num_nodes): + s += fmt.format( + u, self.parent[u], + self.left_sib[u], self.right_sib[u], + self.left_child[u], self.right_child[u], + self.next_sample[u], self.left_sample[u], self.right_sample[u]) + # Strip off trailing newline + return s[:-1] + + def remove_edge(self, edge): + p = edge.parent + c = edge.child + lsib = self.left_sib[c] + rsib = self.right_sib[c] + if lsib == -1: + self.left_child[p] = rsib + else: + self.right_sib[lsib] = rsib + if rsib == -1: + self.right_child[p] = lsib + else: + self.left_sib[rsib] = lsib + self.parent[c] = -1 + self.left_sib[c] = -1 + self.right_sib[c] = -1 + + def insert_edge(self, edge): + p = edge.parent + c = edge.child + assert self.parent[c] == -1, "contradictory edges" + self.parent[c] = p + u = self.right_child[p] + if u == -1: + self.left_child[p] = c + self.left_sib[c] = -1 + self.right_sib[c] = -1 + else: + self.right_sib[u] = c + self.left_sib[c] = u + self.right_sib[c] = -1 + self.right_child[p] = c + + def update_sample_list(self, parent): + # This can surely be done more efficiently and elegantly. We are iterating + # up the tree and iterating over all the siblings of the nodes we visit, + # rebuilding the links as we go. This results in visiting the same nodes + # over again, which if we have nodes with many siblings will surely be + # expensive. Another consequence of the current approach is that the + # next pointer contains an arbitrary value for the rightmost sample of + # every root. This should point to NULL ideally, but it's quite tricky + # to do in practise. It's easier to have a slightly uglier iteration + # over samples. + # + # In the future it would be good have a more efficient version of this + # algorithm using next and prev pointers that we keep up to date at all + # times, and which we use to patch the lists together more efficiently. + u = parent + while u != -1: + sample_index = self.sample_index_map[u] + if sample_index != -1: + self.right_sample[u] = self.left_sample[u] + else: + self.right_sample[u] = -1 + self.left_sample[u] = -1 + v = self.left_child[u] + while v != -1: + if self.left_sample[v] != -1: + assert self.right_sample[v] != -1 + if self.left_sample[u] == -1: + self.left_sample[u] = self.left_sample[v] + self.right_sample[u] = self.right_sample[v] + else: + self.next_sample[self.right_sample[u]] = self.left_sample[v] + self.right_sample[u] = self.right_sample[v] + v = self.right_sib[v] + u = self.parent[u] + + def sample_lists(self): + """ + Iterate over the the trees in this tree sequence, yielding the (left, right) + interval tuples. The tree state is maintained internally. + + See note above about the cruddiness of this interface. + """ + ts = self.tree_sequence + sequence_length = ts.sequence_length + edges = list(ts.edges()) + M = len(edges) + time = [ts.node(edge.parent).time for edge in edges] + in_order = sorted(range(M), key=lambda j: ( + edges[j].left, time[j], edges[j].parent, edges[j].child)) + out_order = sorted(range(M), key=lambda j: ( + edges[j].right, -time[j], -edges[j].parent, -edges[j].child)) + j = 0 + k = 0 + left = 0 + + while j < M or left < sequence_length: + while k < M and edges[out_order[k]].right == left: + edge = edges[out_order[k]] + self.remove_edge(edge) + self.update_sample_list(edge.parent) + k += 1 + while j < M and edges[in_order[j]].left == left: + edge = edges[in_order[j]] + self.insert_edge(edge) + self.update_sample_list(edge.parent) + j += 1 + right = sequence_length + if j < M: + right = min(right, edges[in_order[j]].left) + if k < M: + right = min(right, edges[out_order[k]].right) + yield left, right + left = right + + +def mean_descendants(ts, reference_sets): + """ + Returns the mean number of nodes from the specified reference sets + where the node is ancestral to at least one of the reference nodes. Returns a + ``(ts.num_nodes, len(reference_sets))`` dimensional numpy array. + """ + # Check the inputs (could be done more efficiently here) + all_reference_nodes = set() + for reference_set in reference_sets: + U = set(reference_set) + if len(U) != len(reference_set): + raise ValueError("Cannot have duplicate values within set") + if len(all_reference_nodes & U) != 0: + raise ValueError("Sample sets must be disjoint") + all_reference_nodes |= U + + K = len(reference_sets) + C = np.zeros((ts.num_nodes, K)) + parent = np.zeros(ts.num_nodes, dtype=int) - 1 + # The -1th element of ref_count is for all nodes in the reference set. + ref_count = np.zeros((ts.num_nodes, K + 1), dtype=int) + last_update = np.zeros(ts.num_nodes) + total_length = np.zeros(ts.num_nodes) + + def update_counts(edge, sign): + # Update the counts and statistics for a given node. Before we change the + # node counts in the given direction, check to see if we need to update + # statistics for that node. When a node count changes, we add the + # accumulated statistic value for the span since that node was last updated. + v = edge.parent + while v != -1: + if last_update[v] != left: + if ref_count[v, K] > 0: + length = left - last_update[v] + C[v] += length * ref_count[v, :K] + total_length[v] += length + last_update[v] = left + ref_count[v] += sign * ref_count[edge.child] + v = parent[v] + + # Set the intitial conditions. + for j in range(K): + ref_count[reference_sets[j], j] = 1 + ref_count[ts.samples(), K] = 1 + + for (left, right), edges_out, edges_in in ts.edge_diffs(): + for edge in edges_out: + parent[edge.child] = -1 + update_counts(edge, -1) + for edge in edges_in: + parent[edge.child] = edge.parent + update_counts(edge, +1) + + # Finally, add the stats for the last tree and divide by the total + # length that each node was an ancestor to > 0 samples. + for v in range(ts.num_nodes): + if ref_count[v, K] > 0: + length = ts.sequence_length - last_update[v] + total_length[v] += length + C[v] += length * ref_count[v, :K] + if total_length[v] != 0: + C[v] /= total_length[v] + return C + + +def genealogical_nearest_neighbours(ts, focal, reference_sets): + + reference_set_map = np.zeros(ts.num_nodes, dtype=int) - 1 + for k, reference_set in enumerate(reference_sets): + for u in reference_set: + if reference_set_map[u] != -1: + raise ValueError("Duplicate value in reference sets") + reference_set_map[u] = k + + K = len(reference_sets) + A = np.zeros((len(focal), K)) + L = np.zeros(len(focal)) + parent = np.zeros(ts.num_nodes, dtype=int) - 1 + sample_count = np.zeros((ts.num_nodes, K), dtype=int) + + # Set the intitial conditions. + for j in range(K): + sample_count[reference_sets[j], j] = 1 + + for (left, right), edges_out, edges_in in ts.edge_diffs(): + for edge in edges_out: + parent[edge.child] = -1 + v = edge.parent + while v != -1: + sample_count[v] -= sample_count[edge.child] + v = parent[v] + for edge in edges_in: + parent[edge.child] = edge.parent + v = edge.parent + while v != -1: + sample_count[v] += sample_count[edge.child] + v = parent[v] + + # Process this tree. + for j, u in enumerate(focal): + focal_reference_set = reference_set_map[u] + p = parent[u] + while p != tskit.NULL: + total = np.sum(sample_count[p]) + if total > 1: + break + p = parent[p] + if p != tskit.NULL: + length = right - left + L[j] += length + scale = length / (total - int(focal_reference_set != -1)) + for k, reference_set in enumerate(reference_sets): + n = sample_count[p, k] - int(focal_reference_set == k) + A[j, k] += n * scale + + # Avoid division by zero + L[L == 0] = 1 + A /= L.reshape((len(focal), 1)) + return A diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py new file mode 100644 index 0000000000..30dc101945 --- /dev/null +++ b/python/tskit/__init__.py @@ -0,0 +1,15 @@ + +from __future__ import print_function +from __future__ import division + +import _tskit +FORWARD = _tskit.FORWARD +REVERSE = _tskit.REVERSE + +from tskit.provenance import __version__ # NOQA +from tskit.provenance import validate_provenance # NOQA +from tskit.formats import * # NOQA +from tskit.trees import * # NOQA +from tskit.tables import * # NOQA +from tskit.stats import * # NOQA +from tskit.exceptions import * # NOQA diff --git a/python/tskit/_version.py b/python/tskit/_version.py new file mode 100644 index 0000000000..3616cd1a3a --- /dev/null +++ b/python/tskit/_version.py @@ -0,0 +1,2 @@ +# Definitive location for the version number. +tskit_version = "0.1.0a1" diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py new file mode 100644 index 0000000000..a0aa3dd7f6 --- /dev/null +++ b/python/tskit/drawing.py @@ -0,0 +1,395 @@ +""" +Module responsible for visualisations. +""" +from __future__ import division +from __future__ import print_function + +import array +import collections +import sys + +try: + import svgwrite + _svgwrite_imported = True +except ImportError: + _svgwrite_imported = False + +IS_PY2 = sys.version_info[0] < 3 + +NULL_NODE = -1 + + +def draw_tree( + tree, width=None, height=None, node_labels=None, node_colours=None, + mutation_labels=None, mutation_colours=None, format=None): + # See tree.draw() for documentation on these arguments. + if format is None: + format = "SVG" + fmt = format.lower() + supported_formats = ["svg", "ascii", "unicode"] + if fmt not in supported_formats: + raise ValueError("Unknown format '{}'. Supported formats are {}".format( + format, supported_formats)) + if fmt == "svg": + if not _svgwrite_imported: + raise ImportError( + "svgwrite is not installed. try `pip install svgwrite`") + if width is None: + width = 200 + if height is None: + height = 200 + cls = SvgTreeDrawer + elif fmt == "ascii": + cls = AsciiTreeDrawer + elif fmt == "unicode": + if IS_PY2: + raise ValueError("Unicode tree drawing not supported on Python 2") + cls = UnicodeTreeDrawer + + # We can't draw trees with zero roots. + if tree.num_roots == 0: + raise ValueError("Cannot draw a tree with zero roots") + + td = cls( + tree, width=width, height=height, + node_labels=node_labels, node_colours=node_colours, + mutation_labels=mutation_labels, mutation_colours=mutation_colours) + return td.draw() + + +class TreeDrawer(object): + """ + A class to draw sparse trees in SVG format. + """ + + discretise_coordinates = False + + def _discretise(self, x): + """ + Discetises the specified value, if necessary. + """ + ret = x + if self.discretise_coordinates: + ret = int(round(x)) + return ret + + def __init__( + self, tree, width=None, height=None, node_labels=None, node_colours=None, + mutation_labels=None, mutation_colours=None): + self._tree = tree + self._num_leaves = len(list(tree.leaves())) + self._width = width + self._height = height + self._x_coords = {} + self._y_coords = {} + self._node_labels = {} + self._node_colours = {} + self._mutation_labels = {} + self._mutation_colours = {} + + # Set the node labels and colours. + for u in tree.nodes(): + if node_labels is None: + self._node_labels[u] = str(u) + else: + self._node_labels[u] = None + if node_labels is not None: + for node, label in node_labels.items(): + self._node_labels[node] = label + if node_colours is not None: + for node, colour in node_colours.items(): + self._node_colours[node] = colour + + # Set the mutation labels. + for site in tree.sites(): + for mutation in site.mutations: + if mutation_labels is None: + self._mutation_labels[mutation.id] = str(mutation.id) + else: + self._mutation_labels[mutation.id] = None + if mutation_labels is not None: + for mutation, label in mutation_labels.items(): + self._mutation_labels[mutation] = label + if mutation_colours is not None: + for mutation, colour in mutation_colours.items(): + self._mutation_colours[mutation] = colour + + self._assign_coordinates() + + +class SvgTreeDrawer(TreeDrawer): + """ + Draws trees in SVG format using the svgwrite library. + """ + + def _assign_coordinates(self): + y_padding = 20 + t = 1 + if self._tree.num_roots > 0: + t = max(self._tree.time(root) for root in self._tree.roots) + # In pathological cases, all the roots are at time 0 + if t == 0: + t = 1 + # Do we have any mutations over a root? + mutations_over_root = any( + self._tree.parent(mut.node) == NULL_NODE for mut in self._tree.mutations()) + root_branch_length = 0 + if mutations_over_root: + # Allocate a fixed about of space to show the mutations on the + # 'root branch' + root_branch_length = self._height / 10 + self._y_scale = (self._height - root_branch_length - 2 * y_padding) / t + self._y_coords[-1] = y_padding + for u in self._tree.nodes(): + scaled_t = self._tree.get_time(u) * self._y_scale + self._y_coords[u] = self._height - scaled_t - y_padding + self._x_scale = self._width / (self._num_leaves + 2) + self._leaf_x = 1 + for root in self._tree.roots: + self._assign_x_coordinates(root) + self._mutations = [] + node_mutations = collections.defaultdict(list) + for site in self._tree.sites(): + for mutation in site.mutations: + node_mutations[mutation.node].append(mutation) + for child, mutations in node_mutations.items(): + n = len(mutations) + parent = self._tree.parent(child) + # Ignore any mutations that are above non-roots that are + # not in the current tree. + if child in self._x_coords: + x = self._x_coords[child] + y1 = self._y_coords[child] + y2 = self._y_coords[parent] + chunk = (y2 - y1) / (n + 1) + for k, mutation in enumerate(mutations): + z = x, self._discretise(y1 + (k + 1) * chunk) + self._mutations.append((z, mutation)) + + def _assign_x_coordinates(self, node): + """ + Assign x coordinates to all nodes underneath this node. + """ + if self._tree.is_internal(node): + children = self._tree.children(node) + for c in children: + self._assign_x_coordinates(c) + coords = [self._x_coords[c] for c in children] + a = min(coords) + b = max(coords) + self._x_coords[node] = self._discretise(a + (b - a) / 2) + else: + self._x_coords[node] = self._discretise(self._leaf_x * self._x_scale) + self._leaf_x += 1 + + def draw(self): + """ + Writes the SVG description of this tree and returns the resulting XML + code as text. + """ + dwg = svgwrite.Drawing(size=(self._width, self._height), debug=True) + lines = dwg.add(dwg.g(id='lines', stroke='black')) + left_labels = dwg.add(dwg.g(font_size=14, text_anchor="start")) + right_labels = dwg.add(dwg.g(font_size=14, text_anchor="end")) + mid_labels = dwg.add(dwg.g(font_size=14, text_anchor="middle")) + for u in self._tree.nodes(): + v = self._tree.get_parent(u) + x = self._x_coords[u], self._y_coords[u] + colour = "black" + if self._node_colours.get(u, None) is not None: + colour = self._node_colours[u] + dwg.add(dwg.circle(center=x, r=3, fill=colour)) + dx = 0 + dy = -5 + labels = mid_labels + if self._tree.is_leaf(u): + dy = 20 + elif self._tree.parent(u) != NULL_NODE: + dx = 5 + if self._tree.left_sib(u) == NULL_NODE: + dx *= -1 + labels = right_labels + else: + labels = left_labels + if self._node_labels[u] is not None: + labels.add(dwg.text(self._node_labels[u], (x[0] + dx, x[1] + dy))) + if self._tree.parent(u) != NULL_NODE: + y = self._x_coords[v], self._y_coords[v] + lines.add(dwg.line(x, (x[0], y[1]))) + lines.add(dwg.line((x[0], y[1]), y)) + + # Experimental stuff to render the mutation labels. Not working very + # well at the moment. + left_labels = dwg.add(dwg.g( + font_size=14, text_anchor="start", font_style="italic", + alignment_baseline="middle")) + right_labels = dwg.add(dwg.g( + font_size=14, text_anchor="end", font_style="italic", + alignment_baseline="middle")) + for x, mutation in self._mutations: + r = 3 + colour = "red" + if self._mutation_colours.get(mutation.id, None) is not None: + colour = self._mutation_colours[mutation.id] + dwg.add(dwg.rect( + insert=(x[0] - r, x[1] - r), size=(2 * r, 2 * r), fill=colour)) + dx = 5 + if self._tree.left_sib(mutation.node) == NULL_NODE: + dx *= -1 + labels = right_labels + else: + labels = left_labels + if self._mutation_labels[mutation.id] is not None: + dy = 1.5 * r + labels.add(dwg.text( + self._mutation_labels[mutation.id], (x[0] + dx, x[1] + dy))) + return dwg.tostring() + + +class TextTreeDrawer(TreeDrawer): + """ + Abstract superclass of TreeDrawers that draw trees in a text buffer. + """ + discretise_coordinates = True + + array_type = None # the type used for the array.array canvas + background_char = None # The fill char + eol_char = None # End of line + left_down_char = None # left corner of a horizontal line + right_down_char = None # right corner of a horizontal line + horizontal_line_char = None # horizontal line fill + vertical_line_char = None # vertial line fill + mid_up_char = None # char in a horizontal line going up + mid_down_char = None # char in a horizontal line going down + mid_up_down_char = None # char in a horizontal line going down and up + + def _convert_text(self, text): + """ + Converts the specified string into an array representation that can be + filled into the text buffer. + """ + raise NotImplementedError() + + def _assign_coordinates(self): + # Get the age of each node and rank them. + times = set(self._tree.time(u) for u in self._tree.nodes()) + depth = {t: 2 * j for j, t in enumerate(sorted(times, reverse=True))} + for u in self._tree.nodes(): + self._y_coords[u] = depth[self._tree.time(u)] + self._height = 0 + if len(self._y_coords) > 0: + self._height = max(self._y_coords.values()) + 1 + # Get the overall width and assign x coordinates. + x = 0 + for root in self._tree.roots: + for u in self._tree.nodes(root, order="postorder"): + if self._tree.is_leaf(u): + label_size = 1 + if self._node_labels[u] is not None: + label_size = len(self._node_labels[u]) + self._x_coords[u] = x + x += label_size + 1 + else: + coords = [self._x_coords[c] for c in self._tree.children(u)] + if len(coords) == 1: + self._x_coords[u] = coords[0] + else: + a = min(coords) + b = max(coords) + assert b - a > 1 + self._x_coords[u] = int(round((a + (b - a) / 2))) + x += 1 + self._width = x + 1 + + def _draw(self): + w = self._width + h = self._height + + # Create a width * height canvas of spaces. + canvas = array.array(self.array_type, (w * h) * [self.background_char]) + for u in self._tree.nodes(): + col = self._x_coords[u] + row = self._y_coords[u] + j = row * w + col + label = self._convert_text(self._node_labels[u]) + n = len(label) + canvas[j: j + n] = label + if self._tree.is_internal(u): + children = self._tree.children(u) + row += 1 + left = min(self._x_coords[v] for v in children) + right = max(self._x_coords[v] for v in children) + for col in range(left + 1, right): + canvas[row * w + col] = self.horizontal_line_char + if len(self._tree.children(u)) == 1: + canvas[row * w + self._x_coords[u]] = self.vertical_line_char + else: + canvas[row * w + self._x_coords[u]] = self.mid_up_char + for v in children: + col = self._x_coords[v] + canvas[row * w + col] = self.mid_down_char + if col == self._x_coords[u]: + canvas[row * w + col] = self.mid_up_down_char + for j in range(row + 1, self._y_coords[v]): + canvas[j * w + col] = self.vertical_line_char + if left == right: + canvas[row * w + left] = self.vertical_line_char + else: + canvas[row * w + left] = self.left_down_char + canvas[row * w + right] = self.right_down_char + + # Put in the EOLs last so that if we can't overwrite them. + for row in range(h): + canvas[row * w + w - 1] = self.eol_char + return canvas + + +class AsciiTreeDrawer(TextTreeDrawer): + """ + Draws an ASCII rendering of a tree. + """ + array_type = 'b' + background_char = ord(' ') + eol_char = ord('\n') + left_down_char = ord('+') + right_down_char = ord('+') + horizontal_line_char = ord('-') + vertical_line_char = ord('|') + mid_up_char = ord('+') + mid_down_char = ord('+') + mid_up_down_char = ord('+') + + def _convert_text(self, text): + if text is None: + text = "|" # vertical line char + return array.array(self.array_type, text.encode()) + + def draw(self): + s = self._draw().tostring() + if not IS_PY2: + s = s.decode() + return s + + +class UnicodeTreeDrawer(TextTreeDrawer): + """ + Draws an Unicode rendering of a tree using box drawing characters. + """ + array_type = 'u' + background_char = ' ' + eol_char = '\n' + left_down_char = "\u250F" + right_down_char = "\u2513" + horizontal_line_char = "\u2501" + vertical_line_char = "\u2503" + mid_up_char = "\u253b" + mid_down_char = "\u2533" + mid_up_down_char = "\u254b" + + def _convert_text(self, text): + if text is None: + text = self.vertical_line_char + return array.array(self.array_type, text) + + def draw(self): + return self._draw().tounicode() diff --git a/python/tskit/exceptions.py b/python/tskit/exceptions.py new file mode 100644 index 0000000000..468c51e773 --- /dev/null +++ b/python/tskit/exceptions.py @@ -0,0 +1,41 @@ +""" +Exceptions defined in tskit. +""" +from _tskit import TskitException +from _tskit import LibraryError +from _tskit import FileFormatError +from _tskit import VersionTooNewError +from _tskit import VersionTooOldError + +# Some exceptions are defined in the low-level module. In particular, the +# superclass of all exceptions for tskit is defined here. We define the +# docstrings here to avoid difficulties with compiling C code on +# readthedocs. + +# TODO finalise this when working out the docs structure for tskit on rtd. + +try: + TskitException.__doc__ = "Superclass of all exceptions defined in tskit." + LibraryError.__doc__ = "Generic low-level error raised by the C library." + FileFormatError.__doc__ = "An error was detected in the file format." + VersionTooNewError.__doc__ = """ + The version of the file is too new and cannot be read by the library. + """ + VersionTooOldError.__doc__ = """ + The version of the file is too old and cannot be read by the library. + """ +except AttributeError: + # Python2 throws attribute error. Ignore. + pass + + +class DuplicatePositionsError(TskitException): + """ + Duplicate positions in the list of sites. + """ + + +class ProvenanceValidationError(TskitException): + """ + A JSON document did non validate against the provenance schema. + """ diff --git a/python/tskit/formats.py b/python/tskit/formats.py new file mode 100644 index 0000000000..70ce16b210 --- /dev/null +++ b/python/tskit/formats.py @@ -0,0 +1,547 @@ +""" +Module responsible for converting tree sequence files from older +formats. +""" +from __future__ import division +from __future__ import print_function + +import datetime +import json +import logging + +import h5py +import numpy as np + +import tskit +import tskit.provenance as provenance +import tskit.exceptions as exceptions + + +def _get_v2_provenance(command, attrs): + """ + Returns the V2 tree provenance attributes reformatted as a provenance record. + """ + environment = {} + parameters = {} + # Try to get the provenance strings. Malformed JSON should not prevent us + # from finishing the conversion. + try: + environment = json.loads(str(attrs["environment"])) + except ValueError: + logging.warn("Failed to convert environment provenance") + try: + parameters = json.loads(str(attrs["parameters"])) + except ValueError: + logging.warn("Failed to convert parameters provenance") + parameters["command"] = command + provenance_dict = provenance.get_provenance_dict(parameters) + provenance_dict["version"] = environment.get("msprime_version", "Unknown_version") + provenance_dict["environment"] = environment + return json.dumps(provenance_dict).encode() + + +def _get_upgrade_provenance(root): + """ + Returns the provenance string from upgrading the specified HDF5 file. + """ + # TODO add more parameters here like filename, etc. + parameters = { + "command": "upgrade", + "source_version": list(map(int, root.attrs["format_version"])) + } + s = json.dumps(provenance.get_provenance_dict(parameters)) + return s.encode() + + +def _convert_hdf5_mutations( + mutations_group, sites, mutations, remove_duplicate_positions): + """ + Loads the v2/v3 into the specified tables. + """ + position = np.array(mutations_group["position"]) + node = np.array(mutations_group["node"], dtype=np.int32) + unique_position, index = np.unique(position, return_index=True) + if unique_position.shape != position.shape: + if remove_duplicate_positions: + position = position[index] + node = node[index] + else: + # TODO add the number of duplicates so that we can improve the + # error message. + raise exceptions.DuplicatePositionsError() + num_mutations = position.shape[0] + sites.set_columns( + position=position, + ancestral_state=ord("0") * np.ones(num_mutations, dtype=np.int8), + ancestral_state_offset=np.arange(num_mutations + 1, dtype=np.uint32)) + mutations.set_columns( + node=node, + site=np.arange(num_mutations, dtype=np.int32), + derived_state=ord("1") * np.ones(num_mutations, dtype=np.int8), + derived_state_offset=np.arange(num_mutations + 1, dtype=np.uint32)) + + +def _set_populations(tables): + """ + Updates PopulationTable suitable to represent the populations referred to + in the node table. + """ + if len(tables.nodes) > 0: + for _ in range(np.max(tables.nodes.population) + 1): + tables.populations.add_row() + + +def _load_legacy_hdf5_v2(root, remove_duplicate_positions): + # Get the coalescence records + trees_group = root["trees"] + old_timestamp = datetime.datetime.min.isoformat() + provenances = tskit.ProvenanceTable() + provenances.add_row( + timestamp=old_timestamp, + record=_get_v2_provenance("generate_trees", trees_group.attrs)) + num_rows = trees_group["node"].shape[0] + index = np.arange(num_rows, dtype=int) + parent = np.zeros(2 * num_rows, dtype=np.int32) + parent[2 * index] = trees_group["node"] + parent[2 * index + 1] = trees_group["node"] + left = np.zeros(2 * num_rows, dtype=np.float64) + left[2 * index] = trees_group["left"] + left[2 * index + 1] = trees_group["left"] + right = np.zeros(2 * num_rows, dtype=np.float64) + right[2 * index] = trees_group["right"] + right[2 * index + 1] = trees_group["right"] + child = np.array(trees_group["children"], dtype=np.int32).flatten() + + tables = tskit.TableCollection(np.max(right)) + tables.edges.set_columns(left=left, right=right, parent=parent, child=child) + + cr_node = np.array(trees_group["node"], dtype=np.int32) + num_nodes = max(np.max(child), np.max(cr_node)) + 1 + sample_size = np.min(cr_node) + flags = np.zeros(num_nodes, dtype=np.uint32) + population = np.zeros(num_nodes, dtype=np.int32) + time = np.zeros(num_nodes, dtype=np.float64) + flags[:sample_size] = tskit.NODE_IS_SAMPLE + cr_population = np.array(trees_group["population"], dtype=np.int32) + cr_time = np.array(trees_group["time"]) + time[cr_node] = cr_time + population[cr_node] = cr_population + if "samples" in root: + samples_group = root["samples"] + population[:sample_size] = samples_group["population"] + if "time" in samples_group: + time[:sample_size] = samples_group["time"] + tables.nodes.set_columns(flags=flags, population=population, time=time) + _set_populations(tables) + + if "mutations" in root: + mutations_group = root["mutations"] + _convert_hdf5_mutations( + mutations_group, tables.sites, tables.mutations, remove_duplicate_positions) + provenances.add_row( + timestamp=old_timestamp, + record=_get_v2_provenance("generate_mutations", mutations_group.attrs)) + tables.provenances.add_row(_get_upgrade_provenance(root)) + tables.sort() + return tables.tree_sequence() + + +def _load_legacy_hdf5_v3(root, remove_duplicate_positions): + # get the trees group for the records and samples + trees_group = root["trees"] + nodes_group = trees_group["nodes"] + time = np.array(nodes_group["time"]) + + breakpoints = np.array(trees_group["breakpoints"]) + records_group = trees_group["records"] + left_indexes = np.array(records_group["left"]) + right_indexes = np.array(records_group["right"]) + record_node = np.array(records_group["node"], dtype=np.int32) + num_nodes = time.shape[0] + sample_size = np.min(record_node) + flags = np.zeros(num_nodes, dtype=np.uint32) + flags[:sample_size] = tskit.NODE_IS_SAMPLE + + children_length = np.array(records_group["num_children"], dtype=np.uint32) + total_rows = np.sum(children_length) + left = np.zeros(total_rows, dtype=np.float64) + right = np.zeros(total_rows, dtype=np.float64) + parent = np.zeros(total_rows, dtype=np.int32) + record_left = breakpoints[left_indexes] + record_right = breakpoints[right_indexes] + k = 0 + for j in range(left_indexes.shape[0]): + for _ in range(children_length[j]): + left[k] = record_left[j] + right[k] = record_right[j] + parent[k] = record_node[j] + k += 1 + tables = tskit.TableCollection(np.max(right)) + tables.nodes.set_columns( + flags=flags, + time=nodes_group["time"], + population=nodes_group["population"]) + _set_populations(tables) + tables.edges.set_columns( + left=left, right=right, parent=parent, child=records_group["children"]) + if "mutations" in root: + _convert_hdf5_mutations( + root["mutations"], tables.sites, tables.mutations, + remove_duplicate_positions) + old_timestamp = datetime.datetime.min.isoformat() + if "provenance" in root: + for record in root["provenance"]: + tables.provenances.add_row(timestamp=old_timestamp, record=record) + tables.provenances.add_row(_get_upgrade_provenance(root)) + tables.sort() + return tables.tree_sequence() + + +def load_legacy(filename, remove_duplicate_positions=False): + """ + Reads the specified msprime HDF5 file and returns a tree sequence. This + method is only intended to be used to read old format HDF5 files. + + If remove_duplicate_positions is True, remove all sites (except the + first) that contain duplicate positions. If this is False, any input + files that contain duplicate positions will raise an DuplicatePositionsError. + """ + loaders = { + 2: _load_legacy_hdf5_v2, + 3: _load_legacy_hdf5_v3, + 10: _load_legacy_hdf5_v10, + } + root = h5py.File(filename, "r") + if 'format_version' not in root.attrs: + raise ValueError("HDF5 file not in msprime format") + format_version = root.attrs['format_version'] + if format_version[0] not in loaders: + raise ValueError("Version {} not supported for loading".format(format_version)) + try: + ts = loaders[format_version[0]](root, remove_duplicate_positions) + finally: + root.close() + return ts + + +def raise_hdf5_format_error(filename, original_exception): + """ + Tries to open the specified file as a legacy HDF5 file. If it looks like + an msprime format HDF5 file, raise an error advising to run tskit upgrade. + """ + try: + with h5py.File(filename, "r") as root: + version = tuple(root.attrs["format_version"]) + raise exceptions.VersionTooOldError( + "File format {} is too old. Please use the ``tskit upgrade`` command " + "to upgrade this file to the latest version".format(version)) + except (IOError, OSError, KeyError): + raise exceptions.FileFormatError(str(original_exception)) + + +def _dump_legacy_hdf5_v2(tree_sequence, root): + root.attrs["format_version"] = (2, 999) + root.attrs["sample_size"] = tree_sequence.get_sample_size() + root.attrs["sequence_length"] = tree_sequence.get_sequence_length(), + left = [] + right = [] + node = [] + children = [] + time = [] + population = [] + for record in tree_sequence.records(): + left.append(record.left) + right.append(record.right) + node.append(record.node) + if len(record.children) != 2: + raise ValueError("V2 files only support binary records") + children.append(record.children) + time.append(record.time) + population.append(record.population) + length = len(time) + trees = root.create_group("trees") + trees.attrs["environment"] = json.dumps({"msprime_version": 0}) + trees.attrs["parameters"] = "{}" + trees.create_dataset("left", (length, ), data=left, dtype=float) + trees.create_dataset("right", (length, ), data=right, dtype=float) + trees.create_dataset("time", (length, ), data=time, dtype=float) + trees.create_dataset("node", (length, ), data=node, dtype="u4") + trees.create_dataset("population", (length, ), data=population, dtype="u1") + trees.create_dataset( + "children", (length, 2), data=children, dtype="u4") + samples = root.create_group("samples") + population = [] + time = [] + length = tree_sequence.get_sample_size() + for u in range(length): + time.append(tree_sequence.get_time(u)) + population.append(tree_sequence.get_population(u)) + samples.create_dataset("time", (length, ), data=time, dtype=float) + samples.create_dataset("population", (length, ), data=population, dtype="u1") + if tree_sequence.get_num_mutations() > 0: + node = [] + position = [] + for site in tree_sequence.sites(): + if len(site.mutations) != 1: + raise ValueError("v2 does not support recurrent mutations") + if site.ancestral_state != "0" or site.mutations[0].derived_state != "1": + raise ValueError("v2 does not support non-binary mutations") + position.append(site.position) + node.append(site.mutations[0].node) + length = len(node) + mutations = root.create_group("mutations") + mutations.attrs["environment"] = json.dumps({"msprime_version": 0}) + mutations.attrs["parameters"] = "{}" + mutations.create_dataset("position", (length, ), data=position, dtype=float) + mutations.create_dataset("node", (length, ), data=node, dtype="u4") + + +def _dump_legacy_hdf5_v3(tree_sequence, root): + root.attrs["format_version"] = (3, 999) + root.attrs["sample_size"] = 0, + root.attrs["sequence_length"] = 0, + trees = root.create_group("trees") + # Get the breakpoints from the records. + left = [cr.left for cr in tree_sequence.records()] + breakpoints = np.unique(left + [tree_sequence.sequence_length]) + trees.create_dataset( + "breakpoints", (len(breakpoints), ), data=breakpoints, dtype=float) + + left = [] + right = [] + node = [] + children = [] + num_children = [] + time = [] + for cr in tree_sequence.records(): + node.append(cr.node) + left.append(np.searchsorted(breakpoints, cr.left)) + right.append(np.searchsorted(breakpoints, cr.right)) + children.extend(cr.children) + num_children.append(len(cr.children)) + time.append(cr.time) + records_group = trees.create_group("records") + length = len(num_children) + records_group.create_dataset("left", (length, ), data=left, dtype="u4") + records_group.create_dataset("right", (length, ), data=right, dtype="u4") + records_group.create_dataset("node", (length, ), data=node, dtype="u4") + records_group.create_dataset( + "num_children", (length, ), data=num_children, dtype="u4") + records_group.create_dataset( + "children", (len(children), ), data=children, dtype="u4") + + indexes_group = trees.create_group("indexes") + left_index = sorted(range(length), key=lambda j: (left[j], time[j])) + right_index = sorted(range(length), key=lambda j: (right[j], -time[j])) + indexes_group.create_dataset( + "insertion_order", (length, ), data=left_index, dtype="u4") + indexes_group.create_dataset( + "removal_order", (length, ), data=right_index, dtype="u4") + + nodes_group = trees.create_group("nodes") + population = np.zeros(tree_sequence.num_nodes, dtype="u4") + time = np.zeros(tree_sequence.num_nodes, dtype=float) + tree = next(tree_sequence.trees()) + for u in range(tree_sequence.sample_size): + population[u] = tree.population(u) + time[u] = tree.time(u) + for cr in tree_sequence.records(): + population[cr.node] = cr.population + time[cr.node] = cr.time + length = tree_sequence.num_nodes + nodes_group.create_dataset("time", (length, ), data=time, dtype=float) + nodes_group.create_dataset("population", (length, ), data=population, dtype="u4") + + node = [] + position = [] + for site in tree_sequence.sites(): + if len(site.mutations) != 1: + raise ValueError("v3 does not support recurrent mutations") + if site.ancestral_state != "0" or site.mutations[0].derived_state != "1": + raise ValueError("v3 does not support non-binary mutations") + position.append(site.position) + node.append(site.mutations[0].node) + length = len(position) + if length > 0: + mutations = root.create_group("mutations") + mutations.create_dataset("position", (length, ), data=position, dtype=float) + mutations.create_dataset("node", (length, ), data=node, dtype="u4") + + +def _add_dataset(group, name, data): + # In the HDF5 format any zero-d arrays must be excluded. + if data.shape[0] > 0: + group.create_dataset(name, data=data) + + +def _dump_legacy_hdf5_v10(tree_sequence, root): + root.attrs["format_version"] = (10, 999) + root.attrs["sample_size"] = 0, + root.attrs["sequence_length"] = tree_sequence.sequence_length, + tables = tree_sequence.dump_tables() + + nodes = root.create_group("nodes") + _add_dataset(nodes, "time", tables.nodes.time) + _add_dataset(nodes, "flags", tables.nodes.flags) + _add_dataset(nodes, "population", tables.nodes.population) + _add_dataset(nodes, "metadata", tables.nodes.metadata) + _add_dataset(nodes, "metadata_offset", tables.nodes.metadata_offset) + + edges = root.create_group("edges") + if len(tables.edges) > 0: + edges.create_dataset("left", data=tables.edges.left) + edges.create_dataset("right", data=tables.edges.right) + edges.create_dataset("parent", data=tables.edges.parent) + edges.create_dataset("child", data=tables.edges.child) + + left = tables.edges.left + right = tables.edges.right + time = tables.nodes.time[tables.edges.parent] + # We can do this more efficiently if we ever need to do it for anything + # other than testing. + indexes_group = edges.create_group("indexes") + length = len(tables.edges) + left_index = sorted(range(length), key=lambda j: (left[j], time[j])) + right_index = sorted(range(length), key=lambda j: (right[j], -time[j])) + indexes_group.create_dataset( + "insertion_order", data=left_index, dtype="u4") + indexes_group.create_dataset( + "removal_order", data=right_index, dtype="u4") + + migrations = root.create_group("migrations") + if len(tables.migrations) > 0: + migrations.create_dataset("left", data=tables.migrations.left) + migrations.create_dataset("right", data=tables.migrations.right) + migrations.create_dataset("node", data=tables.migrations.node) + migrations.create_dataset("source", data=tables.migrations.source) + migrations.create_dataset("dest", data=tables.migrations.dest) + migrations.create_dataset("time", data=tables.migrations.time) + + sites = root.create_group("sites") + _add_dataset(sites, "position", tables.sites.position) + _add_dataset(sites, "ancestral_state", tables.sites.ancestral_state) + _add_dataset(sites, "ancestral_state_offset", tables.sites.ancestral_state_offset) + _add_dataset(sites, "metadata", tables.sites.metadata) + _add_dataset(sites, "metadata_offset", tables.sites.metadata_offset) + + mutations = root.create_group("mutations") + _add_dataset(mutations, "site", tables.mutations.site) + _add_dataset(mutations, "node", tables.mutations.node) + _add_dataset(mutations, "parent", tables.mutations.parent) + _add_dataset(mutations, "derived_state", tables.mutations.derived_state) + _add_dataset( + mutations, "derived_state_offset", tables.mutations.derived_state_offset) + _add_dataset(mutations, "metadata", tables.mutations.metadata) + _add_dataset(mutations, "metadata_offset", tables.mutations.metadata_offset) + + provenances = root.create_group("provenances") + _add_dataset(provenances, "timestamp", tables.provenances.timestamp) + _add_dataset(provenances, "timestamp_offset", tables.provenances.timestamp_offset) + _add_dataset(provenances, "record", tables.provenances.record) + _add_dataset(provenances, "record_offset", tables.provenances.record_offset) + + +def _load_legacy_hdf5_v10(root, remove_duplicate_positions=False): + # We cannot have duplicate positions in v10, so this parameter is ignored + sequence_length = root.attrs["sequence_length"] + tables = tskit.TableCollection(sequence_length) + + nodes_group = root["nodes"] + metadata = None + metadata_offset = None + if "metadata" in nodes_group: + metadata = nodes_group["metadata"] + metadata_offset = nodes_group["metadata_offset"] + if "flags" in nodes_group: + tables.nodes.set_columns( + flags=nodes_group["flags"], + population=nodes_group["population"], + time=nodes_group["time"], + metadata=metadata, + metadata_offset=metadata_offset) + + edges_group = root["edges"] + if "left" in edges_group: + tables.edges.set_columns( + left=edges_group["left"], + right=edges_group["right"], + parent=edges_group["parent"], + child=edges_group["child"]) + + migrations_group = root["migrations"] + if "left" in migrations_group: + tables.migrations.set_columns( + left=migrations_group["left"], + right=migrations_group["right"], + node=migrations_group["node"], + source=migrations_group["source"], + dest=migrations_group["dest"], + time=migrations_group["time"]) + + sites_group = root["sites"] + if "position" in sites_group: + metadata = None + metadata_offset = None + if "metadata" in sites_group: + metadata = sites_group["metadata"] + metadata_offset = sites_group["metadata_offset"] + tables.sites.set_columns( + position=sites_group["position"], + ancestral_state=sites_group["ancestral_state"], + ancestral_state_offset=sites_group["ancestral_state_offset"], + metadata=metadata, + metadata_offset=metadata_offset) + + mutations_group = root["mutations"] + if "site" in mutations_group: + metadata = None + metadata_offset = None + if "metadata" in mutations_group: + metadata = mutations_group["metadata"] + metadata_offset = mutations_group["metadata_offset"] + tables.mutations.set_columns( + site=mutations_group["site"], + node=mutations_group["node"], + parent=mutations_group["parent"], + derived_state=mutations_group["derived_state"], + derived_state_offset=mutations_group["derived_state_offset"], + metadata=metadata, + metadata_offset=metadata_offset) + + provenances_group = root["provenances"] + if "timestamp" in provenances_group: + timestamp = provenances_group["timestamp"] + timestamp_offset = provenances_group["timestamp_offset"] + if "record" in provenances_group: + record = provenances_group["record"] + record_offset = provenances_group["record_offset"] + else: + record = np.empty_like(timestamp) + record_offset = np.zeros_like(timestamp_offset) + tables.provenances.set_columns( + timestamp=timestamp, + timestamp_offset=timestamp_offset, + record=record, + record_offset=record_offset) + tables.provenances.add_row(_get_upgrade_provenance(root)) + _set_populations(tables) + return tables.tree_sequence() + + +def dump_legacy(tree_sequence, filename, version=3): + """ + Writes the specified tree sequence to a HDF5 file in the specified + legacy file format version. + """ + dumpers = { + 2: _dump_legacy_hdf5_v2, + 3: _dump_legacy_hdf5_v3, + 10: _dump_legacy_hdf5_v10, + } + if version not in dumpers: + raise ValueError("Version {} file format is supported".format(version)) + root = h5py.File(filename, "w") + try: + dumpers[version](tree_sequence, root) + finally: + root.close() diff --git a/python/tskit/provenance.py b/python/tskit/provenance.py new file mode 100644 index 0000000000..fb41833644 --- /dev/null +++ b/python/tskit/provenance.py @@ -0,0 +1,115 @@ +""" +Common provenance methods used to determine the state and versions +of various dependencies and the OS. +""" +from __future__ import print_function +from __future__ import division + +import platform +import json +import os.path + +import jsonschema + +import tskit.exceptions as exceptions +import _tskit + +from . import _version + +__version__ = _version.tskit_version + + +# NOTE: the APIs here are all preliminary. We should have a class that encapsulates +# all of the required functionality, including parsing and printing out provenance +# records. This will replace the current functions. + + +def get_environment(extra_libs=None, include_tskit=True): + """ + Returns a dictionary describing the environment in which tskit + is currently running. + + This API is tentative and will change in the future when a more + comprehensive provenance API is implemented. + """ + env = { + "os": { + "system": platform.system(), + "node": platform.node(), + "release": platform.release(), + "version": platform.version(), + "machine": platform.machine(), + }, + "python": { + "implementation": platform.python_implementation(), + "version": platform.python_version(), + } + } + libs = { + "kastore": { + "version": ".".join(map(str, _tskit.get_kastore_version())) + } + } + if include_tskit: + libs["tskit"] = {"version": __version__} + if extra_libs is not None: + libs.update(extra_libs) + env["libraries"] = libs + return env + + +def get_provenance_dict(parameters=None): + """ + Returns a dictionary encoding an execution of tskit conforming to the + provenance schema. + """ + document = { + "schema_version": "1.0.0", + "software": { + "name": "tskit", + "version": __version__ + }, + "parameters": parameters, + "environment": get_environment(include_tskit=False) + } + return document + + +# Cache the schema +_schema = None + + +def get_schema(): + """ + Returns the tskit provenance :ref:`provenance schema ` as + a dict. + + :return: The provenance schema. + :rtype: dict + """ + global _schema + if _schema is None: + base = os.path.dirname(__file__) + schema_file = os.path.join(base, "provenance.schema.json") + with open(schema_file) as f: + _schema = json.load(f) + # Return a copy to avoid issues with modifying the cached schema + return dict(_schema) + + +def validate_provenance(provenance): + """ + Validates the specified dict-like object against the tskit + :ref:`provenance schema `. If the input does + not represent a valid instance of the schema an exception is + raised. + + :param dict provenance: The dictionary representing a JSON document + to be validated against the schema. + :raises: :class:`.ProvenanceValidationError` + """ + schema = get_schema() + try: + jsonschema.validate(provenance, schema) + except jsonschema.exceptions.ValidationError as ve: + raise exceptions.ProvenanceValidationError(str(ve)) diff --git a/python/tskit/provenance.schema.json b/python/tskit/provenance.schema.json new file mode 100644 index 0000000000..fd683fff9e --- /dev/null +++ b/python/tskit/provenance.schema.json @@ -0,0 +1,50 @@ +{ + "schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "title": "tskit provenance", + "description": "The combination of software, parameters and environment that produced a tree sequence", + "type": "object", + "required": ["schema_version", "software", "parameters", "environment"], + "properties": { + "schema_version": { + "description": "The version of this schema used.", + "type": "string", + "minLength": 1 + }, + "software": { + "description": "The primary software used to produce the tree sequence.", + "type": "object", + "required": ["name", "version"], + "properties": { + "name": { + "description": "The name of the primary software.", + "type": "string", + "minLength": 1 + }, + "version": { + "description": "The version of primary software.", + "type": "string", + "minLength": 1 + } + } + }, + "parameters": { + "description": "The parameters used to produce the tree sequence.", + "type": "object" + }, + "environment": { + "description": "The computational environment within which the primary software ran.", + "type": "object", + "properties": { + "os": { + "description": "Operating system.", + "type": "object" + }, + "libraries": { + "description": "Details of libraries the primary software linked against.", + "type": "object" + } + } + } + } +} diff --git a/python/tskit/stats.py b/python/tskit/stats.py new file mode 100644 index 0000000000..397cce96ee --- /dev/null +++ b/python/tskit/stats.py @@ -0,0 +1,1066 @@ +""" +Module responsible for computing various statistics on tree sequences. +""" +from __future__ import division +from __future__ import print_function + +import threading +import struct +import sys + +import numpy as np + +import _tskit + + +class LdCalculator(object): + """ + Class for calculating `linkage disequilibrium + `_ coefficients + between pairs of mutations in a :class:`.TreeSequence`. This class requires + the `numpy `_ library. + + This class supports multithreaded access using the Python :mod:`threading` + module. Separate instances of :class:`.LdCalculator` referencing the + same tree sequence can operate in parallel in multiple threads. + See the :ref:`sec_tutorial_threads` section in the :ref:`sec_tutorial` + for an example of how use multiple threads to calculate LD values + efficiently. + + .. note:: This class does not currently support sites that have more than one + mutation. Using it on such a tree sequence will raise a LibraryError with + an "Unsupported operation" message. + + :param TreeSequence tree_sequence: The tree sequence containing the + mutations we are interested in. + """ + + def __init__(self, tree_sequence): + self._tree_sequence = tree_sequence + self._ll_ld_calculator = _tskit.LdCalculator( + tree_sequence.get_ll_tree_sequence()) + item_size = struct.calcsize('d') + self._buffer = bytearray( + tree_sequence.get_num_mutations() * item_size) + # To protect low-level C code, only one method may execute on the + # low-level objects at one time. + self._instance_lock = threading.Lock() + + def get_r2(self, a, b): + # Deprecated alias for r2(a, b) + return self.r2(a, b) + + def r2(self, a, b): + """ + Returns the value of the :math:`r^2` statistic between the pair of + mutations at the specified indexes. This method is *not* an efficient + method for computing large numbers of pairwise; please use either + :meth:`.r2_array` or :meth:`.r2_matrix` for this purpose. + + :param int a: The index of the first mutation. + :param int b: The index of the second mutation. + :return: The value of :math:`r^2` between the mutations at indexes + ``a`` and ``b``. + :rtype: float + """ + with self._instance_lock: + return self._ll_ld_calculator.get_r2(a, b) + + def get_r2_array(self, a, direction=1, max_mutations=None, max_distance=None): + # Deprecated alias for r2_array + return self.r2_array(a, direction, max_mutations, max_distance) + + def r2_array(self, a, direction=1, max_mutations=None, max_distance=None): + """ + Returns the value of the :math:`r^2` statistic between the focal + mutation at index :math:`a` and a set of other mutations. The method + operates by starting at the focal mutation and iterating over adjacent + mutations (in either the forward or backwards direction) until either a + maximum number of other mutations have been considered (using the + ``max_mutations`` parameter), a maximum distance in sequence + coordinates has been reached (using the ``max_distance`` parameter) or + the start/end of the sequence has been reached. For every mutation + :math:`b` considered, we then insert the value of :math:`r^2` between + :math:`a` and :math:`b` at the corresponding index in an array, and + return the entire array. If the returned array is :math:`x` and + ``direction`` is :const:`tskit.FORWARD` then :math:`x[0]` is the + value of the statistic for :math:`a` and :math:`a + 1`, :math:`x[1]` + the value for :math:`a` and :math:`a + 2`, etc. Similarly, if + ``direction`` is :const:`tskit.REVERSE` then :math:`x[0]` is the + value of the statistic for :math:`a` and :math:`a - 1`, :math:`x[1]` + the value for :math:`a` and :math:`a - 2`, etc. + + :param int a: The index of the focal mutation. + :param int direction: The direction in which to travel when + examining other mutations. Must be either + :const:`tskit.FORWARD` or :const:`tskit.REVERSE`. Defaults + to :const:`tskit.FORWARD`. + :param int max_mutations: The maximum number of mutations to return + :math:`r^2` values for. Defaults to as many mutations as + possible. + :param float max_distance: The maximum absolute distance between + the focal mutation and those for which :math:`r^2` values + are returned. + :return: An array of double precision floating point values + representing the :math:`r^2` values for mutations in the + specified direction. + :rtype: numpy.ndarray + :warning: For efficiency reasons, the underlying memory used to + store the returned array is shared between calls. Therefore, + if you wish to store the results of a single call to + ``get_r2_array()`` for later processing you **must** take a + copy of the array! + """ + if max_mutations is None: + max_mutations = -1 + if max_distance is None: + max_distance = sys.float_info.max + with self._instance_lock: + num_values = self._ll_ld_calculator.get_r2_array( + self._buffer, a, direction=direction, + max_mutations=max_mutations, max_distance=max_distance) + return np.frombuffer(self._buffer, "d", num_values) + + def get_r2_matrix(self): + # Deprecated alias for r2_matrix + return self.r2_matrix() + + def r2_matrix(self): + """ + Returns the complete :math:`m \\times m` matrix of pairwise + :math:`r^2` values in a tree sequence with :math:`m` mutations. + + :return: An 2 dimensional square array of double precision + floating point values representing the :math:`r^2` values for + all pairs of mutations. + :rtype: numpy.ndarray + """ + m = self._tree_sequence.get_num_mutations() + A = np.ones((m, m), dtype=float) + for j in range(m - 1): + a = self.get_r2_array(j) + A[j, j + 1:] = a + A[j + 1:, j] = a + return A + + +class GeneralStatCalculator(object): + """ + A common class for BranchLengthStatCalculator and SiteStatCalculator -- those + implemment different `tree_stat_vector()` methods, but given that + general-purpose function, many statistics are computed in the same way. + + .. warning:: + This interface is still in beta, and may change in the future. + """ + + def __init__(self, tree_sequence): + self.tree_sequence = tree_sequence + + def divergence(self, sample_sets, windows): + """ + Finds the divergence between pairs of samples as described in + mean_pairwise_tmrca_matrix (which uses this function). Returns the + upper triangle (including the diagonal) in row-major order, so if the + output is `x`, then: + + >>> k=0 + >>> for w in range(len(windows)-1): + >>> for i in range(len(sample_sets)): + >>> for j in range(i,len(sample_sets)): + >>> trmca[i,j] = tmrca[j,i] = x[w][k]/2.0 + >>> k += 1 + + will fill out the matrix of mean TMRCAs in the `i`th window between (and + within) each group of samples in `sample_sets` in the matrix `tmrca`. + (This is because divergence is one-half TMRCA.) Alternatively, if + `names` labels the sample_sets, the output labels are: + + >>> [".".join(names[i],names[j]) for i in range(len(names)) + >>> for j in range(i,len(names))] + + :param list sample_sets: A list of sets of IDs of samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of the upper triangle of divergences in row-major + order, including the diagonal. + """ + ns = len(sample_sets) + n = [len(x) for x in sample_sets] + + def f(x): + return [float(x[i]*(n[j]-x[j])) + for i in range(ns) for j in range(i, ns)] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + # move this division outside of f(x) so it only has to happen once + # corrects the diagonal for self comparisons + # and note factor of two for tree length -> real time + for w in range(len(windows)-1): + k = 0 + for i in range(ns): + for j in range(i, ns): + if i == j: + if n[i] == 1: + out[w][k] = np.nan + else: + out[w][k] /= float(n[i] * (n[i] - 1)) + else: + out[w][k] /= float(n[i] * n[j]) + k += 1 + + return out + + def divergence_matrix(self, sample_sets, windows): + """ + Finds the mean divergence between pairs of samples from each set of + samples and in each window. Returns a numpy array indexed by (window, + sample_set, sample_set). Diagonal entries are corrected so that the + value gives the mean divergence for *distinct* samples, but it is not + checked whether the sample_sets are disjoint (so offdiagonals are not + corrected). For this reason, if an element of `sample_sets` has only + one element, the corresponding diagonal will be NaN. + + The mean divergence between two samples is defined to be the mean: (as + a TreeStat) length of all edges separating them in the tree, or (as a + SiteStat) density of segregating sites, at a uniformly chosen position + on the genome. + + :param list sample_sets: A list of sets of IDs of samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of the upper triangle of mean TMRCA values in row-major + order, including the diagonal. + """ + x = self.divergence(sample_sets, windows) + ns = len(sample_sets) + nw = len(windows) - 1 + A = np.ones((nw, ns, ns), dtype=float) + for w in range(nw): + k = 0 + for i in range(ns): + for j in range(i, ns): + A[w, i, j] = A[w, j, i] = x[w][k] + k += 1 + return A + + def Y3_vector(self, sample_sets, windows, indices): + """ + Finds the 'Y' statistic between three sample_sets. The sample_sets should + be disjoint (the computation works fine, but if not the result depends + on the amount of overlap). If the sample_sets are A, B, and C, then the + result gives the mean total length of any edge in the tree between a + and the most recent common ancestor of b and c, where a, b, and c are + random draws from A, B, and C respectively; or the density of mutations + segregating a|bc. + + The result is, for each window, a vector whose k-th entry is + Y(sample_sets[indices[k][0]], sample_sets[indices[k][1]], + sample_sets[indices[k][2]]). + + :param list sample_sets: A list of *three* lists of IDs of samples: (A,B,C). + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :param list indices: A list of triples of indices of sample_sets. + :return: A list of numeric vectors of length equal to the length of + indices, computed separately on each window. + """ + for u in indices: + if not len(u) == 3: + raise ValueError("All indices should be of length 3.") + n = [len(x) for x in sample_sets] + + def f(x): + return [float(x[i] * (n[j] - x[j]) * (n[k] - x[k])) + for i, j, k in indices] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + + # move this division outside of f(x) so it only has to happen once + # corrects the diagonal for self comparisons + for w in range(len(windows)-1): + for u in range(len(indices)): + out[w][u] /= float(n[indices[u][0]] * n[indices[u][1]] + * n[indices[u][2]]) + + return out + + def Y2_vector(self, sample_sets, windows, indices): + """ + Finds the 'Y' statistic for two groups of samples in sample_sets. + The sample_sets should be disjoint (the computation works fine, but if + not the result depends on the amount of overlap). + If the sample_sets are A and B then the result gives the mean total length + of any edge in the tree between a and the most recent common ancestor of + b and c, where a, b, and c are random draws from A, B, and B + respectively (without replacement). + + The result is, for each window, a vector whose k-th entry is + Y2(sample_sets[indices[k][0]], sample_sets[indices[k][1]]). + + :param list sample_sets: A list of lists of IDs of leaves. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :param list indices: A list of pairs of indices of sample_sets. + :return: A list of numeric vectors of length equal to the length of + indices, computed separately on each window. + """ + for u in indices: + if not len(u) == 2: + raise ValueError("All indices should be of length 2.") + n = [len(x) for x in sample_sets] + + def f(x): + return [float(x[i] * (n[j] - x[j]) * (n[j] - x[j] - 1)) + for i, j in indices] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + for w in range(len(windows)-1): + for u in range(len(indices)): + out[w][u] /= float(n[indices[u][0]] * n[indices[u][1]] + * (n[indices[u][1]]-1)) + + return out + + def Y1_vector(self, sample_sets, windows): + """ + Finds the 'Y1' statistic within each set of samples in sample_sets. The + sample_sets should be disjoint (the computation works fine, but if not + the result depends on the amount of overlap). For the sample set A, the + result gives the mean total length of any edge in the tree between a + and the most recent common ancestor of b and c, where a, b, and c are + random draws from A, without replacement. + + The result is, for each window, a vector whose k-th entry is + Y1(sample_sets[k]). + + :param list sample_sets: A list of sets of IDs of samples, each of length + at least 3. + :param iterable windows: The breakpoints of the windows (including + start and end, so has one more entry than number of windows). + :return: A list of numeric vectors of length equal to the length of + sample_sets, computed separately on each window. + """ + for x in sample_sets: + if len(x) < 3: + raise ValueError("All sample_sets should be of length at least 3.") + n = [len(x) for x in sample_sets] + + def f(x): + return [float(z * (m - z) * (m - z - 1)) for m, z in zip(n, x)] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + for w in range(len(windows)-1): + for u in range(len(sample_sets)): + out[w][u] /= float(n[u] * (n[u]-1) * (n[u]-2)) + + return out + + def Y2(self, sample_sets, windows): + return self.Y2_vector(sample_sets, windows, indices=[(0, 1)]) + + def Y3(self, sample_sets, windows): + """ + Finds the 'Y' statistic between the three groups of samples in + sample_sets. The sample_sets should be disjoint (the computation works + fine, but if not the result depends on the amount of overlap). If the + sample_sets are A, B, and C, then the result gives the mean total + length of any edge in the tree between a and the most recent common + ancestor of b and c, where a, b, and c are random draws from A, B, and + C respectively. + + :param list sample_sets: A list of *three* sets of IDs of samples: (A,B,C). + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of numeric values computed separately on each window. + """ + return self.Y3_vector(sample_sets, windows, indices=[(0, 1, 2)]) + + def f4_vector(self, sample_sets, windows, indices): + """ + Finds the Patterson's f4 statistics between multiple subsets of four + groups of sample_sets. The sample_sets should be disjoint (the computation + works fine, but if not the result depends on the amount of overlap). + + :param list sample_sets: A list of four sets of IDs of samples: (A,B,C,D) + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :param list indices: A list of 4-tuples of indices of sample_sets. + :return: A list of values of f4(A,B;C,D) of length equal to the length of + indices, computed separately on each window. + """ + for u in indices: + if not len(u) == 4: + raise ValueError("All tuples in indices should be of length 4.") + n = [len(x) for x in sample_sets] + + def f(x): + return [float(x[i] * x[k] * (n[j] - x[j]) * (n[l] - x[l]) + - x[i] * x[l] * (n[j] - x[j]) * (n[k] - x[k])) + for i, j, k, l in indices] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + # move this division outside of f(x) so it only has to happen once + # corrects the diagonal for self comparisons + for w in range(len(windows)-1): + for u in range(len(indices)): + out[w][u] /= float(n[indices[u][0]] * n[indices[u][1]] + * n[indices[u][2]] * n[indices[u][3]]) + + return out + + def f4(self, sample_sets, windows): + """ + Finds the Patterson's f4 statistics between the four groups of samples + in sample_sets. The sample_sets should be disjoint (the computation works + fine, but if not the result depends on the amount of overlap). + + :param list sample_sets: A list of four sets of IDs of samples: (A,B,C,D) + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of values of f4(A,B;C,D) computed separately on each window. + """ + if not len(sample_sets) == 4: + raise ValueError("sample_sets should be of length 4.") + return self.f4_vector(sample_sets, windows, indices=[(0, 1, 2, 3)]) + + def f3_vector(self, sample_sets, windows, indices): + """ + Finds the Patterson's f3 statistics between multiple subsets of three + groups of samples in sample_sets. The sample_sets should be disjoint (the + computation works fine, but if not the result depends on the amount of + overlap). + + f3(A;B,C) is f4(A,B;A,C) corrected to not include self comparisons. + + If A does not contain at least three samples, the result is NaN. + + :param list sample_sets: A list of sets of IDs of samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :param list indices: A list of triples of indices of sample_sets. + :return: A list of values of f3(A,B,C) computed separately on each window. + """ + for u in indices: + if not len(u) == 3: + raise ValueError("All tuples in indices should be of length 3.") + n = [len(x) for x in sample_sets] + + def f(x): + return [float(x[i] * (x[i] - 1) * (n[j] - x[j]) * (n[k] - x[k]) + - x[i] * (n[i] - x[i]) * (n[j] - x[j]) * x[k]) + for i, j, k in indices] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + # move this division outside of f(x) so it only has to happen once + for w in range(len(windows)-1): + for u in range(len(indices)): + if n[indices[u][0]] == 1: + out[w][u] = np.nan + else: + out[w][u] /= float(n[indices[u][0]] * (n[indices[u][0]]-1) + * n[indices[u][1]] * n[indices[u][2]]) + + return out + + def f3(self, sample_sets, windows): + """ + Finds the Patterson's f3 statistics between the three groups of samples + in sample_sets. The sample_sets should be disjoint (the computation works + fine, but if not the result depends on the amount of overlap). + + f3(A;B,C) is f4(A,B;A,C) corrected to not include self comparisons. + + :param list sample_sets: A list of *three* sets of IDs of samples: (A,B,C), + with the first set having at least two samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of values of f3(A,B,C) computed separately on each window. + """ + if not len(sample_sets) == 3: + raise ValueError("sample_sets should be of length 3.") + return self.f3_vector(sample_sets, windows, indices=[(0, 1, 2)]) + + def f2_vector(self, sample_sets, windows, indices): + """ + Finds the Patterson's f2 statistics between multiple subsets of pairs + of samples in sample_sets. The sample_sets should be disjoint (the + computation works fine, but if not the result depends on the amount of + overlap). + + f2(A;B) is f4(A,B;A,B) corrected to not include self comparisons. + + :param list sample_sets: A list of sets of IDs of samples, each having at + least two samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :param list indices: A list of pairs of indices of sample_sets. + :return: A list of values of f2(A,C) computed separately on each window. + """ + for u in indices: + if not len(u) == 2: + raise ValueError("All tuples in indices should be of length 2.") + n = [len(x) for x in sample_sets] + for xlen in n: + if not xlen > 1: + raise ValueError("All sample_sets must have at least two samples.") + + def f(x): + return [float(x[i] * (x[i] - 1) * (n[j] - x[j]) * (n[j] - x[j] - 1) + - x[i] * (n[i] - x[i]) * (n[j] - x[j]) * x[j]) + for i, j in indices] + + out = self.tree_stat_vector(sample_sets, weight_fun=f, windows=windows) + # move this division outside of f(x) so it only has to happen once + for w in range(len(windows)-1): + for u in range(len(indices)): + out[w][u] /= float(n[indices[u][0]] * (n[indices[u][0]]-1) + * n[indices[u][1]] * (n[indices[u][1]] - 1)) + + return out + + def f2(self, sample_sets, windows): + """ + Finds the Patterson's f2 statistics between the three groups of samples + in sample_sets. The sample_sets should be disjoint (the computation works + fine, but if not the result depends on the amount of overlap). + + f2(A;B) is f4(A,B;A,B) corrected to not include self comparisons. + + :param list sample_sets: A list of *two* sets of IDs of samples: (A,B), + each having at least two samples. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of values of f2(A,B) computed separately on each window. + """ + if not len(sample_sets) == 2: + raise ValueError("sample_sets should be of length 2.") + return self.f2_vector(sample_sets, windows, indices=[(0, 1)]) + + def tree_stat(self, sample_sets, weight_fun): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a function + whose argument is a list of integers of the same length as sample_sets + that returns a weight. A branch in a tree is weighted by weight_fun(x), + where x[i] is the number of samples in sample_sets[i] below that + branch. This finds the sum of all counted branches for each tree, + and averages this across the tree sequence, weighted by genomic length. + ''' + out = self.tree_stat_vector(sample_sets, lambda x: [weight_fun(x)]) + assert len(out) == 1 and len(out[0]) == 1 + return out[0][0] + + def tree_stat_windowed(self, sample_sets, weight_fun, windows=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a function + whose argument is a list of integers of the same length as sample_sets + that returns a boolean. A branch in a tree is weighted by weight_fun(x), + where x[i] is the number of samples in sample_sets[i] below that + branch. This finds the sum of all counted branches for each tree, + and averages this across the tree sequence, weighted by genomic length. + ''' + out = self.tree_stat_vector(sample_sets, lambda x: [weight_fun(x)], windows) + assert len(out[0]) == 1 + return [x[0] for x in out] + + +class BranchLengthStatCalculator(GeneralStatCalculator): + """ + Class for calculating a broad class of tree statistics. These are all + calculated using :meth:``BranchLengthStatCalculator.tree_stat_vector`` as the + underlying engine. This class requires the `numpy + `_ library. + + .. warning:: + This interface is still in beta, and may change in the future. + + :param TreeSequence tree_sequence: The tree sequence we will compute + statistics for. + """ + + def tree_stat_vector(self, sample_sets, weight_fun, windows=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a + function whose argument is a list of integers of the same length as + sample_sets that returns a list of numbers. A branch in a tree is + weighted by weight_fun(x) + weight_fun(n-x), where x[i] is the number + of samples in sample_sets[i] below that branch, and n[i]-x[i] is the + number *not* below that branch. This finds the sum of this weight for + all branches in each tree, and averages this across the tree sequence, + weighted by genomic length. + + It does this separately for each window [windows[i], windows[i+1]) and + returns the values in a list. Note that windows cannot be overlapping, + but overlapping windows can be achieved by (a) computing staistics on a + small window size and (b) averaging neighboring windows, by additivity + of the statistics. + ''' + if windows is None: + windows = (0, self.tree_sequence.sequence_length) + for U in sample_sets: + if ((not isinstance(U, list)) or + len(U) != len(set(U))): + raise ValueError( + "elements of sample_sets must be lists without repeated elements.") + if len(U) == 0: + raise ValueError("elements of sample_sets cannot be empty.") + for u in U: + if not self.tree_sequence.node(u).is_sample(): + raise ValueError("Not all elements of sample_sets are samples.") + num_windows = len(windows) - 1 + if windows[0] != 0.0: + raise ValueError( + "Windows must start at the start of the sequence (at 0.0).") + if windows[-1] != self.tree_sequence.sequence_length: + raise ValueError("Windows must extend to the end of the sequence.") + for k in range(num_windows): + if windows[k + 1] <= windows[k]: + raise ValueError("Windows must be increasing.") + # below we actually just keep track of x, not (x,xbar), so here's the + # weighting function we actually use of just x: + num_sample_sets = len(sample_sets) + + # this how we apply the weight function to both below the branch and + # above it + n = [len(x) for x in sample_sets] + + def wfn(x): + ax = [nn - xx for nn, xx in zip(n, x)] + return [a + b for a, b in zip(weight_fun(x), weight_fun(ax))] + + # initialize + n_out = len(wfn([0 for a in range(num_sample_sets)])) + + S = [[0.0 for j in range(n_out)] for _ in range(num_windows)] + L = [0.0 for j in range(n_out)] + N = self.tree_sequence.num_nodes + X = [[int(u in a) for a in sample_sets] for u in range(N)] + # we will essentially construct the tree + pi = [-1 for j in range(N)] + node_time = [self.tree_sequence.node(u).time for u in range(N)] + # keep track of where we are for the windows + chrom_pos = 0.0 + # index of *left-hand* end of the current window + window_num = 0 + for interval, records_out, records_in in self.tree_sequence.edge_diffs(): + length = interval[1] - interval[0] + for sign, records in ((-1, records_out), (+1, records_in)): + for edge in records: + dx = [0 for k in range(num_sample_sets)] + if sign == +1: + pi[edge.child] = edge.parent + for k in range(num_sample_sets): + dx[k] += sign * X[edge.child][k] + w = wfn(X[edge.child]) + dt = (node_time[pi[edge.child]] - node_time[edge.child]) + for j in range(n_out): + L[j] += sign * dt * w[j] + if sign == -1: + pi[edge.child] = -1 + old_w = wfn(X[edge.parent]) + for k in range(num_sample_sets): + X[edge.parent][k] += dx[k] + if pi[edge.parent] != -1: + w = wfn(X[edge.parent]) + dt = (node_time[pi[edge.parent]] - node_time[edge.parent]) + for j in range(n_out): + L[j] += dt * (w[j]-old_w[j]) + # propagate change up the tree + u = pi[edge.parent] + if u != -1: + next_u = pi[u] + while u != -1: + old_w = wfn(X[u]) + for k in range(num_sample_sets): + X[u][k] += dx[k] + # need to update X for the root, + # but the root does not have a branch length + if next_u != -1: + w = wfn(X[u]) + dt = (node_time[pi[u]] - node_time[u]) + for j in range(n_out): + L[j] += dt*(w[j] - old_w[j]) + u = next_u + next_u = pi[next_u] + while chrom_pos + length >= windows[window_num + 1]: + # wrap up the last window + this_length = windows[window_num + 1] - chrom_pos + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] += L[j] * this_length + S[window_num][j] /= window_length + length -= this_length + # start the next + if window_num < num_windows - 1: + window_num += 1 + chrom_pos = windows[window_num] + else: + # skips the else statement below + break + else: + for j in range(n_out): + S[window_num][j] += L[j] * length + chrom_pos += length + return S + + def site_frequency_spectrum(self, sample_set, windows=None): + ''' + Computes the expected *derived* (unfolded) site frequency spectrum, + based on tree lengths, separately in each window. + + :param list sample_set: A list of IDs of samples of length n. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of lists of length n, one for each window, whose kth + entry gives the total length of any branches in the marginal trees + over that window that are ancestral to exactly k of the samples, + divided by the length of the window. + ''' + if windows is None: + windows = (0, self.tree_sequence.sequence_length) + if ((not isinstance(sample_set, list)) or + len(sample_set) != len(set(sample_set))): + raise ValueError( + "elements of sample_sets must be lists without repeated elements.") + if len(sample_set) == 0: + raise ValueError("elements of sample_sets cannot be empty.") + for u in sample_set: + if not self.tree_sequence.node(u).is_sample(): + raise ValueError("Not all elements of sample_sets are samples.") + num_windows = len(windows) - 1 + if windows[0] != 0.0: + raise ValueError( + "Windows must start at the start of the sequence (at 0.0).") + if windows[-1] != self.tree_sequence.sequence_length: + raise ValueError("Windows must extend to the end of the sequence.") + for k in range(num_windows): + if windows[k + 1] <= windows[k]: + raise ValueError("Windows must be increasing.") + n_out = len(sample_set) + S = [[0.0 for j in range(n_out)] for _ in range(num_windows)] + L = [0.0 for j in range(n_out)] + N = self.tree_sequence.num_nodes + X = [int(u in sample_set) for u in range(N)] + # we will essentially construct the tree + pi = [-1 for j in range(N)] + node_time = [self.tree_sequence.node(u).time for u in range(N)] + # keep track of where we are for the windows + chrom_pos = 0.0 + # index of *left-hand* end of the current window + window_num = 0 + for interval, records_out, records_in in self.tree_sequence.edge_diffs(): + length = interval[1] - interval[0] + for sign, records in ((-1, records_out), (+1, records_in)): + for edge in records: + dx = 0 + if sign == +1: + pi[edge.child] = edge.parent + dx += sign * X[edge.child] + dt = (node_time[pi[edge.child]] - node_time[edge.child]) + if X[edge.child] > 0: + L[X[edge.child] - 1] += sign * dt + if sign == -1: + pi[edge.child] = -1 + old_X = X[edge.parent] + X[edge.parent] += dx + if pi[edge.parent] != -1: + dt = (node_time[pi[edge.parent]] - node_time[edge.parent]) + if X[edge.parent] > 0: + L[X[edge.parent] - 1] += dt + if old_X > 0: + L[old_X - 1] -= dt + # propagate change up the tree + u = pi[edge.parent] + if u != -1: + next_u = pi[u] + while u != -1: + old_X = X[u] + X[u] += dx + # need to update X for the root, + # but the root does not have a branch length + if next_u != -1: + dt = (node_time[pi[u]] - node_time[u]) + if X[u] > 0: + L[X[u] - 1] += dt + if old_X > 0: + L[old_X - 1] -= dt + u = next_u + next_u = pi[next_u] + while chrom_pos + length >= windows[window_num + 1]: + # wrap up the last window + this_length = windows[window_num + 1] - chrom_pos + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] += L[j] * this_length + S[window_num][j] /= window_length + length -= this_length + # start the next + if window_num < num_windows - 1: + window_num += 1 + chrom_pos = windows[window_num] + else: + # skips the else statement below + break + else: + for j in range(n_out): + S[window_num][j] += L[j] * length + chrom_pos += length + return S + + +class SiteStatCalculator(GeneralStatCalculator): + """ + Class for calculating a broad class of single-site statistics. These are + all calculated using :meth:``SiteStatCalculator.tree_stat_vector`` as the + underlying engine. This class requires the `numpy + `_ library. + + .. warning:: + This interface is still in beta, and may change in the future. + + :param TreeSequence tree_sequence: The tree sequence we will compute + statistics for. + """ + + def __init__(self, tree_sequence): + self.tree_sequence = tree_sequence + + def tree_stat_vector(self, sample_sets, weight_fun, windows=None): + ''' + Here sample_sets is a list of lists of samples, and weight_fun is a + function whose argument is a list of integers of the same length as + sample_sets that returns a list of numbers. Each allele is weighted by + weight_fun(x), where x[i] is the number of samples in sample_sets[i] + that inherit that allele. This finds the sum of this weight for all + polymorphic sites, and divides by the sequence length. + + It does this separately for each window [windows[i], windows[i+1]) and + returns the values in a list. Note that windows cannot be overlapping, + but overlapping windows can be achieved by (a) computing staistics on a + small window size and (b) averaging neighboring windows, by additivity + of the statistics. + ''' + if windows is None: + windows = (0, self.tree_sequence.sequence_length) + for U in sample_sets: + if ((not isinstance(U, list)) or + len(U) != len(set(U))): + raise ValueError( + "elements of sample_sets must be lists without repeated elements.") + if len(U) == 0: + raise ValueError("elements of sample_sets cannot be empty.") + for u in U: + if not self.tree_sequence.node(u).is_sample(): + raise ValueError("Not all elements of sample_sets are samples.") + num_windows = len(windows) - 1 + if windows[0] != 0.0: + raise ValueError( + "Windows must start at the start of the sequence (at 0.0).") + if windows[-1] != self.tree_sequence.sequence_length: + raise ValueError("Windows must extend to the end of the sequence.") + for k in range(num_windows): + if windows[k + 1] <= windows[k]: + raise ValueError("Windows must be increasing.") + num_sample_sets = len(sample_sets) + num_sites = self.tree_sequence.num_sites + n = [len(x) for x in sample_sets] + n_out = len(weight_fun([0 for a in range(num_sample_sets)])) + # we store the final answers here + S = [[0.0 for j in range(n_out)] for _ in range(num_windows)] + if num_sites == 0: + return S + N = self.tree_sequence.num_nodes + # initialize: with no tree, each node is either in a sample set or not + X = [[int(u in a) for a in sample_sets] for u in range(N)] + # we will construct the tree here + pi = [-1 for j in range(N)] + # keep track of which site we're looking at + sites = self.tree_sequence.sites() + ns = 0 # this will record number of sites seen so far + s = next(sites) + # index of *left-hand* end of the current window + window_num = 0 + while s.position > windows[window_num + 1]: + window_num += 1 + for interval, records_out, records_in in self.tree_sequence.edge_diffs(): + # if we've done all the sites then stop + if ns == num_sites: + break + # update the tree + for sign, records in ((-1, records_out), (+1, records_in)): + for edge in records: + dx = [0 for k in range(num_sample_sets)] + if sign == +1: + pi[edge.child] = edge.parent + for k in range(num_sample_sets): + dx[k] += sign * X[edge.child][k] + if sign == -1: + pi[edge.child] = -1 + for k in range(num_sample_sets): + X[edge.parent][k] += dx[k] + # propagate change up the tree + u = pi[edge.parent] + if u != -1: + next_u = pi[u] + while u != -1: + for k in range(num_sample_sets): + X[u][k] += dx[k] + u = next_u + next_u = pi[next_u] + # loop over sites in this tree + while s.position < interval[1]: + if s.position > windows[window_num + 1]: + # finalize this window and move to the next + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] /= window_length + # may need to advance through empty windows + while s.position > windows[window_num + 1]: + window_num += 1 + nm = len(s.mutations) + if nm > 0: + U = {s.ancestral_state: list(n)} + for mut in s.mutations: + if mut.derived_state not in U: + U[mut.derived_state] = [0 for _ in range(num_sample_sets)] + for k in range(num_sample_sets): + U[mut.derived_state][k] += X[mut.node][k] + parent_state = get_derived_state(s, mut.parent) + if parent_state not in U: + U[parent_state] = [0 for _ in range(num_sample_sets)] + for k in range(num_sample_sets): + U[parent_state][k] -= X[mut.node][k] + for a in U: + w = weight_fun(U[a]) + for j in range(n_out): + S[window_num][j] += w[j] + ns += 1 + if ns == num_sites: + break + s = next(sites) + # wrap up the final window + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] /= window_length + return S + + def site_frequency_spectrum(self, sample_set, windows=None): + ''' + Computes the folded site frequency spectrum in sample_set, + independently in windows. + + :param list sample_set: A list of IDs of samples of length n. + :param iterable windows: The breakpoints of the windows (including start + and end, so has one more entry than number of windows). + :return: A list of lists of length n, one for each window, whose kth + entry gives the number of mutations in that window at which a mutation + is seen by exactly k of the samples, divided by the window length. + ''' + if windows is None: + windows = (0, self.tree_sequence.sequence_length) + if ((not isinstance(sample_set, list)) or + len(sample_set) != len(set(sample_set))): + raise ValueError( + "sample_set must not contain repeated elements.") + if len(sample_set) == 0: + raise ValueError("sample_set cannot be empty.") + for u in sample_set: + if not self.tree_sequence.node(u).is_sample(): + raise ValueError("Not all elements of sample_set are samples.") + num_windows = len(windows) - 1 + if windows[0] != 0.0: + raise ValueError( + "Windows must start at the start of the sequence (at 0.0).") + if windows[-1] != self.tree_sequence.sequence_length: + raise ValueError("Windows must extend to the end of the sequence.") + for k in range(num_windows): + if windows[k + 1] <= windows[k]: + raise ValueError("Windows must be increasing.") + num_sites = self.tree_sequence.num_sites + n = len(sample_set) + n_out = n + # we store the final answers here + S = [[0.0 for j in range(n_out)] for _ in range(num_windows)] + if num_sites == 0: + return S + N = self.tree_sequence.num_nodes + # initialize: with no tree, each node is either in a sample set or not + X = [int(u in sample_set) for u in range(N)] + # we will construct the tree here + pi = [-1 for j in range(N)] + # keep track of which site we're looking at + sites = self.tree_sequence.sites() + ns = 0 # this will record number of sites seen so far + s = next(sites) + # index of *left-hand* end of the current window + window_num = 0 + while s.position > windows[window_num + 1]: + window_num += 1 + for interval, records_out, records_in in self.tree_sequence.edge_diffs(): + # if we've done all the sites then stop + if ns == num_sites: + break + # update the tree + for sign, records in ((-1, records_out), (+1, records_in)): + for edge in records: + dx = 0 + if sign == +1: + pi[edge.child] = edge.parent + dx += sign * X[edge.child] + if sign == -1: + pi[edge.child] = -1 + X[edge.parent] += dx + # propagate change up the tree + u = pi[edge.parent] + if u != -1: + next_u = pi[u] + while u != -1: + X[u] += dx + u = next_u + next_u = pi[next_u] + # loop over sites in this tree + while s.position < interval[1]: + if s.position > windows[window_num + 1]: + # finalize this window and move to the next + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] /= window_length + # may need to advance through empty windows + while s.position > windows[window_num + 1]: + window_num += 1 + nm = len(s.mutations) + if nm > 0: + U = {s.ancestral_state: n} + for mut in s.mutations: + if mut.derived_state not in U: + U[mut.derived_state] = 0 + U[mut.derived_state] += X[mut.node] + parent_state = get_derived_state(s, mut.parent) + if parent_state not in U: + U[parent_state] = 0 + U[parent_state] -= X[mut.node] + for a in U: + if U[a] > 0: + S[window_num][U[a] - 1] += 1.0 + ns += 1 + if ns == num_sites: + break + s = next(sites) + # wrap up the final window + window_length = windows[window_num + 1] - windows[window_num] + for j in range(n_out): + S[window_num][j] /= window_length + return S + + +def get_derived_state(site, mut_id): + """ + Find the derived state of the mutation with id `mut_id` at site `site`. + """ + if mut_id == -1: + state = site.ancestral_state + else: + for m in site.mutations: + if m.id == mut_id: + state = m.derived_state + return state diff --git a/python/tskit/tables.py b/python/tskit/tables.py new file mode 100644 index 0000000000..893e1fed25 --- /dev/null +++ b/python/tskit/tables.py @@ -0,0 +1,1830 @@ +""" +Tree sequence IO via the tables API. +""" +from __future__ import division +from __future__ import print_function + +import base64 +import collections +import datetime +import warnings + +import numpy as np +from six.moves import copyreg + +import _tskit +# This circular import is ugly but it seems hard to avoid it since table collection +# and tree sequence depend on each other. Unless they're in the same module they +# need to import each other. In Py3 at least we can import the modules but we +# can't do this in Py3. +import tskit + + +IndividualTableRow = collections.namedtuple( + "IndividualTableRow", + ["flags", "location", "metadata"]) + + +NodeTableRow = collections.namedtuple( + "NodeTableRow", + ["flags", "time", "population", "individual", "metadata"]) + + +EdgeTableRow = collections.namedtuple( + "EdgeTableRow", + ["left", "right", "parent", "child"]) + + +MigrationTableRow = collections.namedtuple( + "MigrationTableRow", + ["left", "right", "node", "source", "dest", "time"]) + + +SiteTableRow = collections.namedtuple( + "SiteTableRow", + ["position", "ancestral_state", "metadata"]) + + +MutationTableRow = collections.namedtuple( + "MutationTableRow", + ["site", "node", "derived_state", "parent", "metadata"]) + + +PopulationTableRow = collections.namedtuple( + "PopulationTableRow", + ["metadata"]) + + +ProvenanceTableRow = collections.namedtuple( + "ProvenanceTableRow", + ["timestamp", "record"]) + + +# TODO We could abstract quite a lot more functionality up into this baseclass +# if each class kept a list of its columns. Then it would be pretty simple to +# define generic implementation of copy, etc. + + +class BaseTable(object): + """ + Superclass of high-level tables. Not intended for direct instantiation. + """ + def __init__(self, ll_table, row_class): + self.ll_table = ll_table + self.row_class = row_class + + def _check_required_args(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + raise TypeError("{} is required".format(k)) + + @property + def num_rows(self): + return self.ll_table.num_rows + + @property + def max_rows(self): + return self.ll_table.max_rows + + @property + def max_rows_increment(self): + return self.ll_table.max_rows_increment + + def __eq__(self, other): + ret = False + if type(other) is type(self): + ret = bool(self.ll_table.equals(other.ll_table)) + return ret + + def __ne__(self, other): + return not self.__eq__(other) + + def __len__(self): + return self.num_rows + + def __getitem__(self, index): + if index < 0: + index += len(self) + if index < 0 or index >= len(self): + raise IndexError("Index out of bounds") + return self.row_class(*self.ll_table.get_row(index)) + + def clear(self): + """ + Deletes all rows in this table. + """ + self.ll_table.clear() + + def reset(self): + # Deprecated alias for clear + self.clear() + + def truncate(self, num_rows): + """ + Truncates this table so that the only the first ``num_rows`` are retained. + + :param int num_rows: The number of rows to retain in this table. + """ + return self.ll_table.truncate(num_rows) + + # Unpickle support + def __setstate__(self, state): + self.set_columns(**state) + + def asdict(self): + """ + Returns a dictionary mapping the names of the columns in this table + to the corresponding numpy arrays. + """ + raise NotImplementedError() + + +class IndividualTable(BaseTable): + """ + A table defining the individuals in a tree sequence. Note that although + each Individual has associated nodes, reference to these is not stored in + the individual table, but rather reference to the individual is stored for + each node in the :class:`NodeTable`. This is similar to the way in which + the relationship between sites and mutations is modelled. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar flags: The array of flags values. + :vartype flags: numpy.ndarray, dtype=np.uint32 + :ivar location: The flattened array of floating point location values. See + :ref:`sec_encoding_ragged_columns` for more details. + :vartype location: numpy.ndarray, dtype=np.float64 + :ivar location_offset: The array of offsets into the location column. See + :ref:`sec_encoding_ragged_columns` for more details. + :vartype location_offset: numpy.ndarray, dtype=np.uint32 + :ivar metadata: The flattened array of binary metadata values. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata: numpy.ndarray, dtype=np.int8 + :ivar metadata_offset: The array of offsets into the metadata column. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata_offset: numpy.ndarray, dtype=np.uint32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.IndividualTable(max_rows_increment=max_rows_increment) + super(IndividualTable, self).__init__(ll_table, IndividualTableRow) + + @property + def flags(self): + return self.ll_table.flags + + @property + def location(self): + return self.ll_table.location + + @property + def location_offset(self): + return self.ll_table.location_offset + + @property + def metadata(self): + return self.ll_table.metadata + + @property + def metadata_offset(self): + return self.ll_table.metadata_offset + + def __str__(self): + flags = self.flags + location = self.location + location_offset = self.location_offset + metadata = unpack_bytes(self.metadata, self.metadata_offset) + ret = "id\tflags\tlocation\tmetadata\n" + for j in range(self.num_rows): + md = base64.b64encode(metadata[j]).decode('utf8') + location_str = ",".join(map( + str, location[location_offset[j]: location_offset[j + 1]])) + ret += "{}\t{}\t{}\t{}\n".format(j, flags[j], location_str, md) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = IndividualTable() + copy.set_columns( + flags=self.flags, + location=self.location, location_offset=self.location_offset, + metadata=self.metadata, metadata_offset=self.metadata_offset) + return copy + + def add_row(self, flags=0, location=None, metadata=None): + """ + Adds a new row to this :class:`IndividualTable` and returns the ID of the + corresponding individual. + + :param int flags: The bitwise flags for the new node. + :param array-like location: A list of numeric values or one-dimensional numpy + array describing the location of this individual. If not specified + or None, a zero-dimensional location is stored. + :param bytes metadata: The binary-encoded metadata for the new node. If not + specified or None, a zero-length byte string is stored. + :return: The ID of the newly added node. + :rtype: int + """ + return self.ll_table.add_row(flags=flags, location=location, metadata=metadata) + + def set_columns( + self, flags=None, location=None, location_offset=None, + metadata=None, metadata_offset=None): + """ + Sets the values for each column in this :class:`.IndividualTable` using the + values in the specified arrays. Overwrites any data currently stored in + the table. + + The ``flags`` array is mandatory and defines the number of individuals + the table will contain. + The ``location`` and ``location_offset`` parameters must be supplied + together, and meet the requirements for :ref:`sec_encoding_ragged_columns`. + The ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for :ref:`sec_encoding_ragged_columns`. + See :ref:`sec_tables_api_binary_columns` for more information. + + :param flags: The bitwise flags for each individual. Required. + :type flags: numpy.ndarray, dtype=np.uint32 + :param location: The flattened location array. Must be specified along + with ``location_offset``. If not specified or None, an empty location + value is stored for each individual. + :type location: numpy.ndarray, dtype=np.float64 + :param location_offset: The offsets into the ``location`` array. + :type location_offset: numpy.ndarray, dtype=np.uint32. + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each individual. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args(flags=flags) + self.ll_table.set_columns(dict( + flags=flags, location=location, location_offset=location_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def append_columns( + self, flags=None, location=None, location_offset=None, metadata=None, + metadata_offset=None): + """ + Appends the specified arrays to the end of the columns in this + :class:`IndividualTable`. This allows many new rows to be added at once. + + The ``flags`` array is mandatory and defines the number of + extra individuals to add to the table. + The ``location`` and ``location_offset`` parameters must be supplied + together, and meet the requirements for :ref:`sec_encoding_ragged_columns`. + The ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for :ref:`sec_encoding_ragged_columns`. + See :ref:`sec_tables_api_binary_columns` for more information. + + :param flags: The bitwise flags for each individual. Required. + :type flags: numpy.ndarray, dtype=np.uint32 + :param location: The flattened location array. Must be specified along + with ``location_offset``. If not specified or None, an empty location + value is stored for each individual. + :type location: numpy.ndarray, dtype=np.float64 + :param location_offset: The offsets into the ``location`` array. + :type location_offset: numpy.ndarray, dtype=np.uint32. + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each individual. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args(flags=flags) + self.ll_table.append_columns(dict( + flags=flags, location=location, location_offset=location_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def asdict(self): + return { + "flags": self.flags, + "location": self.location, + "location_offset": self.location_offset, + "metadata": self.metadata, + "metadata_offset": self.metadata_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _pickle_individual_table(table): + return IndividualTable, tuple(), table.asdict() + + +class NodeTable(BaseTable): + """ + A table defining the nodes in a tree sequence. See the + :ref:`definitions ` for details on the columns + in this table and the + :ref:`tree sequence requirements ` section + for the properties needed for a node table to be a part of a valid tree sequence. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar time: The array of time values. + :vartype time: numpy.ndarray, dtype=np.float64 + :ivar flags: The array of flags values. + :vartype flags: numpy.ndarray, dtype=np.uint32 + :ivar population: The array of population IDs. + :vartype population: numpy.ndarray, dtype=np.int32 + :ivar individual: The array of individual IDs that each node belongs to. + :vartype individual: numpy.ndarray, dtype=np.int32 + :ivar metadata: The flattened array of binary metadata values. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata: numpy.ndarray, dtype=np.int8 + :ivar metadata_offset: The array of offsets into the metadata column. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata_offset: numpy.ndarray, dtype=np.uint32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.NodeTable(max_rows_increment=max_rows_increment) + super(NodeTable, self).__init__(ll_table, NodeTableRow) + + @property + def time(self): + return self.ll_table.time + + @property + def flags(self): + return self.ll_table.flags + + @property + def population(self): + return self.ll_table.population + + @property + def individual(self): + return self.ll_table.individual + + # EXPERIMENTAL interface for setting a single column. This is done + # quite a bit in tests. Not part of the public API as yet, but we + # probably will want to allow something like this in general. + @individual.setter + def individual(self, individual): + self.set_columns( + flags=self.flags, time=self.time, population=self.population, + metadata=self.metadata, metadata_offset=self.metadata_offset, + individual=individual) + + @property + def metadata(self): + return self.ll_table.metadata + + @property + def metadata_offset(self): + return self.ll_table.metadata_offset + + def __str__(self): + time = self.time + flags = self.flags + population = self.population + individual = self.individual + metadata = unpack_bytes(self.metadata, self.metadata_offset) + ret = "id\tflags\tpopulation\tindividual\ttime\tmetadata\n" + for j in range(self.num_rows): + md = base64.b64encode(metadata[j]).decode('utf8') + ret += "{}\t{}\t{}\t{}\t{:.14f}\t{}\n".format( + j, flags[j], population[j], individual[j], time[j], md) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = NodeTable() + copy.set_columns( + flags=self.flags, time=self.time, population=self.population, + individual=self.individual, metadata=self.metadata, + metadata_offset=self.metadata_offset) + return copy + + def add_row(self, flags=0, time=0, population=-1, individual=-1, metadata=None): + """ + Adds a new row to this :class:`NodeTable` and returns the ID of the + corresponding node. + + :param int flags: The bitwise flags for the new node. + :param float time: The birth time for the new node. + :param int population: The ID of the population in which the new node was born. + Defaults to :const:`.NULL`. + :param int individual: The ID of the individual in which the new node was born. + Defaults to :const:`.NULL`. + :param bytes metadata: The binary-encoded metadata for the new node. If not + specified or None, a zero-length byte string is stored. + :return: The ID of the newly added node. + :rtype: int + """ + return self.ll_table.add_row(flags, time, population, individual, metadata) + + def set_columns( + self, flags=None, time=None, population=None, individual=None, metadata=None, + metadata_offset=None): + """ + Sets the values for each column in this :class:`.NodeTable` using the values in + the specified arrays. Overwrites any data currently stored in the table. + + The ``flags``, ``time`` and ``population`` arrays must all be of the same length, + which is equal to the number of nodes the table will contain. The + ``metadata`` and ``metadata_offset`` parameters must be supplied together, and + meet the requirements for :ref:`sec_encoding_ragged_columns`. + See :ref:`sec_tables_api_binary_columns` for more information. + + :param flags: The bitwise flags for each node. Required. + :type flags: numpy.ndarray, dtype=np.uint32 + :param time: The time values for each node. Required. + :type time: numpy.ndarray, dtype=np.float64 + :param population: The population values for each node. If not specified + or None, the :const:`.NULL` value is stored for each node. + :type population: numpy.ndarray, dtype=np.int32 + :param individual: The individual values for each node. If not specified + or None, the :const:`.NULL` value is stored for each node. + :type individual: numpy.ndarray, dtype=np.int32 + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args(flags=flags, time=time) + self.ll_table.set_columns(dict( + flags=flags, time=time, population=population, individual=individual, + metadata=metadata, metadata_offset=metadata_offset)) + + def append_columns( + self, flags=None, time=None, population=None, individual=None, metadata=None, + metadata_offset=None): + """ + Appends the specified arrays to the end of the columns in this + :class:`NodeTable`. This allows many new rows to be added at once. + + The ``flags``, ``time`` and ``population`` arrays must all be of the same length, + which is equal to the number of nodes that will be added to the table. The + ``metadata`` and ``metadata_offset`` parameters must be supplied together, and + meet the requirements for :ref:`sec_encoding_ragged_columns`. + See :ref:`sec_tables_api_binary_columns` for more information. + + :param flags: The bitwise flags for each node. Required. + :type flags: numpy.ndarray, dtype=np.uint32 + :param time: The time values for each node. Required. + :type time: numpy.ndarray, dtype=np.float64 + :param population: The population values for each node. If not specified + or None, the :const:`.NULL` value is stored for each node. + :type population: numpy.ndarray, dtype=np.int32 + :param individual: The individual values for each node. If not specified + or None, the :const:`.NULL` value is stored for each node. + :type individual: numpy.ndarray, dtype=np.int32 + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args(flags=flags, time=time) + self.ll_table.append_columns(dict( + flags=flags, time=time, population=population, individual=individual, + metadata=metadata, metadata_offset=metadata_offset)) + + def asdict(self): + return { + "time": self.time, + "flags": self.flags, + "population": self.population, + "individual": self.individual, + "metadata": self.metadata, + "metadata_offset": self.metadata_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _pickle_node_table(table): + return NodeTable, tuple(), table.asdict() + + +class EdgeTable(BaseTable): + """ + A table defining the edges in a tree sequence. See the + :ref:`definitions ` for details on the columns + in this table and the + :ref:`tree sequence requirements ` section + for the properties needed for an edge table to be a part of a valid tree sequence. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar left: The array of left coordinates. + :vartype left: numpy.ndarray, dtype=np.float64 + :ivar right: The array of right coordinates. + :vartype right: numpy.ndarray, dtype=np.float64 + :ivar parent: The array of parent node IDs. + :vartype parent: numpy.ndarray, dtype=np.int32 + :ivar child: The array of child node IDs. + :vartype child: numpy.ndarray, dtype=np.int32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.EdgeTable(max_rows_increment=max_rows_increment) + super(EdgeTable, self).__init__(ll_table, EdgeTableRow) + + @property + def left(self): + return self.ll_table.left + + @property + def right(self): + return self.ll_table.right + + @property + def parent(self): + return self.ll_table.parent + + @property + def child(self): + return self.ll_table.child + + def __str__(self): + left = self.left + right = self.right + parent = self.parent + child = self.child + ret = "id\tleft\t\tright\t\tparent\tchild\n" + for j in range(self.num_rows): + ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\n".format( + j, left[j], right[j], parent[j], child[j]) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = EdgeTable() + copy.set_columns( + left=self.left, right=self.right, parent=self.parent, child=self.child) + return copy + + def add_row(self, left, right, parent, child): + """ + Adds a new row to this :class:`EdgeTable` and returns the ID of the + corresponding edge. + + :param float left: The left coordinate (inclusive). + :param float right: The right coordinate (exclusive). + :param int parent: The ID of parent node. + :param int child: The ID of child node. + :return: The ID of the newly added edge. + :rtype: int + """ + return self.ll_table.add_row(left, right, parent, child) + + def set_columns(self, left=None, right=None, parent=None, child=None): + """ + Sets the values for each column in this :class:`.EdgeTable` using the values + in the specified arrays. Overwrites any data currently stored in the table. + + All four parameters are mandatory, and must be numpy arrays of the + same length (which is equal to the number of edges the table will contain). + + :param left: The left coordinates (inclusive). + :type left: numpy.ndarray, dtype=np.float64 + :param right: The right coordinates (exclusive). + :type right: numpy.ndarray, dtype=np.float64 + :param parent: The parent node IDs. + :type parent: numpy.ndarray, dtype=np.int32 + :param child: The child node IDs. + :type child: numpy.ndarray, dtype=np.int32 + """ + self._check_required_args(left=left, right=right, parent=parent, child=child) + self.ll_table.set_columns(dict( + left=left, right=right, parent=parent, child=child)) + + def append_columns(self, left, right, parent, child): + """ + Appends the specified arrays to the end of the columns of this + :class:`EdgeTable`. This allows many new rows to be added at once. + + All four parameters are mandatory, and must be numpy arrays of the + same length (which is equal to the number of additional edges to + add to the table). + + :param left: The left coordinates (inclusive). + :type left: numpy.ndarray, dtype=np.float64 + :param right: The right coordinates (exclusive). + :type right: numpy.ndarray, dtype=np.float64 + :param parent: The parent node IDs. + :type parent: numpy.ndarray, dtype=np.int32 + :param child: The child node IDs. + :type child: numpy.ndarray, dtype=np.int32 + """ + self.ll_table.append_columns(dict( + left=left, right=right, parent=parent, child=child)) + + def asdict(self): + return { + "left": self.left, + "right": self.right, + "parent": self.parent, + "child": self.child, + } + + +# Pickle support. See copyreg registration for this function below. +def _edge_table_pickle(table): + return EdgeTable, tuple(), table.asdict() + + +class MigrationTable(BaseTable): + """ + A table defining the migrations in a tree sequence. See the + :ref:`definitions ` for details on the columns + in this table and the + :ref:`tree sequence requirements ` section + for the properties needed for a migration table to be a part of a valid tree + sequence. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar left: The array of left coordinates. + :vartype left: numpy.ndarray, dtype=np.float64 + :ivar right: The array of right coordinates. + :vartype right: numpy.ndarray, dtype=np.float64 + :ivar node: The array of node IDs. + :vartype node: numpy.ndarray, dtype=np.int32 + :ivar source: The array of source population IDs. + :vartype source: numpy.ndarray, dtype=np.int32 + :ivar dest: The array of destination population IDs. + :vartype dest: numpy.ndarray, dtype=np.int32 + :ivar time: The array of time values. + :vartype time: numpy.ndarray, dtype=np.float64 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.MigrationTable(max_rows_increment=max_rows_increment) + super(MigrationTable, self).__init__(ll_table, MigrationTableRow) + + @property + def left(self): + return self.ll_table.left + + @property + def right(self): + return self.ll_table.right + + @property + def node(self): + return self.ll_table.node + + @property + def source(self): + return self.ll_table.source + + @property + def dest(self): + return self.ll_table.dest + + @property + def time(self): + return self.ll_table.time + + def __str__(self): + left = self.left + right = self.right + node = self.node + source = self.source + dest = self.dest + time = self.time + ret = "id\tleft\tright\tnode\tsource\tdest\ttime\n" + for j in range(self.num_rows): + ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\t{:.8f}\n".format( + j, left[j], right[j], node[j], source[j], dest[j], time[j]) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = MigrationTable() + copy.set_columns( + left=self.left, right=self.right, node=self.node, source=self.source, + dest=self.dest, time=self.time) + return copy + + def add_row(self, left, right, node, source, dest, time): + """ + Adds a new row to this :class:`MigrationTable` and returns the ID of the + corresponding migration. + + :param float left: The left coordinate (inclusive). + :param float right: The right coordinate (exclusive). + :param int node: The node ID. + :param int source: The ID of the source population. + :param int dest: The ID of the destination population. + :param float time: The time of the migration event. + :return: The ID of the newly added migration. + :rtype: int + """ + return self.ll_table.add_row(left, right, node, source, dest, time) + + def set_columns( + self, left=None, right=None, node=None, source=None, dest=None, time=None): + """ + Sets the values for each column in this :class:`.MigrationTable` using the values + in the specified arrays. Overwrites any data currently stored in the table. + + All six parameters are mandatory, and must be numpy arrays of the + same length (which is equal to the number of migrations the table will contain). + + :param left: The left coordinates (inclusive). + :type left: numpy.ndarray, dtype=np.float64 + :param right: The right coordinates (exclusive). + :type right: numpy.ndarray, dtype=np.float64 + :param node: The node IDs. + :type node: numpy.ndarray, dtype=np.int32 + :param source: The source population IDs. + :type source: numpy.ndarray, dtype=np.int32 + :param dest: The destination population IDs. + :type dest: numpy.ndarray, dtype=np.int32 + :param time: The time of each migration. + :type time: numpy.ndarray, dtype=np.int64 + """ + self._check_required_args( + left=left, right=right, node=node, source=source, dest=dest, time=time) + self.ll_table.set_columns(dict( + left=left, right=right, node=node, source=source, dest=dest, time=time)) + + def append_columns(self, left, right, node, source, dest, time): + """ + Appends the specified arrays to the end of the columns of this + :class:`MigrationTable`. This allows many new rows to be added at once. + + All six parameters are mandatory, and must be numpy arrays of the + same length (which is equal to the number of additional migrations + to add to the table). + + :param left: The left coordinates (inclusive). + :type left: numpy.ndarray, dtype=np.float64 + :param right: The right coordinates (exclusive). + :type right: numpy.ndarray, dtype=np.float64 + :param node: The node IDs. + :type node: numpy.ndarray, dtype=np.int32 + :param source: The source population IDs. + :type source: numpy.ndarray, dtype=np.int32 + :param dest: The destination population IDs. + :type dest: numpy.ndarray, dtype=np.int32 + :param time: The time of each migration. + :type time: numpy.ndarray, dtype=np.int64 + """ + self.ll_table.append_columns(dict( + left=left, right=right, node=node, source=source, dest=dest, time=time)) + + def asdict(self): + return { + "left": self.left, + "right": self.right, + "node": self.node, + "source": self.source, + "dest": self.dest, + "time": self.time, + } + + +# Pickle support. See copyreg registration for this function below. +def _migration_table_pickle(table): + return MigrationTable, tuple(), table.asdict() + + +class SiteTable(BaseTable): + """ + A table defining the sites in a tree sequence. See the + :ref:`definitions ` for details on the columns + in this table and the + :ref:`tree sequence requirements ` section + for the properties needed for a site table to be a part of a valid tree + sequence. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar position: The array of site position coordinates. + :vartype position: numpy.ndarray, dtype=np.float64 + :ivar ancestral_state: The flattened array of ancestral state strings. + See :ref:`sec_tables_api_text_columns` for more details. + :vartype ancestral_state: numpy.ndarray, dtype=np.int8 + :ivar ancestral_state_offset: The offsets of rows in the ancestral_state + array. See :ref:`sec_tables_api_text_columns` for more details. + :vartype ancestral_state_offset: numpy.ndarray, dtype=np.uint32 + :ivar metadata: The flattened array of binary metadata values. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata: numpy.ndarray, dtype=np.int8 + :ivar metadata_offset: The array of offsets into the metadata column. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata_offset: numpy.ndarray, dtype=np.uint32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.SiteTable(max_rows_increment=max_rows_increment) + super(SiteTable, self).__init__(ll_table, SiteTableRow) + + @property + def position(self): + return self.ll_table.position + + @property + def ancestral_state(self): + return self.ll_table.ancestral_state + + @property + def ancestral_state_offset(self): + return self.ll_table.ancestral_state_offset + + @property + def metadata(self): + return self.ll_table.metadata + + @property + def metadata_offset(self): + return self.ll_table.metadata_offset + + def __str__(self): + position = self.position + ancestral_state = unpack_strings( + self.ancestral_state, self.ancestral_state_offset) + metadata = unpack_bytes(self.metadata, self.metadata_offset) + ret = "id\tposition\tancestral_state\tmetadata\n" + for j in range(self.num_rows): + md = base64.b64encode(metadata[j]).decode('utf8') + ret += "{}\t{:.8f}\t{}\t{}\n".format( + j, position[j], ancestral_state[j], md) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = SiteTable() + copy.set_columns( + position=self.position, + ancestral_state=self.ancestral_state, + ancestral_state_offset=self.ancestral_state_offset, + metadata=self.metadata, + metadata_offset=self.metadata_offset) + return copy + + def add_row(self, position, ancestral_state, metadata=None): + """ + Adds a new row to this :class:`SiteTable` and returns the ID of the + corresponding site. + + :param float position: The position of this site in genome coordinates. + :param str ancestral_state: The state of this site at the root of the tree. + :param bytes metadata: The binary-encoded metadata for the new node. If not + specified or None, a zero-length byte string is stored. + :return: The ID of the newly added site. + :rtype: int + """ + return self.ll_table.add_row(position, ancestral_state, metadata) + + def set_columns( + self, position=None, ancestral_state=None, ancestral_state_offset=None, + metadata=None, metadata_offset=None): + """ + Sets the values for each column in this :class:`.SiteTable` using the values + in the specified arrays. Overwrites any data currently stored in the table. + + The ``position``, ``ancestral_state`` and ``ancestral_state_offset`` + parameters are mandatory, and must be 1D numpy arrays. The length + of the ``position`` array determines the number of rows in table. + The ``ancestral_state`` and ``ancestral_state_offset`` parameters must + be supplied together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_text_columns` for more information). The + ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_binary_columns` for more information). + + :param position: The position of each site in genome coordinates. + :type position: numpy.ndarray, dtype=np.float64 + :param ancestral_state: The flattened ancestral_state array. Required. + :type ancestral_state: numpy.ndarray, dtype=np.int8 + :param ancestral_state_offset: The offsets into the ``ancestral_state`` array. + :type ancestral_state_offset: numpy.ndarray, dtype=np.uint32. + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args( + position=position, ancestral_state=ancestral_state, + ancestral_state_offset=ancestral_state_offset) + self.ll_table.set_columns(dict( + position=position, ancestral_state=ancestral_state, + ancestral_state_offset=ancestral_state_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def append_columns( + self, position, ancestral_state, ancestral_state_offset, + metadata=None, metadata_offset=None): + """ + Appends the specified arrays to the end of the columns of this + :class:`SiteTable`. This allows many new rows to be added at once. + + The ``position``, ``ancestral_state`` and ``ancestral_state_offset`` + parameters are mandatory, and must be 1D numpy arrays. The length + of the ``position`` array determines the number of additional rows + to add the table. + The ``ancestral_state`` and ``ancestral_state_offset`` parameters must + be supplied together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_text_columns` for more information). The + ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_binary_columns` for more information). + + :param position: The position of each site in genome coordinates. + :type position: numpy.ndarray, dtype=np.float64 + :param ancestral_state: The flattened ancestral_state array. Required. + :type ancestral_state: numpy.ndarray, dtype=np.int8 + :param ancestral_state_offset: The offsets into the ``ancestral_state`` array. + :type ancestral_state_offset: numpy.ndarray, dtype=np.uint32. + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self.ll_table.append_columns(dict( + position=position, ancestral_state=ancestral_state, + ancestral_state_offset=ancestral_state_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def asdict(self): + return { + "position": self.position, + "ancestral_state": self.ancestral_state, + "ancestral_state_offset": self.ancestral_state_offset, + "metadata": self.metadata, + "metadata_offset": self.metadata_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _site_table_pickle(table): + return SiteTable, tuple(), table.asdict() + + +class MutationTable(BaseTable): + """ + A table defining the mutations in a tree sequence. See the + :ref:`definitions ` for details on the columns + in this table and the + :ref:`tree sequence requirements ` section + for the properties needed for a mutation table to be a part of a valid tree + sequence. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar site: The array of site IDs. + :vartype site: numpy.ndarray, dtype=np.int32 + :ivar node: The array of node IDs. + :vartype node: numpy.ndarray, dtype=np.int32 + :ivar derived_state: The flattened array of derived state strings. + See :ref:`sec_tables_api_text_columns` for more details. + :vartype derived_state: numpy.ndarray, dtype=np.int8 + :ivar derived_state_offset: The offsets of rows in the derived_state + array. See :ref:`sec_tables_api_text_columns` for more details. + :vartype derived_state_offset: numpy.ndarray, dtype=np.uint32 + :ivar parent: The array of parent mutation IDs. + :vartype parent: numpy.ndarray, dtype=np.int32 + :ivar metadata: The flattened array of binary metadata values. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata: numpy.ndarray, dtype=np.int8 + :ivar metadata_offset: The array of offsets into the metadata column. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata_offset: numpy.ndarray, dtype=np.uint32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.MutationTable(max_rows_increment=max_rows_increment) + super(MutationTable, self).__init__(ll_table, MutationTableRow) + + @property + def site(self): + return self.ll_table.site + + @property + def node(self): + return self.ll_table.node + + @property + def parent(self): + return self.ll_table.parent + + @property + def derived_state(self): + return self.ll_table.derived_state + + @property + def derived_state_offset(self): + return self.ll_table.derived_state_offset + + @property + def metadata(self): + return self.ll_table.metadata + + @property + def metadata_offset(self): + return self.ll_table.metadata_offset + + def __str__(self): + site = self.site + node = self.node + parent = self.parent + derived_state = unpack_strings(self.derived_state, self.derived_state_offset) + metadata = unpack_bytes(self.metadata, self.metadata_offset) + ret = "id\tsite\tnode\tderived_state\tparent\tmetadata\n" + for j in range(self.num_rows): + md = base64.b64encode(metadata[j]).decode('utf8') + ret += "{}\t{}\t{}\t{}\t{}\t{}\n".format( + j, site[j], node[j], derived_state[j], parent[j], md) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = MutationTable() + copy.set_columns( + site=self.site, node=self.node, parent=self.parent, + derived_state=self.derived_state, + derived_state_offset=self.derived_state_offset, + metadata=self.metadata, metadata_offset=self.metadata_offset) + return copy + + def add_row(self, site, node, derived_state, parent=-1, metadata=None): + """ + Adds a new row to this :class:`MutationTable` and returns the ID of the + corresponding mutation. + + :param int site: The ID of the site that this mutation occurs at. + :param int node: The ID of the first node inheriting this mutation. + :param str derived_state: The state of the site at this mutation's node. + :param int parent: The ID of the parent mutation. If not specified, + defaults to :attr:`NULL`. + :param bytes metadata: The binary-encoded metadata for the new node. If not + specified or None, a zero-length byte string is stored. + :return: The ID of the newly added mutation. + :rtype: int + """ + return self.ll_table.add_row( + site, node, derived_state, parent, metadata) + + def set_columns( + self, site=None, node=None, derived_state=None, derived_state_offset=None, + parent=None, metadata=None, metadata_offset=None): + """ + Sets the values for each column in this :class:`.MutationTable` using the values + in the specified arrays. Overwrites any data currently stored in the table. + + The ``site``, ``node``, ``derived_state`` and ``derived_state_offset`` + parameters are mandatory, and must be 1D numpy arrays. The + ``site`` and ``node`` (also ``parent``, if supplied) arrays + must be of equal length, and determine the number of rows in the table. + The ``derived_state`` and ``derived_state_offset`` parameters must + be supplied together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_text_columns` for more information). The + ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_binary_columns` for more information). + + :param site: The ID of the site each mutation occurs at. + :type site: numpy.ndarray, dtype=np.int32 + :param node: The ID of the node each mutation is associated with. + :type node: numpy.ndarray, dtype=np.int32 + :param derived_state: The flattened derived_state array. Required. + :type derived_state: numpy.ndarray, dtype=np.int8 + :param derived_state_offset: The offsets into the ``derived_state`` array. + :type derived_state_offset: numpy.ndarray, dtype=np.uint32. + :param parent: The ID of the parent mutation for each mutation. + :type parent: numpy.ndarray, dtype=np.int32 + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self._check_required_args( + site=site, node=node, derived_state=derived_state, + derived_state_offset=derived_state_offset) + self.ll_table.set_columns(dict( + site=site, node=node, parent=parent, + derived_state=derived_state, derived_state_offset=derived_state_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def append_columns( + self, site, node, derived_state, derived_state_offset, + parent=None, metadata=None, metadata_offset=None): + """ + Appends the specified arrays to the end of the columns of this + :class:`MutationTable`. This allows many new rows to be added at once. + + The ``site``, ``node``, ``derived_state`` and ``derived_state_offset`` + parameters are mandatory, and must be 1D numpy arrays. The + ``site`` and ``node`` (also ``parent``, if supplied) arrays + must be of equal length, and determine the number of additional + rows to add to the table. + The ``derived_state`` and ``derived_state_offset`` parameters must + be supplied together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_text_columns` for more information). The + ``metadata`` and ``metadata_offset`` parameters must be supplied + together, and meet the requirements for + :ref:`sec_encoding_ragged_columns` (see + :ref:`sec_tables_api_binary_columns` for more information). + + :param site: The ID of the site each mutation occurs at. + :type site: numpy.ndarray, dtype=np.int32 + :param node: The ID of the node each mutation is associated with. + :type node: numpy.ndarray, dtype=np.int32 + :param derived_state: The flattened derived_state array. Required. + :type derived_state: numpy.ndarray, dtype=np.int8 + :param derived_state_offset: The offsets into the ``derived_state`` array. + :type derived_state_offset: numpy.ndarray, dtype=np.uint32. + :param parent: The ID of the parent mutation for each mutation. + :type parent: numpy.ndarray, dtype=np.int32 + :param metadata: The flattened metadata array. Must be specified along + with ``metadata_offset``. If not specified or None, an empty metadata + value is stored for each node. + :type metadata: numpy.ndarray, dtype=np.int8 + :param metadata_offset: The offsets into the ``metadata`` array. + :type metadata_offset: numpy.ndarray, dtype=np.uint32. + """ + self.ll_table.append_columns(dict( + site=site, node=node, parent=parent, + derived_state=derived_state, derived_state_offset=derived_state_offset, + metadata=metadata, metadata_offset=metadata_offset)) + + def asdict(self): + return { + "site": self.site, + "node": self.node, + "parent": self.parent, + "derived_state": self.derived_state, + "derived_state_offset": self.derived_state_offset, + "metadata": self.metadata, + "metadata_offset": self.metadata_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _mutation_table_pickle(table): + return MutationTable, tuple(), table.asdict() + + +class PopulationTable(BaseTable): + """ + A table defining the populations referred to in a tree sequence. + The PopulationTable stores metadata for populations that may be referred to + in the NodeTable and MigrationTable". Note that although nodes + may be associated with populations, this association is stored in + the :class:`NodeTable`: only metadata on each population is stored + in the population table. + + :warning: The numpy arrays returned by table attribute accesses are **copies** + of the underlying data. In particular, this means that you cannot edit + the values in the columns by updating the attribute arrays. + + **NOTE:** this behaviour may change in future. + + :ivar metadata: The flattened array of binary metadata values. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata: numpy.ndarray, dtype=np.int8 + :ivar metadata_offset: The array of offsets into the metadata column. See + :ref:`sec_tables_api_binary_columns` for more details. + :vartype metadata_offset: numpy.ndarray, dtype=np.uint32 + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.PopulationTable(max_rows_increment=max_rows_increment) + super(PopulationTable, self).__init__(ll_table, PopulationTableRow) + + @property + def metadata(self): + return self.ll_table.metadata + + @property + def metadata_offset(self): + return self.ll_table.metadata_offset + + def add_row(self, metadata=None): + """ + Adds a new row to this :class:`PopulationTable` and returns the ID of the + corresponding population. + + :param bytes metadata: The binary-encoded metadata for the new population. + If not specified or None, a zero-length byte string is stored. + :return: The ID of the newly added population. + :rtype: int + """ + return self.ll_table.add_row(metadata=metadata) + + def __str__(self): + metadata = unpack_bytes(self.metadata, self.metadata_offset) + ret = "id\tmetadata\n" + for j in range(self.num_rows): + md = base64.b64encode(metadata[j]).decode('utf8') + ret += "{}\t{}\n".format(j, md) + return ret[:-1] + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = PopulationTable() + copy.set_columns( + metadata=self.metadata, + metadata_offset=self.metadata_offset) + return copy + + def set_columns(self, metadata=None, metadata_offset=None): + self.ll_table.set_columns( + dict(metadata=metadata, metadata_offset=metadata_offset)) + + def append_columns(self, metadata=None, metadata_offset=None): + self.ll_table.append_columns( + dict(metadata=metadata, metadata_offset=metadata_offset)) + + def asdict(self): + return { + "metadata": self.metadata, + "metadata_offset": self.metadata_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _population_table_pickle(table): + return PopulationTable, tuple(), table.asdict() + + +class ProvenanceTable(BaseTable): + """ + A table recording the provenance (i.e., history) of this table, so that the + origin of the underlying data and sequence of subsequent operations can be + traced. Each row contains a "record" string (recommended format: JSON) and + a timestamp. + + .. todo:: + The format of the `record` field will be more precisely specified in + the future. + + :ivar record: The flattened array containing the record strings. + :ref:`sec_tables_api_text_columns` for more details. + :vartype record: numpy.ndarray, dtype=np.int8 + :ivar record_offset: The array of offsets into the record column. See + :ref:`sec_tables_api_text_columns` for more details. + :vartype record_offset: numpy.ndarray, dtype=np.uint32 + :ivar timestamp: The flattened array containing the timestamp strings. + :ref:`sec_tables_api_text_columns` for more details. + :vartype timestamp: numpy.ndarray, dtype=np.int8 + :ivar timestamp_offset: The array of offsets into the timestamp column. See + :ref:`sec_tables_api_text_columns` for more details. + :vartype timestamp_offset: numpy.ndarray, dtype=np.uint32 + + """ + def __init__(self, max_rows_increment=0, ll_table=None): + if ll_table is None: + ll_table = _tskit.ProvenanceTable(max_rows_increment=max_rows_increment) + super(ProvenanceTable, self).__init__(ll_table, ProvenanceTableRow) + + @property + def record(self): + return self.ll_table.record + + @property + def record_offset(self): + return self.ll_table.record_offset + + @property + def timestamp(self): + return self.ll_table.timestamp + + @property + def timestamp_offset(self): + return self.ll_table.timestamp_offset + + def add_row(self, record, timestamp=None): + """ + Adds a new row to this ProvenanceTable consisting of the specified record and + timestamp. If timestamp is not specified, it is automatically generated from + the current time. + + :param str record: A provenance record, describing the parameters and + environment used to generate the current set of tables. + :param str timestamp: A string timestamp. This should be in ISO8601 form. + """ + if timestamp is None: + timestamp = datetime.datetime.now().isoformat() + # Note that the order of the positional arguments has been reversed + # from the low-level module, which is a bit confusing. However, we + # want the default behaviour here to be to add a row to the table at + # the current time as simply as possible. + return self.ll_table.add_row(record=record, timestamp=timestamp) + + def set_columns( + self, timestamp=None, timestamp_offset=None, + record=None, record_offset=None): + self.ll_table.set_columns(dict( + timestamp=timestamp, timestamp_offset=timestamp_offset, + record=record, record_offset=record_offset)) + + def append_columns( + self, timestamp=None, timestamp_offset=None, + record=None, record_offset=None): + self.ll_table.append_columns(dict( + timestamp=timestamp, timestamp_offset=timestamp_offset, + record=record, record_offset=record_offset)) + + def __str__(self): + timestamp = unpack_strings(self.timestamp, self.timestamp_offset) + record = unpack_strings(self.record, self.record_offset) + ret = "id\ttimestamp\trecord\n" + for j in range(self.num_rows): + ret += "{}\t{}\t{}\n".format(j, timestamp[j], record[j]) + return ret[:-1] + + # Unpickle support + def __setstate__(self, state): + self.set_columns( + timestamp=state["timestamp"], + timestamp_offset=state["timestamp_offset"], + record=state["record"], + record_offset=state["record_offset"]) + + def copy(self): + """ + Returns a deep copy of this table. + """ + copy = ProvenanceTable() + copy.set_columns( + timestamp=self.timestamp, + timestamp_offset=self.timestamp_offset, + record=self.record, + record_offset=self.record_offset) + return copy + + def asdict(self): + return { + "timestamp": self.timestamp, + "timestamp_offset": self.timestamp_offset, + "record": self.record, + "record_offset": self.record_offset, + } + + +# Pickle support. See copyreg registration for this function below. +def _provenance_table_pickle(table): + return ProvenanceTable, tuple(), table.asdict() + + +class TableCollection(object): + """ + A collection of mutable tables defining a tree sequence. See the + :ref:`sec_data_model` section for definition on the various tables + and how they together define a :class:`TreeSequence`. Arbitrary + data can be stored in a TableCollection, but there are certain + :ref:`requirements ` that must be + satisfied for these tables to be interpreted as a tree sequence. + + To obtain a :class:`TreeSequence` instance corresponding to the current + state of a ``TableCollection``, please use the :meth:`.tree_sequence` + method. + + :ivar individuals: The individual table. + :vartype individuals: IndividualTable + :ivar nodes: The node table. + :vartype nodes: NodeTable + :ivar edges: The edge table. + :vartype edges: EdgeTable + :ivar migrations: The migration table. + :vartype migrations: MigrationTable + :ivar sites: The site table. + :vartype sites: SiteTable + :ivar mutations: The mutation table. + :vartype mutations: MutationTable + :ivar populations: The population table. + :vartype populations: PopulationTable + :ivar provenances: The provenance table. + :vartype provenances: ProvenanceTable + :ivar sequence_length: The sequence length defining the coordinate + space. + :vartype sequence_length: float + :ivar file_uuid: The UUID for the file this TableCollection is derived + from, or None if not derived from a file. + :vartype file_uuid: str + """ + def __init__(self, sequence_length=0): + self.ll_tables = _tskit.TableCollection(sequence_length) + + @property + def individuals(self): + return IndividualTable(ll_table=self.ll_tables.individuals) + + @property + def nodes(self): + return NodeTable(ll_table=self.ll_tables.nodes) + + @property + def edges(self): + return EdgeTable(ll_table=self.ll_tables.edges) + + @property + def migrations(self): + return MigrationTable(ll_table=self.ll_tables.migrations) + + @property + def sites(self): + return SiteTable(ll_table=self.ll_tables.sites) + + @property + def mutations(self): + return MutationTable(ll_table=self.ll_tables.mutations) + + @property + def populations(self): + return PopulationTable(ll_table=self.ll_tables.populations) + + @property + def provenances(self): + return ProvenanceTable(ll_table=self.ll_tables.provenances) + + @property + def sequence_length(self): + return self.ll_tables.sequence_length + + @property + def file_uuid(self): + return self.ll_tables.file_uuid + + def asdict(self): + """ + Returns a dictionary representation of this TableCollection. + + Note: the semantics of this method changed at tskit 1.0.0. Previously a + map of table names to the tables themselves was returned. + """ + return { + "sequence_length": self.sequence_length, + "individuals": self.individuals.asdict(), + "nodes": self.nodes.asdict(), + "edges": self.edges.asdict(), + "migrations": self.migrations.asdict(), + "sites": self.sites.asdict(), + "mutations": self.mutations.asdict(), + "populations": self.populations.asdict(), + "provenances": self.provenances.asdict(), + } + + def __banner(self, title): + width = 60 + line = "#" * width + title_line = "# {}".format(title) + title_line += " " * (width - len(title_line) - 1) + title_line += "#" + return line + "\n" + title_line + "\n" + line + "\n" + + def __str__(self): + s = self.__banner("Individuals") + s += str(self.individuals) + "\n" + s += self.__banner("Nodes") + s += str(self.nodes) + "\n" + s += self.__banner("Edges") + s += str(self.edges) + "\n" + s += self.__banner("Sites") + s += str(self.sites) + "\n" + s += self.__banner("Mutations") + s += str(self.mutations) + "\n" + s += self.__banner("Migrations") + s += str(self.migrations) + "\n" + s += self.__banner("Populations") + s += str(self.populations) + "\n" + s += self.__banner("Provenances") + s += str(self.provenances) + return s + + def __eq__(self, other): + ret = False + if type(other) is type(self): + ret = bool(self.ll_tables.equals(other.ll_tables)) + return ret + + def __ne__(self, other): + return not self.__eq__(other) + + # Unpickle support + def __setstate__(self, state): + self.__init__(state["sequence_length"]) + self.individuals.set_columns(**state["individuals"]) + self.nodes.set_columns(**state["nodes"]) + self.edges.set_columns(**state["edges"]) + self.migrations.set_columns(**state["migrations"]) + self.sites.set_columns(**state["sites"]) + self.mutations.set_columns(**state["mutations"]) + self.populations.set_columns(**state["populations"]) + self.provenances.set_columns(**state["provenances"]) + + @classmethod + def fromdict(self, tables_dict): + tables = TableCollection(tables_dict["sequence_length"]) + tables.individuals.set_columns(**tables_dict["individuals"]) + tables.nodes.set_columns(**tables_dict["nodes"]) + tables.edges.set_columns(**tables_dict["edges"]) + tables.migrations.set_columns(**tables_dict["migrations"]) + tables.sites.set_columns(**tables_dict["sites"]) + tables.mutations.set_columns(**tables_dict["mutations"]) + tables.populations.set_columns(**tables_dict["populations"]) + tables.provenances.set_columns(**tables_dict["provenances"]) + return tables + + def tree_sequence(self): + """ + Returns a :class:`TreeSequence` instance with the structure defined by the + tables in this :class:`TableCollection`. If the table collection is not + in canonical form (i.e., does not meet sorting requirements) or cannot be + interpreted as a tree sequence an exception is raised. The + :meth:`.sort` method may be used to ensure that input sorting requirements + are met. + + :return: A :class:`TreeSequence` instance reflecting the structures + defined in this set of tables. + :rtype: .TreeSequence + """ + return tskit.TreeSequence.load_tables(self) + + def simplify( + self, samples=None, + filter_zero_mutation_sites=None, # Deprecated alias for filter_sites + reduce_to_site_topology=False, + filter_populations=True, filter_individuals=True, filter_sites=True): + """ + Simplifies the tables in place to retain only the information necessary + to reconstruct the tree sequence describing the given ``samples``. + This will change the ID of the nodes, so that the node + ``samples[k]`` will have ID ``k`` in the result. The resulting + NodeTable will have only the first ``len(samples)`` individuals marked + as samples. The mapping from node IDs in the current set of tables to + their equivalent values in the simplified tables is also returned as a + numpy array. If an array ``a`` is returned by this function and ``u`` + is the ID of a node in the input table, then ``a[u]`` is the ID of this + node in the output table. For any node ``u`` that is not mapped into + the output tables, this mapping will equal ``-1``. + + Tables operated on by this function must: be sorted (see + :meth:`TableCollection.sort`)), have children be born strictly after their + parents, and the intervals on which any individual is a child must be + disjoint. Other than this the tables need not satisfy remaining + requirements to specify a valid tree sequence (but the resulting tables + will). + + Please see the :meth:`TreeSequence.simplify` method for a description + of the remaining parameters. + + :param list[int] samples: A list of node IDs to retain as samples. If + not specified or None, use all nodes marked with the IS_SAMPLE flag. + :param bool filter_zero_mutation_sites: Deprecated alias for ``filter_sites``. + :param bool reduce_to_site_topology: Whether to reduce the topology down + to the trees that are present at sites. (default: False). + :param bool filter_populations: If True, remove any populations that are + not referenced by nodes after simplification; new population IDs are + allocated sequentially from zero. If False, the population table will + not be altered in any way. (Default: True) + :param bool filter_individuals: If True, remove any individuals that are + not referenced by nodes after simplification; new individual IDs are + allocated sequentially from zero. If False, the individual table will + not be altered in any way. (Default: True) + :param bool filter_sites: If True, remove any sites that are + not referenced by mutations after simplification; new site IDs are + allocated sequentially from zero. If False, the site table will not + be altered in any way. (Default: True) + :return: A numpy array mapping node IDs in the input tables to their + corresponding node IDs in the output tables. + :rtype: numpy array (dtype=np.int32). + """ + if filter_zero_mutation_sites is not None: + # Deprecated in 0.6.1. + warnings.warn( + "filter_zero_mutation_sites is deprecated; use filter_sites instead", + DeprecationWarning) + filter_sites = filter_zero_mutation_sites + if samples is None: + flags = self.nodes.flags + samples = np.where( + np.bitwise_and(flags, _tskit.NODE_IS_SAMPLE) != 0)[0].astype(np.int32) + return self.ll_tables.simplify( + samples, filter_sites=filter_sites, + filter_individuals=filter_individuals, + filter_populations=filter_populations, + reduce_to_site_topology=reduce_to_site_topology) + + def sort(self, edge_start=0): + """ + Sorts the tables in place. This ensures that all tree sequence ordering + requirements listed in the + :ref:`sec_valid_tree_sequence_requirements` section are met, as long + as each site has at most one mutation (see below). + + If the ``edge_start`` parameter is provided, this specifies the index + in the edge table where sorting should start. Only rows with index + greater than or equal to ``edge_start`` are sorted; rows before this index + are not affected. This parameter is provided to allow for efficient sorting + when the user knows that the edges up to a given index are already sorted. + + The individual, node, population and provenance tables are not affected + by this method. + + Edges are sorted as follows: + + - time of parent, then + - parent node ID, then + - child node ID, then + - left endpoint. + + Note that this sorting order exceeds the + :ref:`edge sorting requirements ` for a valid + tree sequence. For a valid tree sequence, we require that all edges for a + given parent ID are adjacent, but we do not require that they be listed in + sorted order. + + Sites are sorted by position, and sites with the same position retain + their relative ordering. + + Mutations are sorted by site ID, and mutations with the same site retain + their relative ordering. This does not currently rearrange tables so that + mutations occur after their mutation parents, which is a requirement for + valid tree sequences. + + :param int edge_start: The index in the edge table where sorting starts + (default=0; must be <= len(edges)). + """ + self.ll_tables.sort(edge_start) + # TODO add provenance + + def compute_mutation_parents(self): + """ + Modifies the tables in place, computing the ``parent`` column of the + mutation table. For this to work, the node and edge tables must be + valid, and the site and mutation tables must be sorted (see + :meth:`TableCollection.sort`). This will produce an error if mutations + are not sorted (i.e., if a mutation appears before its mutation parent) + *unless* the two mutations occur on the same branch, in which case + there is no way to detect the error. + + The ``parent`` of a given mutation is the ID of the next mutation + encountered traversing the tree upwards from that mutation, or + ``NULL`` if there is no such mutation. + """ + self.ll_tables.compute_mutation_parents() + # TODO add provenance + + def deduplicate_sites(self): + """ + Modifies the tables in place, removing entries in the site table with + duplicate ``position`` (and keeping only the *first* entry for each + site), and renumbering the ``site`` column of the mutation table + appropriately. This requires the site table to be sorted by position. + """ + self.ll_tables.deduplicate_sites() + # TODO add provenance + + +# Pickle support. See copyreg registration for this function below. +def _table_collection_pickle(tables): + return TableCollection, tuple(), tables.asdict() + + +# Pickle support for the various tables. We are forced to use copyreg.pickle +# here to support Python 2. For Python 3, we can just use the __setstate__. +# It would be cleaner to attach the pickle_*_table functions to the classes +# themselves, but this causes issues with Mocking on readthedocs. Sigh. +copyreg.pickle(IndividualTable, _pickle_individual_table) +copyreg.pickle(NodeTable, _pickle_node_table) +copyreg.pickle(EdgeTable, _edge_table_pickle) +copyreg.pickle(MigrationTable, _migration_table_pickle) +copyreg.pickle(SiteTable, _site_table_pickle) +copyreg.pickle(MutationTable, _mutation_table_pickle) +copyreg.pickle(PopulationTable, _population_table_pickle) +copyreg.pickle(ProvenanceTable, _provenance_table_pickle) +copyreg.pickle(TableCollection, _table_collection_pickle) + + +############################################# +# Table functions. +############################################# + +def pack_bytes(data): + """ + Packs the specified list of bytes into a flattened numpy array of 8 bit integers + and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param list[bytes] data: The list of bytes values to encode. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + n = len(data) + offsets = np.zeros(n + 1, dtype=np.uint32) + for j in range(n): + offsets[j + 1] = offsets[j] + len(data[j]) + column = np.zeros(offsets[-1], dtype=np.int8) + for j, value in enumerate(data): + column[offsets[j]: offsets[j + 1]] = bytearray(value) + return column, offsets + + +def unpack_bytes(packed, offset): + """ + Unpacks a list of bytes from the specified numpy arrays of packed byte + data and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :return: The list of bytes values unpacked from the parameter arrays. + :rtype: list[bytes] + """ + # This could be done a lot more efficiently... + ret = [] + for j in range(offset.shape[0] - 1): + raw = packed[offset[j]: offset[j + 1]].tobytes() + ret.append(raw) + return ret + + +def pack_strings(strings, encoding="utf8"): + """ + Packs the specified list of strings into a flattened numpy array of 8 bit integers + and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param list[str] data: The list of strings to encode. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + return pack_bytes([bytearray(s.encode(encoding)) for s in strings]) + + +def unpack_strings(packed, offset, encoding="utf8"): + """ + Unpacks a list of strings from the specified numpy arrays of packed byte + data and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The list of strings unpacked from the parameter arrays. + :rtype: list[str] + """ + return [b.decode(encoding) for b in unpack_bytes(packed, offset)] diff --git a/python/tskit/trees.py b/python/tskit/trees.py new file mode 100644 index 0000000000..73ae0ecf67 --- /dev/null +++ b/python/tskit/trees.py @@ -0,0 +1,2753 @@ +# -*- coding: utf-8 -*- +""" +Module responsible for managing trees and tree sequences. +""" +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import json +import sys +import base64 +import warnings +import functools +try: + import concurrent.futures +except ImportError: + # We're on Python2; any attempts to use futures are dealt with below. + pass + +import numpy as np + +import _tskit +import tskit.drawing as drawing +import tskit.exceptions as exceptions +import tskit.provenance as provenance +import tskit.tables as tables +import tskit.formats as formats + +from _tskit import NODE_IS_SAMPLE +from _tskit import NULL + + +IS_PY2 = sys.version_info[0] < 3 + + +CoalescenceRecord = collections.namedtuple( + "CoalescenceRecord", + ["left", "right", "node", "children", "time", "population"]) + + +# TODO this interface is rubbish. Should have much better printing options. +# TODO we should be use __slots__ here probably. +class SimpleContainer(object): + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return repr(self.__dict__) + + +class Individual(SimpleContainer): + """ + An :ref:`individual ` in a tree sequence. + Since nodes correspond to genomes, individuals are associated with a collection + of nodes (e.g., two nodes per diploid). See :ref:`sec_nodes_or_individuals` + for more discussion of this distinction. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar id: The integer ID of this individual. Varies from 0 to + :attr:`.TreeSequence.num_individuals` - 1. + :vartype id: int + :ivar flags: The bitwise flags for this individual. + :vartype flags: int + :ivar location: The spatial location of this individual as a numpy array. The + location is an empty array if no spatial location is defined. + :vartype location: numpy.ndarray + :ivar nodes: The IDs of the nodes that are associated with this individual as + a numpy array (dtype=np.int32). If no nodes are associated with the + individual this array will be empty. + :vartype location: numpy.ndarray + :ivar metadata: The :ref:`metadata ` for this individual. + :vartype metadata: bytes + """ + def __init__(self, id_=None, flags=0, location=None, nodes=None, metadata=""): + self.id = id_ + self.flags = flags + self.location = location + self.metadata = metadata + self.nodes = nodes + + def __eq__(self, other): + return ( + self.id == other.id and + self.flags == other.flags and + self.metadata == other.metadata and + np.array_equal(self.nodes, other.nodes) and + np.array_equal(self.location, other.location)) + + +class Node(SimpleContainer): + """ + A :ref:`node ` in a tree sequence, corresponding + to a single genome. The ``time`` and ``population`` are attributes of the + ``Node``, rather than the ``Individual``, as discussed in + :ref:`sec_nodes_or_individuals`. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar id: The integer ID of this node. Varies from 0 to + :attr:`.TreeSequence.num_nodes` - 1. + :vartype id: int + :ivar flags: The bitwise flags for this node. + :vartype flags: int + :ivar time: The birth time of this node. + :vartype time: float + :ivar population: The integer ID of the population that this node was born in. + :vartype population: int + :ivar individual: The integer ID of the individual that this node was a part of. + :vartype individual: int + :ivar metadata: The :ref:`metadata ` for this node. + :vartype metadata: bytes + """ + def __init__( + self, id_=None, flags=0, time=0, population=NULL, + individual=NULL, metadata=""): + self.id = id_ + self.time = time + self.population = population + self.individual = individual + self.metadata = metadata + self.flags = flags + + def is_sample(self): + """ + Returns True if this node is a sample. This value is derived from the + ``flag`` variable. + + :rtype: bool + """ + return self.flags & NODE_IS_SAMPLE + + +class Edge(SimpleContainer): + """ + An :ref:`edge ` in a tree sequence. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar left: The left coordinate of this edge. + :vartype left: float + :ivar right: The right coordinate of this edge. + :vartype right: float + :ivar parent: The integer ID of the parent node for this edge. + To obtain further information about a node with a given ID, use + :meth:`.TreeSequence.node`. + :vartype parent: int + :ivar child: The integer ID of the child node for this edge. + To obtain further information about a node with a given ID, use + :meth:`.TreeSequence.node`. + :vartype child: int + """ + def __init__(self, left, right, parent, child): + self.left = left + self.right = right + self.parent = parent + self.child = child + + def __repr__(self): + return "{{left={:.3f}, right={:.3f}, parent={}, child={}}}".format( + self.left, self.right, self.parent, self.child) + + +class Site(SimpleContainer): + """ + A :ref:`site ` in a tree sequence. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar id: The integer ID of this site. Varies from 0 to + :attr:`.TreeSequence.num_sites` - 1. + :vartype id: int + :ivar position: The floating point location of this site in genome coordinates. + Ranges from 0 (inclusive) to :attr:`.TreeSequence.sequence_length` + (exclusive). + :vartype position: float + :ivar ancestral_state: The ancestral state at this site (i.e., the state + inherited by nodes, unless mutations occur). + :vartype ancestral_state: str + :ivar metadata: The :ref:`metadata ` for this site. + :vartype metadata: bytes + :ivar mutations: The list of mutations at this site. Mutations + within a site are returned in the order they are specified in the + underlying :class:`.MutationTable`. + :vartype mutations: list[:class:`.Mutation`] + """ + def __init__(self, id_, position, ancestral_state, mutations, metadata): + self.id = id_ + self.position = position + self.ancestral_state = ancestral_state + self.mutations = mutations + self.metadata = metadata + + +class Mutation(SimpleContainer): + """ + A :ref:`mutation ` in a tree sequence. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar id: The integer ID of this mutation. Varies from 0 to + :attr:`.TreeSequence.num_mutations` - 1. + :vartype id: int + :ivar site: The integer ID of the site that this mutation occurs at. To obtain + further information about a site with a given ID use + :meth:`.TreeSequence.site`. + :vartype site: int + :ivar node: The integer ID of the first node that inherits this mutation. + To obtain further information about a node with a given ID, use + :meth:`.TreeSequence.node`. + :vartype node: int + :ivar derived_state: The derived state for this mutation. This is the state + inherited by nodes in the subtree rooted at this mutation's node, unless + another mutation occurs. + :vartype derived_state: str + :ivar parent: The integer ID of this mutation's parent mutation. When multiple + mutations occur at a site along a path in the tree, mutations must + record the mutation that is immediately above them. If the mutation does + not have a parent, this is equal to the :const:`NULL` (-1). + To obtain further information about a mutation with a given ID, use + :meth:`.TreeSequence.mutation`. + :vartype parent: int + :ivar metadata: The :ref:`metadata ` for this site. + :vartype metadata: bytes + """ + def __init__(self, id_, site, node, derived_state, parent, metadata): + self.id = id_ + self.site = site + self.node = node + self.derived_state = derived_state + self.parent = parent + self.metadata = metadata + + +class Migration(SimpleContainer): + """ + A :ref:`migration ` in a tree sequence. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar left: The left end of the genomic interval covered by this + migration (inclusive). + :vartype left: float + :ivar right: The right end of the genomic interval covered by this migration + (exclusive). + :vartype right: float + :ivar node: The integer ID of the node involved in this migration event. + To obtain further information about a node with a given ID, use + :meth:`.TreeSequence.node`. + :vartype node: int + :ivar source: The source population ID. + :vartype source: int + :ivar dest: The destination population ID. + :vartype dest: int + :ivar time: The time at which this migration occured at. + :vartype time: float + """ + def __init__(self, left, right, node, source, dest, time): + self.left = left + self.right = right + self.node = node + self.source = source + self.dest = dest + self.time = time + + +class Population(SimpleContainer): + """ + A :ref:`population ` in a tree sequence. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar id: The integer ID of this population. Varies from 0 to + :attr:`.TreeSequence.num_populations` - 1. + :vartype id: int + :ivar metadata: The :ref:`metadata ` for this population. + :vartype metadata: bytes + """ + def __init__(self, id_, metadata=""): + self.id = id_ + self.metadata = metadata + + +class Variant(SimpleContainer): + """ + A variant represents the observed variation among the samples + for a given site. A variant consists (a) of a reference to the + :class:`.Site` instance in question; (b) the **alleles** that may be + observed at the samples for this site; and (c) the **genotypes** + mapping sample IDs to the observed alleles. + + Each element in the ``alleles`` tuple is a string, representing the + actual observed state for a given sample. The first element of this + tuple is guaranteed to be the same as the site's ``ancestral_state`` value. + The list of alleles is also guaranteed not to contain any duplicates. + However, allelic values may be listed that are not referred to by any + samples. For example, if we have a site that is fixed for the derived state + (i.e., we have a mutation over the tree root), all genotypes will be 1, but + the alleles list will be equal to ``('0', '1')``. Other than the + ancestral state being the first allele, the alleles are listed in + no particular order, and the ordering should not be relied upon. + + The ``genotypes`` represent the observed allelic states for each sample, + such that ``var.alleles[var.genotypes[j]]`` gives the string allele + for sample ID ``j``. Thus, the elements of the genotypes array are + indexes into the ``alleles`` list. The genotypes are provided in this + way via a numpy array to enable efficient calculations. + + Modifying the attributes in this class will have **no effect** on the + underlying tree sequence data. + + :ivar site: The site object for this variant. + :vartype site: :class:`.Site` + :ivar alleles: A tuple of the allelic values that may be observed at the + samples at the current site. The first element of this tuple is always + the sites's ancestral state. + :vartype alleles: tuple(str) + :ivar genotypes: An array of indexes into the list ``alleles``, giving the + state of each sample at the current site. + :vartype genotypes: numpy.ndarray + """ + def __init__(self, site, alleles, genotypes): + self.site = site + self.alleles = alleles + self.genotypes = genotypes + # Deprecated aliases to avoid breaking existing code. + self.position = site.position + self.index = site.id + + +class Edgeset(SimpleContainer): + def __init__(self, left, right, parent, children): + self.left = left + self.right = right + self.parent = parent + self.children = children + + def __repr__(self): + return "{{left={:.3f}, right={:.3f}, parent={}, children={}}}".format( + self.left, self.right, self.parent, self.children) + + +class Provenance(SimpleContainer): + def __init__(self, id_=None, timestamp=None, record=None): + self.id = id_ + self.timestamp = timestamp + self.record = record + + +def add_deprecated_mutation_attrs(site, mutation): + """ + Add in attributes for the older deprecated way of defining + mutations. These attributes will be removed in future releases + and are deliberately undocumented in version 0.5.0. + """ + mutation.position = site.position + mutation.index = site.id + return mutation + + +class Tree(object): + """ + A Tree is a single tree in a :class:`.TreeSequence`. The Tree + implementation differs from most tree implementations by using **integer + node IDs** to refer to nodes rather than objects. Thus, when we wish to + find the parent of the node with ID '0', we use ``tree.parent(0)``, which + returns another integer. If '0' does not have a parent in the current tree + (e.g., if it is a root), then the special value :const:`.NULL` + (:math:`-1`) is returned. The children of a node are found using the + :meth:`.children` method. To obtain information about a particular node, + one may either use ``tree.tree_sequence.node(u)`` to obtain the + corresponding :class:`Node` instance, or use the :meth:`.time` or + :meth:`.population` shorthands. Tree traversals in various orders + is possible using the :meth:`.Tree.nodes` iterator. + + Trees are not intended to be instantiated directly, and are + obtained as part of a :class:`.TreeSequence` using the + :meth:`.trees` method. + """ + def __init__(self, ll_tree, tree_sequence): + self._ll_tree = ll_tree + self._tree_sequence = tree_sequence + + @property + def tree_sequence(self): + """ + Returns the tree sequence that this tree is from. + + :return: The parent tree sequence for this tree. + :rtype: :class:`.TreeSequence` + """ + return self._tree_sequence + + def get_branch_length(self, u): + # Deprecated alias for branch_length + return self.branch_length(u) + + def branch_length(self, u): + """ + Returns the length of the branch (in generations) joining the + specified node to its parent. This is equivalent to + + >>> tree.time(tree.parent(u)) - tree.time(u) + + Note that this is not related to the value returned by + :attr:`.length`, which describes the length of the interval + covered by the tree in genomic coordinates. + + :param int u: The node of interest. + :return: The branch length from u to its parent. + :rtype: float + """ + return self.time(self.get_parent(u)) - self.time(u) + + def get_total_branch_length(self): + # Deprecated alias for total_branch_length + return self.total_branch_length + + @property + def total_branch_length(self): + """ + Returns the sum of all the branch lengths in this tree (in + units of generations). This is equivalent to + + >>> sum( + >>> tree.branch_length(u) for u in tree.nodes() + >>> if u not in self.roots) + + :return: The sum of all the branch lengths in this tree. + :rtype: float + """ + return sum( + self.get_branch_length(u) for u in self.nodes() if u not in self.roots) + + def get_mrca(self, u, v): + # Deprecated alias for mrca + return self.mrca(u, v) + + def mrca(self, u, v): + """ + Returns the most recent common ancestor of the specified nodes. + + :param int u: The first node. + :param int v: The second node. + :return: The most recent common ancestor of u and v. + :rtype: int + """ + return self._ll_tree.get_mrca(u, v) + + def get_tmrca(self, u, v): + # Deprecated alias for tmrca + return self.tmrca(u, v) + + def tmrca(self, u, v): + """ + Returns the time of the most recent common ancestor of the specified + nodes. This is equivalent to:: + + >>> tree.time(tree.mrca(u, v)) + + :param int u: The first node. + :param int v: The second node. + :return: The time of the most recent common ancestor of u and v. + :rtype: float + """ + return self.get_time(self.get_mrca(u, v)) + + def get_parent(self, u): + # Deprecated alias for parent + return self.parent(u) + + def parent(self, u): + """ + Returns the parent of the specified node. Returns + the :const:`.NULL` if u is the root or is not a node in + the current tree. + + :param int u: The node of interest. + :return: The parent of u. + :rtype: int + """ + return self._ll_tree.get_parent(u) + + # Quintuply linked tree structure. + + def left_child(self, u): + return self._ll_tree.get_left_child(u) + + def right_child(self, u): + return self._ll_tree.get_right_child(u) + + def left_sib(self, u): + return self._ll_tree.get_left_sib(u) + + def right_sib(self, u): + return self._ll_tree.get_right_sib(u) + + # Sample list. + + def left_sample(self, u): + return self._ll_tree.get_left_sample(u) + + def right_sample(self, u): + return self._ll_tree.get_right_sample(u) + + def next_sample(self, u): + return self._ll_tree.get_next_sample(u) + + # TODO do we also have right_root? + @property + def left_root(self): + return self._ll_tree.get_left_root() + + def get_children(self, u): + # Deprecated alias for self.children + return self.children(u) + + def children(self, u): + """ + Returns the children of the specified node ``u`` as a tuple of integer node IDs. + If ``u`` is a leaf, return the empty tuple. + + :param int u: The node of interest. + :return: The children of ``u`` as a tuple of integers + :rtype: tuple(int) + """ + return self._ll_tree.get_children(u) + + def get_time(self, u): + # Deprecated alias for self.time + return self.time(u) + + def time(self, u): + """ + Returns the time of the specified node in generations. + Equivalent to ``tree.tree_sequence.node(u).time``. + + :param int u: The node of interest. + :return: The time of u. + :rtype: float + """ + return self._ll_tree.get_time(u) + + def get_population(self, u): + # Deprecated alias for self.population + return self.population(u) + + def population(self, u): + """ + Returns the population associated with the specified node. + Equivalent to ``tree.tree_sequence.node(u).population``. + + :param int u: The node of interest. + :return: The ID of the population associated with node u. + :rtype: int + """ + return self._ll_tree.get_population(u) + + def is_internal(self, u): + """ + Returns True if the specified node is not a leaf. A node is internal + if it has one or more children in the current tree. + + :param int u: The node of interest. + :return: True if u is not a leaf node. + :rtype: bool + """ + return not self.is_leaf(u) + + def is_leaf(self, u): + """ + Returns True if the specified node is a leaf. A node :math:`u` is a + leaf if it has zero children. + + :param int u: The node of interest. + :return: True if u is a leaf node. + :rtype: bool + """ + return len(self.children(u)) == 0 + + def is_sample(self, u): + """ + Returns True if the specified node is a sample. A node :math:`u` is a + sample if it has been marked as a sample in the parent tree sequence. + + :param int u: The node of interest. + :return: True if u is a sample. + :rtype: bool + """ + return bool(self._ll_tree.is_sample(u)) + + @property + def num_nodes(self): + """ + Returns the number of nodes in the :class:`.TreeSequence` this tree is in. + Equivalent to ``tree.tree_sequence.num_nodes``. To find the number of + nodes that are reachable from all roots use ``len(list(tree.nodes()))``. + + :rtype: int + """ + return self._ll_tree.get_num_nodes() + + @property + def num_roots(self): + """ + The number of roots in this tree, as defined in the :attr:`.roots` attribute. + + Requires O(number of roots) time. + + :rtype: int + """ + return self._ll_tree.get_num_roots() + + @property + def roots(self): + """ + The list of roots in this tree. A root is defined as a unique endpoint of + the paths starting at samples. We can define the set of roots as follows: + + .. code-block:: python + + roots = set() + for u in tree_sequence.samples(): + while tree.parent(u) != tskit.NULL: + u = tree.parent(u) + roots.add(u) + # roots is now the set of all roots in this tree. + assert sorted(roots) == sorted(tree.roots) + + The roots of the tree are returned in a list, in no particular order. + + Requires O(number of roots) time. + + :return: The list of roots in this tree. + :rtype: list + """ + roots = [] + u = self.left_root + while u != NULL: + roots.append(u) + u = self.right_sib(u) + return roots + + def get_root(self): + # Deprecated alias for self.root + return self.root + + @property + def root(self): + """ + The root of this tree. If the tree contains multiple roots, a ValueError is + raised indicating that the :attr:`.roots` attribute should be used instead. + + :return: The root node. + :rtype: int + :raises: :class:`ValueError` if this tree contains more than one root. + """ + root = self.left_root + if root != NULL and self.right_sib(root) != NULL: + raise ValueError("More than one root exists. Use tree.roots instead") + return root + + def get_index(self): + # Deprecated alias for self.index + return self.index + + @property + def index(self): + """ + Returns the index this tree occupies in the parent tree sequence. + This index is zero based, so the first tree in the sequence has index 0. + + :return: The index of this tree. + :rtype: int + """ + return self._ll_tree.get_index() + + def get_interval(self): + # Deprecated alias for self.interval + return self.interval + + @property + def interval(self): + """ + Returns the coordinates of the genomic interval that this tree + represents the history of. The interval is returned as a tuple + :math:`(l, r)` and is a half-open interval such that the left + coordinate is inclusive and the right coordinate is exclusive. This + tree therefore applies to all genomic locations :math:`x` such that + :math:`l \\leq x < r`. + + :return: A tuple (l, r) representing the left-most (inclusive) + and right-most (exclusive) coordinates of the genomic region + covered by this tree. + :rtype: tuple + """ + return self._ll_tree.get_left(), self._ll_tree.get_right() + + def get_length(self): + # Deprecated alias for self.length + return self.length + + @property + def length(self): + """ + Returns the length of the genomic interval that this tree represents. + This is defined as :math:`r - l`, where :math:`(l, r)` is the genomic + interval returned by :attr:`.interval`. + + :return: The length of the genomic interval covered by this tree. + :rtype: int + """ + left, right = self.get_interval() + return right - left + + # The sample_size (or num_samples) is really a property of the tree sequence, + # and so we should provide access to this via a tree.tree_sequence.num_samples + # property access. However, we can't just remove the method as a lot of code + # may depend on it. To complicate things a bit more, sample_size has been + # changed to num_samples elsewhere for consistency. We can't do this here + # because there is already a num_samples method which returns the number of + # samples below a particular node. The best thing to do is probably to + # undocument the sample_size property, but keep it around for ever. + + def get_sample_size(self): + # Deprecated alias for self.sample_size + return self.sample_size + + @property + def sample_size(self): + """ + Returns the sample size for this tree. This is the number of sample + nodes in the tree. + + :return: The number of sample nodes in the tree. + :rtype: int + """ + return self._ll_tree.get_sample_size() + + def draw( + self, path=None, width=None, height=None, + node_labels=None, node_colours=None, + mutation_labels=None, mutation_colours=None, + format=None): + """ + Returns a drawing of this tree. + + When working in a Jupyter notebook, use the ``IPython.display.SVG`` + function to display the SVG output from this function inline in the notebook:: + + >>> SVG(tree.draw()) + + The unicode format uses unicode `box drawing characters + `_ to render the tree. + This allows rendered trees to be printed out to the terminal:: + + >>> print(tree.draw(format="unicode")) + 6 + ┏━┻━┓ + ┃ 5 + ┃ ┏━┻┓ + ┃ ┃ 4 + ┃ ┃ ┏┻┓ + 3 0 1 2 + + The ``node_labels`` argument allows the user to specify custom labels + for nodes, or no labels at all:: + + >>> print(tree.draw(format="unicode", node_labels={})) + ┃ + ┏━┻━┓ + ┃ ┃ + ┃ ┏━┻┓ + ┃ ┃ ┃ + ┃ ┃ ┏┻┓ + ┃ ┃ ┃ ┃ + + :param str path: The path to the file to write the output. If None, do not + write to file. + :param int width: The width of the image in pixels. If not specified, either + defaults to the minimum size required to depict the tree (text formats) + or 200 pixels. + :param int height: The height of the image in pixels. If not specified, either + defaults to the minimum size required to depict the tree (text formats) + or 200 pixels. + :param map node_labels: If specified, show custom labels for the nodes + that are present in the map. Any nodes not specified in the map will + not have a node label. + :param map node_colours: If specified, show custom colours for nodes. (Only + supported in the SVG format.) + :param str format: The format of the returned image. Currently supported + are 'svg', 'ascii' and 'unicode'. + :return: A representation of this tree in the requested format. + :rtype: str + """ + output = drawing.draw_tree( + self, format=format, width=width, height=height, + node_labels=node_labels, node_colours=node_colours, + mutation_labels=mutation_labels, mutation_colours=mutation_colours) + if path is not None: + with open(path, "w") as f: + f.write(output) + return output + + def get_num_mutations(self): + return self.num_mutations + + @property + def num_mutations(self): + """ + Returns the total number of mutations across all sites on this tree. + + :return: The total number of mutations over all sites on this tree. + :rtype: int + """ + return sum(len(site.mutations) for site in self.sites()) + + @property + def num_sites(self): + """ + Returns the number of sites on this tree. + + :return: The number of sites on this tree. + :rtype: int + """ + return self._ll_tree.get_num_sites() + + def sites(self): + """ + Returns an iterator over all the :ref:`sites ` + in this tree. Sites are returned in order of increasing ID + (and also position). See the :class:`Site` class for details on + the available fields for each site. + + :return: An iterator over all sites in this tree. + :rtype: iter(:class:`.Site`) + """ + # TODO change the low-level API to just return the IDs of the sites. + for ll_site in self._ll_tree.get_sites(): + _, _, _, id_, _ = ll_site + yield self.tree_sequence.site(id_) + + def mutations(self): + """ + Returns an iterator over all the + :ref:`mutations ` in this tree. + Mutations are returned in order of nondecreasing site ID. + See the :class:`Mutation` class for details on the available fields for + each mutation. + + The returned iterator is equivalent to iterating over all sites + and all mutations in each site, i.e.:: + + >>> for site in tree.sites(): + >>> for mutation in site.mutations: + >>> yield mutation + + :return: An iterator over all mutations in this tree. + :rtype: iter(:class:`.Mutation`) + """ + for site in self.sites(): + for mutation in site.mutations: + yield add_deprecated_mutation_attrs(site, mutation) + + def get_leaves(self, u): + # Deprecated alias for samples. See the discussion in the get_num_leaves + # method for why this method is here and why it is semantically incorrect. + # The 'leaves' iterator below correctly returns the leaves below a given + # node. + return self.samples(u) + + def leaves(self, u=None): + """ + Returns an iterator over all the leaves in this tree that are + underneath the specified node. If u is not specified, return all leaves + in the tree. + + :param int u: The node of interest. + :return: An iterator over all leaves in the subtree rooted at u. + :rtype: iterator + """ + roots = [u] + if u is None: + roots = self.roots + for root in roots: + for v in self.nodes(root): + if self.is_leaf(v): + yield v + + def _sample_generator(self, u): + if self._ll_tree.get_flags() & _tskit.SAMPLE_LISTS: + samples = self.tree_sequence.samples() + index = self.left_sample(u) + if index != NULL: + stop = self.right_sample(u) + while True: + yield samples[index] + if index == stop: + break + index = self.next_sample(index) + else: + # Fall back on iterating over all nodes in the tree, yielding + # samples as we see them. + for v in self.nodes(u): + if self.is_sample(v): + yield v + + def samples(self, u=None): + """ + Returns an iterator over all the samples in this tree that are + underneath the specified node. If u is a sample, it is included in the + returned iterator. If u is not specified, return all samples in the tree. + + If the :meth:`.TreeSequence.trees` method is called with + ``sample_lists=True``, this method uses an efficient algorithm to find + the samples. If not, a simple traversal based method is used. + + :param int u: The node of interest. + :return: An iterator over all samples in the subtree rooted at u. + :rtype: iterator + """ + roots = [u] + if u is None: + roots = self.roots + for root in roots: + for v in self._sample_generator(root): + yield v + + def get_num_leaves(self, u): + # Deprecated alias for num_samples. The method name is inaccurate + # as this will count the number of tracked _samples_. This is only provided to + # avoid breaking existing code and should not be used in new code. We could + # change this method to be semantically correct and just count the + # number of leaves we hit in the leaves() iterator. However, this would + # have the undesirable effect of making code that depends on the constant + # time performance of get_num_leaves many times slower. So, the best option + # is to leave this method as is, and to slowly deprecate it out. Once this + # has been removed, we might add in a ``num_leaves`` method that returns the + # length of the leaves() iterator as one would expect. + return self.num_samples(u) + + def get_num_samples(self, u=None): + # Deprecated alias for num_samples. + return self.num_samples(u) + + def num_samples(self, u=None): + """ + Returns the number of samples in this tree underneath the specified + node (including the node itself). If u is not specified return + the total number of samples in the tree. + + If the :meth:`.TreeSequence.trees` method is called with + ``sample_counts=True`` this method is a constant time operation. If not, + a slower traversal based algorithm is used to count the samples. + + :param int u: The node of interest. + :return: The number of samples in the subtree rooted at u. + :rtype: int + """ + if u is None: + return sum(self._ll_tree.get_num_samples(u) for u in self.roots) + else: + return self._ll_tree.get_num_samples(u) + + def get_num_tracked_leaves(self, u): + # Deprecated alias for num_tracked_samples. The method name is inaccurate + # as this will count the number of tracked _samples_. This is only provided to + # avoid breaking existing code and should not be used in new code. + return self.num_tracked_samples(u) + + def get_num_tracked_samples(self, u=None): + # Deprecated alias for num_tracked_samples + return self.num_tracked_samples(u) + + def num_tracked_samples(self, u=None): + """ + Returns the number of samples in the set specified in the + ``tracked_samples`` parameter of the :meth:`.TreeSequence.trees` method + underneath the specified node. If the input node is not specified, + return the total number of tracked samples in the tree. + + This is a constant time operation. + + :param int u: The node of interest. + :return: The number of samples within the set of tracked samples in + the subtree rooted at u. + :rtype: int + :raises RuntimeError: if the :meth:`.TreeSequence.trees` + method is not called with ``sample_counts=True``. + """ + roots = [u] + if u is None: + roots = self.roots + if not (self._ll_tree.get_flags() & _tskit.SAMPLE_COUNTS): + raise RuntimeError( + "The get_num_tracked_samples method is only supported " + "when sample_counts=True.") + return sum(self._ll_tree.get_num_tracked_samples(root) for root in roots) + + def _preorder_traversal(self, u): + stack = [u] + while len(stack) > 0: + v = stack.pop() + if self.is_internal(v): + stack.extend(reversed(self.get_children(v))) + yield v + + def _postorder_traversal(self, u): + stack = [u] + k = NULL + while stack: + v = stack[-1] + if self.is_internal(v) and v != k: + stack.extend(reversed(self.get_children(v))) + else: + k = self.get_parent(v) + yield stack.pop() + + def _inorder_traversal(self, u): + # TODO add a nonrecursive version of the inorder traversal. + children = self.get_children(u) + mid = len(children) // 2 + for c in children[:mid]: + for v in self._inorder_traversal(c): + yield v + yield u + for c in children[mid:]: + for v in self._inorder_traversal(c): + yield v + + def _levelorder_traversal(self, u): + queue = collections.deque([u]) + while queue: + v = queue.popleft() + if self.is_internal(v): + queue.extend(self.get_children(v)) + yield v + + def nodes(self, root=None, order="preorder"): + """ + Returns an iterator over the nodes in this tree. If the root parameter + is provided, iterate over the nodes in the subtree rooted at this + node. If this is None, iterate over all nodes. If the order parameter + is provided, iterate over the nodes in required tree traversal order. + + :param int root: The root of the subtree we are traversing. + :param str order: The traversal ordering. Currently 'preorder', + 'inorder', 'postorder' and 'levelorder' ('breadthfirst') + are supported. + :return: An iterator over the nodes in the tree in some traversal order. + :rtype: iterator + """ + methods = { + "preorder": self._preorder_traversal, + "inorder": self._inorder_traversal, + "postorder": self._postorder_traversal, + "levelorder": self._levelorder_traversal, + "breadthfirst": self._levelorder_traversal + } + try: + iterator = methods[order] + except KeyError: + raise ValueError("Traversal ordering '{}' not supported".format(order)) + roots = [root] + if root is None: + roots = self.roots + for u in roots: + for v in iterator(u): + yield v + + # TODO make this a bit less embarrassing by using an iterative method. + def __build_newick(self, node, precision, node_labels): + """ + Simple recursive version of the newick generator used when non-default + node labels are needed. + """ + label = node_labels.get(node, "") + if self.is_leaf(node): + s = "{0}".format(label) + else: + s = "(" + for child in self.children(node): + branch_length = self.branch_length(child) + subtree = self.__build_newick(child, precision, node_labels) + s += subtree + ":{0:.{1}f},".format(branch_length, precision) + s = s[:-1] + "){}".format(label) + return s + + def newick(self, precision=14, root=None, node_labels=None): + """ + Returns a `newick encoding `_ + of this tree. If the ``root`` argument is specified, return a representation + of the specified subtree, otherwise the full tree is returned. If the tree + has multiple roots then seperate newick strings for each rooted subtree + must be found (i.e., we do not attempt to concatenate the different trees). + + By default, leaf nodes are labelled with their numerical ID + 1, + and internal nodes are not labelled. Arbitrary node labels can be specified + using the ``node_labels`` argument, which maps node IDs to the desired + labels. + + .. warning:: Node labels are **not** Newick escaped, so care must be taken + to provide labels that will not break the encoding. + + :param int precision: The numerical precision with which branch lengths are + printed. + :param int root: If specified, return the tree rooted at this node. + :param map node_labels: If specified, show custom labels for the nodes + that are present in the map. Any nodes not specified in the map will + not have a node label. + :return: A newick representation of this tree. + :rtype: str + """ + if root is None: + if self.num_roots > 1: + raise ValueError( + "Cannot get newick for multiroot trees. Try " + "[t.newick(root) for root in t.roots] to get a list of " + "newick trees, one for each root.") + root = self.root + if node_labels is None: + s = self._ll_tree.get_newick(precision=precision, root=root) + if not IS_PY2: + s = s.decode() + else: + return self.__build_newick(root, precision, node_labels) + ";" + return s + + @property + def parent_dict(self): + return self.get_parent_dict() + + def get_parent_dict(self): + pi = { + u: self.parent(u) for u in range(self.num_nodes) + if self.parent(u) != NULL} + return pi + + def __str__(self): + return str(self.get_parent_dict()) + + +def load(path): + """ + Loads a tree sequence from the specified file path. This file must be in the + :ref:`tree sequence file format ` produced by the + :meth:`.TreeSequence.dump` method. + + :param str path: The file path of the ``.trees`` file containing the + tree sequence we wish to load. + :return: The tree sequence object containing the information + stored in the specified file path. + :rtype: :class:`tskit.TreeSequence` + """ + try: + return TreeSequence.load(path) + except exceptions.FileFormatError as e: + formats.raise_hdf5_format_error(path, e) + + +def __load_tables( + nodes, edges, migrations=None, sites=None, mutations=None, + provenances=None, individuals=None, populations=None, sequence_length=0): + """ + **This method is now deprecated. Please use TableCollection.tree_sequence() + instead** + + Loads the tree sequence data from the specified table objects, and + returns the resulting :class:`.TreeSequence` object. These tables + must fulfil the properties required for an input tree sequence as + described in the :ref:`sec_valid_tree_sequence_requirements` section. + + The ``sequence_length`` parameter determines the + :attr:`.TreeSequence.sequence_length` of the returned tree sequence. If it + is 0 or not specified, the value is taken to be the maximum right + coordinate of the input edges. This parameter is useful in degenerate + situations (such as when there are zero edges), but can usually be ignored. + + :param NodeTable nodes: The :ref:`node table ` + (required). + :param EdgeTable edges: The :ref:`edge table ` + (required). + :param MigrationTable migrations: The :ref:`migration table + ` (optional). + :param SiteTable sites: The :ref:`site table ` + (optional; but if supplied, ``mutations`` must also be specified). + :param MutationTable mutations: The :ref:`mutation table + ` (optional; but if supplied, ``sites`` + must also be specified). + :param ProvenanceTable provenances: The :ref:`provenance table + ` (optional). + :param IndividualTable individuals: The :ref:`individual table + ` (optional). + :param PopulationTable populations: The :ref:`population table + ` (optional). + :param float sequence_length: The sequence length of the returned tree sequence. If + not supplied or zero this will be inferred from the set of edges. + :return: A :class:`.TreeSequence` consistent with the specified tables. + :rtype: TreeSequence + """ + if sequence_length is None: + sequence_length = 0 + if sequence_length == 0 and len(edges) > 0: + sequence_length = edges.right.max() + kwargs = { + "nodes": nodes.ll_table, "edges": edges.ll_table, + "sequence_length": sequence_length} + if migrations is not None: + kwargs["migrations"] = migrations.ll_table + else: + kwargs["migrations"] = _tskit.MigrationTable() + if sites is not None: + kwargs["sites"] = sites.ll_table + else: + kwargs["sites"] = _tskit.SiteTable() + if mutations is not None: + kwargs["mutations"] = mutations.ll_table + else: + kwargs["mutations"] = _tskit.MutationTable() + if provenances is not None: + kwargs["provenances"] = provenances.ll_table + else: + kwargs["provenances"] = _tskit.ProvenanceTable() + if individuals is not None: + kwargs["individuals"] = individuals.ll_table + else: + kwargs["individuals"] = _tskit.IndividualTable() + if populations is not None: + kwargs["populations"] = populations.ll_table + else: + kwargs["populations"] = _tskit.PopulationTable() + + ll_tables = _tskit.TableCollection(**kwargs) + return TreeSequence.load_tables(tables.TableCollection(ll_tables=ll_tables)) + + +def parse_individuals( + source, strict=True, encoding='utf8', base64_metadata=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of an individual table and returns the corresponding + :class:`IndividualTable` instance. See the :ref:`individual text format + ` section for the details of the required + format and the :ref:`individual table definition + ` section for the required properties of + the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` + parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :param IndividualTable table: If specified write into this table. If not, + create a new :class:`.IndividualTable` instance. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.IndividualTable() + # Read the header and find the indexes of the required fields. + header = source.readline().strip("\n").split(sep) + flags_index = header.index("flags") + location_index = None + metadata_index = None + try: + location_index = header.index("location") + except ValueError: + pass + try: + metadata_index = header.index("metadata") + except ValueError: + pass + for line in source: + tokens = line.split(sep) + if len(tokens) >= 1: + flags = int(tokens[flags_index]) + location = () + if location_index is not None: + location_string = tokens[location_index] + if len(location_string) > 0: + location = tuple(map(float, location_string.split(","))) + metadata = b'' + if metadata_index is not None and metadata_index < len(tokens): + metadata = tokens[metadata_index].encode(encoding) + if base64_metadata: + metadata = base64.b64decode(metadata) + table.add_row( + flags=flags, location=location, metadata=metadata) + return table + + +def parse_nodes( + source, strict=True, encoding='utf8', base64_metadata=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of a node table and returns the corresponding :class:`NodeTable` + instance. See the :ref:`node text format ` section + for the details of the required format and the + :ref:`node table definition ` section for the + required properties of the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` + parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :param NodeTable table: If specified write into this table. If not, + create a new :class:`.NodeTable` instance. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.NodeTable() + # Read the header and find the indexes of the required fields. + header = source.readline().strip("\n").split(sep) + is_sample_index = header.index("is_sample") + time_index = header.index("time") + population_index = None + individual_index = None + metadata_index = None + try: + population_index = header.index("population") + except ValueError: + pass + try: + individual_index = header.index("individual") + except ValueError: + pass + try: + metadata_index = header.index("metadata") + except ValueError: + pass + for line in source: + tokens = line.split(sep) + if len(tokens) >= 2: + is_sample = int(tokens[is_sample_index]) + time = float(tokens[time_index]) + flags = 0 + if is_sample != 0: + flags |= NODE_IS_SAMPLE + population = NULL + if population_index is not None: + population = int(tokens[population_index]) + individual = NULL + if individual_index is not None: + individual = int(tokens[individual_index]) + metadata = b'' + if metadata_index is not None and metadata_index < len(tokens): + metadata = tokens[metadata_index].encode(encoding) + if base64_metadata: + metadata = base64.b64decode(metadata) + table.add_row( + flags=flags, time=time, population=population, + individual=individual, metadata=metadata) + return table + + +def parse_edges(source, strict=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of a edge table and returns the corresponding :class:`EdgeTable` + instance. See the :ref:`edge text format ` section + for the details of the required format and the + :ref:`edge table definition ` section for the + required properties of the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param EdgeTable table: If specified, write the edges into this table. If + not, create a new :class:`.EdgeTable` instance and return. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.EdgeTable() + header = source.readline().strip("\n").split(sep) + left_index = header.index("left") + right_index = header.index("right") + parent_index = header.index("parent") + children_index = header.index("child") + for line in source: + tokens = line.split(sep) + if len(tokens) >= 4: + left = float(tokens[left_index]) + right = float(tokens[right_index]) + parent = int(tokens[parent_index]) + children = tuple(map(int, tokens[children_index].split(","))) + for child in children: + table.add_row(left=left, right=right, parent=parent, child=child) + return table + + +def parse_sites( + source, strict=True, encoding='utf8', base64_metadata=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of a site table and returns the corresponding :class:`SiteTable` + instance. See the :ref:`site text format ` section + for the details of the required format and the + :ref:`site table definition ` section for the + required properties of the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` + parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :param SiteTable table: If specified write site into this table. If not, + create a new :class:`.SiteTable` instance. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.SiteTable() + header = source.readline().strip("\n").split(sep) + position_index = header.index("position") + ancestral_state_index = header.index("ancestral_state") + metadata_index = None + try: + metadata_index = header.index("metadata") + except ValueError: + pass + for line in source: + tokens = line.split(sep) + if len(tokens) >= 2: + position = float(tokens[position_index]) + ancestral_state = tokens[ancestral_state_index] + metadata = b'' + if metadata_index is not None and metadata_index < len(tokens): + metadata = tokens[metadata_index].encode(encoding) + if base64_metadata: + metadata = base64.b64decode(metadata) + table.add_row( + position=position, ancestral_state=ancestral_state, metadata=metadata) + return table + + +def parse_mutations( + source, strict=True, encoding='utf8', base64_metadata=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of a mutation table and returns the corresponding :class:`MutationTable` + instance. See the :ref:`mutation text format ` section + for the details of the required format and the + :ref:`mutation table definition ` section for the + required properties of the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` + parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :param MutationTable table: If specified, write mutations into this table. + If not, create a new :class:`.MutationTable` instance. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.MutationTable() + header = source.readline().strip("\n").split(sep) + site_index = header.index("site") + node_index = header.index("node") + derived_state_index = header.index("derived_state") + parent_index = None + parent = NULL + try: + parent_index = header.index("parent") + except ValueError: + pass + metadata_index = None + try: + metadata_index = header.index("metadata") + except ValueError: + pass + for line in source: + tokens = line.split(sep) + if len(tokens) >= 3: + site = int(tokens[site_index]) + node = int(tokens[node_index]) + derived_state = tokens[derived_state_index] + if parent_index is not None: + parent = int(tokens[parent_index]) + metadata = b'' + if metadata_index is not None and metadata_index < len(tokens): + metadata = tokens[metadata_index].encode(encoding) + if base64_metadata: + metadata = base64.b64decode(metadata) + table.add_row( + site=site, node=node, derived_state=derived_state, parent=parent, + metadata=metadata) + return table + + +def parse_populations( + source, strict=True, encoding='utf8', base64_metadata=True, table=None): + """ + Parse the specified file-like object containing a whitespace delimited + description of a population table and returns the corresponding + :class:`PopulationTable` instance. See the :ref:`population text format + ` section for the details of the required + format and the :ref:`population table definition + ` section for the required properties of + the contents. + + See :func:`.load_text` for a detailed explanation of the ``strict`` + parameter. + + :param stream source: The file-like object containing the text. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :param PopulationTable table: If specified write into this table. If not, + create a new :class:`.PopulationTable` instance. + """ + sep = None + if strict: + sep = "\t" + if table is None: + table = tables.PopulationTable() + # Read the header and find the indexes of the required fields. + header = source.readline().strip("\n").split(sep) + metadata_index = header.index("metadata") + for line in source: + tokens = line.split(sep) + if len(tokens) >= 1: + metadata = tokens[metadata_index].encode(encoding) + if base64_metadata: + metadata = base64.b64decode(metadata) + table.add_row(metadata=metadata) + return table + + +def load_text(nodes, edges, sites=None, mutations=None, individuals=None, + populations=None, sequence_length=0, strict=True, + encoding='utf8', base64_metadata=True): + """ + Parses the tree sequence data from the specified file-like objects, and + returns the resulting :class:`.TreeSequence` object. The format + for these files is documented in the :ref:`sec_text_file_format` section, + and is produced by the :meth:`.TreeSequence.dump_text` method. Further + properties required for an input tree sequence are described in the + :ref:`sec_valid_tree_sequence_requirements` section. This method is intended as a + convenient interface for importing external data into tskit; the binary + file format using by :meth:`tskit.load` is many times more efficient than + this text format. + + The ``nodes`` and ``edges`` parameters are mandatory and must be file-like + objects containing text with whitespace delimited columns, parsable by + :func:`parse_nodes` and :func:`parse_edges`, respectively. ``sites``, + ``mutations``, ``individuals`` and ``populations`` are optional, and must + be parsable by :func:`parse_sites`, :func:`parse_individuals`, + :func:`parse_populations`, and :func:`parse_mutations`, respectively. + + TODO: there is no method to parse the remaining tables at present, so + only tree sequences not requiring Population and Individual tables can + be loaded. This will be fixed: https://github.com/tskit-dev/msprime/issues/498 + + The ``sequence_length`` parameter determines the + :attr:`.TreeSequence.sequence_length` of the returned tree sequence. If it + is 0 or not specified, the value is taken to be the maximum right + coordinate of the input edges. This parameter is useful in degenerate + situations (such as when there are zero edges), but can usually be ignored. + + The ``strict`` parameter controls the field delimiting algorithm that + is used. If ``strict`` is True (the default), we require exactly one + tab character separating each field. If ``strict`` is False, a more relaxed + whitespace delimiting algorithm is used, such that any run of whitespace + is regarded as a field separator. In most situations, ``strict=False`` + is more convenient, but it can lead to error in certain situations. For + example, if a deletion is encoded in the mutation table this will not + be parseable when ``strict=False``. + + After parsing the tables, :func:`sort_tables` is called to ensure that + the loaded tables satisfy the tree sequence :ref:`ordering requirements + `. Note that this may result in the + IDs of various entities changing from their positions in the input file. + + :param stream nodes: The file-like object containing text describing a + :class:`.NodeTable`. + :param stream edges: The file-like object containing text + describing an :class:`.EdgeTable`. + :param stream sites: The file-like object containing text describing a + :class:`.SiteTable`. + :param stream mutations: The file-like object containing text + describing a :class:`MutationTable`. + :param stream individuals: The file-like object containing text + describing a :class:`IndividualTable`. + :param stream populations: The file-like object containing text + describing a :class:`PopulationTable`. + :param float sequence_length: The sequence length of the returned tree sequence. If + not supplied or zero this will be inferred from the set of edges. + :param bool strict: If True, require strict tab delimiting (default). If + False, a relaxed whitespace splitting algorithm is used. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + :return: The tree sequence object containing the information + stored in the specified file paths. + :rtype: :class:`tskit.TreeSequence` + """ + # We need to parse the edges so we can figure out the sequence length, and + # TableCollection.sequence_length is immutable so we need to create a temporary + # edge table. + edge_table = parse_edges(edges, strict=strict) + if sequence_length == 0 and len(edge_table) > 0: + sequence_length = edge_table.right.max() + tc = tables.TableCollection(sequence_length) + tc.edges.set_columns( + left=edge_table.left, right=edge_table.right, parent=edge_table.parent, + child=edge_table.child) + parse_nodes( + nodes, strict=strict, encoding=encoding, base64_metadata=base64_metadata, + table=tc.nodes) + # We need to add populations any referenced in the node table. + if len(tc.nodes) > 0: + max_population = tc.nodes.population.max() + if max_population != NULL: + for _ in range(max_population + 1): + tc.populations.add_row() + if sites is not None: + parse_sites( + sites, strict=strict, encoding=encoding, base64_metadata=base64_metadata, + table=tc.sites) + if mutations is not None: + parse_mutations( + mutations, strict=strict, encoding=encoding, + base64_metadata=base64_metadata, table=tc.mutations) + if individuals is not None: + parse_individuals( + individuals, strict=strict, encoding=encoding, + base64_metadata=base64_metadata, table=tc.individuals) + if populations is not None: + parse_populations( + populations, strict=strict, encoding=encoding, + base64_metadata=base64_metadata, table=tc.populations) + tc.sort() + return tc.tree_sequence() + + +class TreeSequence(object): + """ + A single tree sequence, as defined by the :ref:`data model `. + A TreeSequence instance can be created from a set of + :ref:`tables ` using + :meth:`.TableCollection.tree_sequence`; or loaded from a set of text files + using :func:`.load_text`; or, loaded from a native binary file using + :func:`load`. + + TreeSequences are immutable. To change the data held in a particular + tree sequence, first get the table information as a :class:`.TableCollection` + instance (using :meth:`.dump_tables`), edit those tables using the + :ref:`tables api `, and create a new tree sequence using + :meth:`.TableCollection.tree_sequence`. + + The :meth:`.trees` method iterates over all trees in a tree sequence, and + the :meth:`.variants` method iterates over all sites and their genotypes. + """ + + def __init__(self, ll_tree_sequence): + self._ll_tree_sequence = ll_tree_sequence + + @property + def ll_tree_sequence(self): + return self.get_ll_tree_sequence() + + def get_ll_tree_sequence(self): + return self._ll_tree_sequence + + @classmethod + def load(cls, path): + ts = _tskit.TreeSequence() + ts.load(path) + return TreeSequence(ts) + + @classmethod + def load_tables(cls, tables): + ts = _tskit.TreeSequence() + ts.load_tables(tables.ll_tables) + return TreeSequence(ts) + + def dump(self, path, zlib_compression=False): + """ + Writes the tree sequence to the specified file path. + + :param str path: The file path to write the TreeSequence to. + :param bool zlib_compression: This parameter is deprecated and ignored. + """ + if zlib_compression: + warnings.warn( + "The zlib_compression option is no longer supported and is ignored", + RuntimeWarning) + self._ll_tree_sequence.dump(path) + + @property + def tables(self): + """ + A copy of the tables underlying this tree sequence. See also + :meth:`.dump_tables`. + + .. warning:: This propery currently returns a copy of the tables + underlying a tree sequence but it may return a read-only + **view** in the future. Thus, if the tables will subsequently be + updated, please use the :meth:`.dump_tables` method instead as + this will always return a new copy of the TableCollection. + + :return: A :class:`.TableCollection` containing all a copy of the + tables underlying this tree sequence. + :rtype: TableCollection + """ + return self.dump_tables() + + def dump_tables(self): + """ + A copy of the tables defining this tree sequence. + + :return: A :class:`.TableCollection` containing all tables underlying + the tree sequence. + :rtype: TableCollection + """ + t = tables.TableCollection(sequence_length=self.sequence_length) + self._ll_tree_sequence.dump_tables(t.ll_tables) + return t + + def dump_text( + self, nodes=None, edges=None, sites=None, mutations=None, individuals=None, + populations=None, provenances=None, precision=6, encoding='utf8', + base64_metadata=True): + """ + Writes a text representation of the tables underlying the tree sequence + to the specified connections. + + If Base64 encoding is not used, then metadata will be saved directly, possibly + resulting in errors reading the tables back in if metadata includes whitespace. + + :param stream nodes: The file-like object (having a .write() method) to write + the NodeTable to. + :param stream edges: The file-like object to write the EdgeTable to. + :param stream sites: The file-like object to write the SiteTable to. + :param stream mutations: The file-like object to write the MutationTable to. + :param stream individuals: The file-like object to write the IndividualTable to. + :param stream populations: The file-like object to write the PopulationTable to. + :param stream provenances: The file-like object to write the ProvenanceTable to. + :param int precision: The number of digits of precision. + :param string encoding: Encoding used for text representation. + :param bool base64_metadata: If True, metadata is encoded using Base64 + encoding; otherwise, as plain text. + """ + + if nodes is not None: + print( + "id", "is_sample", "time", "population", "individual", "metadata", + sep="\t", file=nodes) + for node in self.nodes(): + metadata = node.metadata + if base64_metadata: + metadata = base64.b64encode(metadata).decode(encoding) + row = ( + "{id:d}\t" + "{is_sample:d}\t" + "{time:.{precision}f}\t" + "{population:d}\t" + "{individual:d}\t" + "{metadata}").format( + precision=precision, id=node.id, + is_sample=node.is_sample(), time=node.time, + population=node.population, + individual=node.individual, + metadata=metadata) + print(row, file=nodes) + + if edges is not None: + print("left", "right", "parent", "child", sep="\t", file=edges) + for edge in self.edges(): + row = ( + "{left:.{precision}f}\t" + "{right:.{precision}f}\t" + "{parent:d}\t" + "{child:d}").format( + precision=precision, left=edge.left, right=edge.right, + parent=edge.parent, child=edge.child) + print(row, file=edges) + + if sites is not None: + print("position", "ancestral_state", "metadata", sep="\t", file=sites) + for site in self.sites(): + metadata = site.metadata + if base64_metadata: + metadata = base64.b64encode(metadata).decode(encoding) + row = ( + "{position:.{precision}f}\t" + "{ancestral_state}\t" + "{metadata}").format( + precision=precision, position=site.position, + ancestral_state=site.ancestral_state, + metadata=metadata) + print(row, file=sites) + + if mutations is not None: + print( + "site", "node", "derived_state", "parent", "metadata", + sep="\t", file=mutations) + for site in self.sites(): + for mutation in site.mutations: + metadata = mutation.metadata + if base64_metadata: + metadata = base64.b64encode(metadata).decode(encoding) + row = ( + "{site}\t" + "{node}\t" + "{derived_state}\t" + "{parent}\t" + "{metadata}").format( + site=mutation.site, node=mutation.node, + derived_state=mutation.derived_state, + parent=mutation.parent, + metadata=metadata) + print(row, file=mutations) + + if individuals is not None: + print( + "id", "flags", "location", "metadata", + sep="\t", file=individuals) + for individual in self.individuals(): + metadata = individual.metadata + if base64_metadata: + metadata = base64.b64encode(metadata).decode(encoding) + location = ",".join(map(str, individual.location)) + row = ( + "{id}\t" + "{flags}\t" + "{location}\t" + "{metadata}").format( + id=individual.id, flags=individual.flags, + location=location, metadata=metadata) + print(row, file=individuals) + + if populations is not None: + print( + "id", "metadata", + sep="\t", file=populations) + for population in self.populations(): + metadata = population.metadata + if base64_metadata: + metadata = base64.b64encode(metadata).decode(encoding) + row = ( + "{id}\t" + "{metadata}").format(id=population.id, metadata=metadata) + print(row, file=populations) + + if provenances is not None: + print("id", "timestamp", "record", sep="\t", file=provenances) + for provenance in self.provenances(): + row = ( + "{id}\t" + "{timestamp}\t" + "{record}\t").format( + id=provenance.id, + timestamp=provenance.timestamp, + record=provenance.record) + print(row, file=provenances) + + # num_samples was originally called sample_size, and so we must keep sample_size + # around as a deprecated alias. + @property + def num_samples(self): + """ + Returns the number of samples in this tree sequence. This is the number + of sample nodes in each tree. + + :return: The number of sample nodes in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_samples() + + @property + def sample_size(self): + # Deprecated alias for num_samples + return self.num_samples + + def get_sample_size(self): + # Deprecated alias for num_samples + return self.num_samples + + @property + def file_uuid(self): + return self._ll_tree_sequence.get_file_uuid() + + @property + def sequence_length(self): + """ + Returns the sequence length in this tree sequence. This defines the + genomic scale over which tree coordinates are defined. Given a + tree sequence with a sequence length :math:`L`, the constituent + trees will be defined over the half-closed interval + :math:`[0, L)`. Each tree then covers some subset of this + interval --- see :meth:`tskit.Tree.get_interval` for details. + + :return: The length of the sequence in this tree sequence in bases. + :rtype: float + """ + return self.get_sequence_length() + + def get_sequence_length(self): + return self._ll_tree_sequence.get_sequence_length() + + @property + def num_edges(self): + """ + Returns the number of :ref:`edges ` in this + tree sequence. + + :return: The number of edges in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_edges() + + def get_num_trees(self): + # Deprecated alias for self.num_trees + return self.num_trees + + @property + def num_trees(self): + """ + Returns the number of distinct trees in this tree sequence. This + is equal to the number of trees returned by the :meth:`.trees` + method. + + :return: The number of trees in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_trees() + + def get_num_sites(self): + # Deprecated alias for self.num_sites + return self._ll_tree_sequence.get_num_sites() + + @property + def num_sites(self): + """ + Returns the number of :ref:`sites ` in + this tree sequence. + + :return: The number of sites in this tree sequence. + :rtype: int + """ + return self.get_num_sites() + + def get_num_mutations(self): + # Deprecated alias for self.num_mutations + return self.num_mutations + + @property + def num_mutations(self): + """ + Returns the number of :ref:`mutations ` + in this tree sequence. + + :return: The number of mutations in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_mutations() + + def get_num_nodes(self): + # Deprecated alias for self.num_nodes + return self.num_nodes + + @property + def num_individuals(self): + """ + Returns the number of :ref:`individuals ` in + this tree sequence. + + :return: The number of individuals in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_individuals() + + @property + def num_nodes(self): + """ + Returns the number of :ref:`nodes ` in + this tree sequence. + + :return: The number of nodes in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_nodes() + + @property + def num_provenances(self): + """ + Returns the number of :ref:`provenances ` + in this tree sequence. + + :return: The number of provenances in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_provenances() + + @property + def num_populations(self): + """ + Returns the number of :ref:`populations ` + in this tree sequence. + + :return: The number of populations in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_populations() + + @property + def num_migrations(self): + """ + Returns the number of :ref:`migrations ` + in this tree sequence. + + :return: The number of migrations in this tree sequence. + :rtype: int + """ + return self._ll_tree_sequence.get_num_migrations() + + def migrations(self): + """ + Returns an iterator over all the + :ref:`migrations ` in this tree sequence. + + Migrations are returned in nondecreasing order of the ``time`` value. + + :return: An iterator over all migrations. + :rtype: iter(:class:`.Migration`) + """ + for j in range(self._ll_tree_sequence.get_num_migrations()): + yield Migration(*self._ll_tree_sequence.get_migration(j)) + + def individuals(self): + """ + Returns an iterator over all the + :ref:`individuals ` in this tree sequence. + + :return: An iterator over all individuals. + :rtype: iter(:class:`.Individual`) + """ + for j in range(self.num_individuals): + yield self.individual(j) + + def nodes(self): + """ + Returns an iterator over all the :ref:`nodes ` + in this tree sequence. + + :return: An iterator over all nodes. + :rtype: iter(:class:`.Node`) + """ + for j in range(self.num_nodes): + yield self.node(j) + + def edges(self): + """ + Returns an iterator over all the :ref:`edges ` + in this tree sequence. Edges are returned in the order required + for a :ref:`valid tree sequence `. So, + edges are guaranteed to be ordered such that (a) all parents with a + given ID are contiguous; (b) edges are returned in non-descreasing + order of parent time ago; (c) within the edges for a given parent, edges + are sorted first by child ID and then by left coordinate. + + :return: An iterator over all edges. + :rtype: iter(:class:`.Edge`) + """ + for j in range(self.num_edges): + left, right, parent, child = self._ll_tree_sequence.get_edge(j) + yield Edge(left=left, right=right, parent=parent, child=child) + + def edgesets(self): + # TODO the order that these records are returned in is not well specified. + # Hopefully this does not matter, and we can just state that the ordering + # should not be depended on. + children = collections.defaultdict(set) + active_edgesets = {} + for (left, right), edges_out, edges_in in self.edge_diffs(): + # Complete and return any edgesets that are affected by this tree + # transition + parents = iter(edge.parent for edge in itertools.chain(edges_out, edges_in)) + for parent in parents: + if parent in active_edgesets: + edgeset = active_edgesets.pop(parent) + edgeset.right = left + edgeset.children = sorted(children[parent]) + yield edgeset + for edge in edges_out: + children[edge.parent].remove(edge.child) + for edge in edges_in: + children[edge.parent].add(edge.child) + # Update the active edgesets + for edge in itertools.chain(edges_out, edges_in): + if len(children[edge.parent]) > 0 and edge.parent not in active_edgesets: + active_edgesets[edge.parent] = Edgeset(left, right, edge.parent, []) + + for parent in active_edgesets.keys(): + edgeset = active_edgesets[parent] + edgeset.right = self.sequence_length + edgeset.children = sorted(children[edgeset.parent]) + yield edgeset + + def edge_diffs(self): + """ + Returns an iterator over all the edges that are inserted and removed to + build the trees as we move from left-to-right along the tree sequence. + The iterator yields a sequence of 3-tuples, ``(interval, edges_out, + edges_in)``. The ``interval`` is a pair ``(left, right)`` representing + the genomic interval (see :attr:`Tree.interval`). The + ``edges_out`` value is a tuple of the edges that were just-removed to + create the tree covering the interval (hence, ``edges_out`` will always + be empty for the first tree). The ``edges_in`` value is a tuple of + edges that were just inserted to contruct the tree convering the + current interval. + + :return: An iterator over the (interval, edges_out, edges_in) tuples. + :rtype: iter(tuple, tuple, tuple) + """ + iterator = _tskit.TreeDiffIterator(self._ll_tree_sequence) + for interval, edge_tuples_out, edge_tuples_in in iterator: + edges_out = [Edge(*e) for e in edge_tuples_out] + edges_in = [Edge(*e) for e in edge_tuples_in] + yield interval, edges_out, edges_in + + def sites(self): + """ + Returns an iterator over all the :ref:`sites ` + in this tree sequence. Sites are returned in order of increasing ID + (and also position). See the :class:`Site` class for details on + the available fields for each site. + + :return: An iterator over all sites. + :rtype: iter(:class:`.Site`) + """ + for j in range(self.num_sites): + yield self.site(j) + + def mutations(self): + """ + Returns an iterator over all the + :ref:`mutations ` in this tree sequence. + Mutations are returned in order of nondecreasing site ID. + See the :class:`Mutation` class for details on the available fields for + each mutation. + + The returned iterator is equivalent to iterating over all sites + and all mutations in each site, i.e.:: + + >>> for site in tree_sequence.sites(): + >>> for mutation in site.mutations: + >>> yield mutation + + :return: An iterator over all mutations in this tree sequence. + :rtype: iter(:class:`.Mutation`) + """ + for site in self.sites(): + for mutation in site.mutations: + yield add_deprecated_mutation_attrs(site, mutation) + + def populations(self): + """ + Returns an iterator over all the + :ref:`populations ` in this tree sequence. + + :return: An iterator over all populations. + :rtype: iter(:class:`.Population`) + """ + for j in range(self.num_populations): + yield self.population(j) + + def provenances(self): + """ + Returns an iterator over all the + :ref:`provenances ` in this tree sequence. + + :return: An iterator over all provenances. + :rtype: iter(:class:`.Provenance`) + """ + for j in range(self.num_provenances): + yield self.provenance(j) + + def breakpoints(self): + """ + Returns an iterator over the breakpoints along the chromosome, + including the two extreme points 0 and L. This is equivalent to + + >>> [0] + [t.get_interval()[1] for t in self.trees()] + + although we do not build an explicit list. + + :return: An iterator over all the breakpoints along the simulated + sequence. + :rtype: iter + """ + yield 0 + for t in self.trees(): + yield t.get_interval()[1] + + def first(self): + """ + Returns the first tree in this :class:`.TreeSequence`. To iterate over all + trees in the sequence, use the :meth:`.trees` method. + + Currently does not support the extra options for the :meth:`.trees` method. + + :return: The first tree in this tree sequence. + :rtype: :class:`.Tree`. + """ + return next(self.trees()) + + def trees( + self, tracked_samples=None, sample_counts=True, sample_lists=False, + tracked_leaves=None, leaf_counts=None, leaf_lists=None): + """ + Returns an iterator over the trees in this tree sequence. Each value + returned in this iterator is an instance of :class:`.Tree`. + + The ``sample_counts`` and ``sample_lists`` parameters control the + features that are enabled for the resulting trees. If ``sample_counts`` + is True, then it is possible to count the number of samples underneath + a particular node in constant time using the :meth:`.num_samples` + method. If ``sample_lists`` is True a more efficient algorithm is + used in the :meth:`.Tree.samples` method. + + The ``tracked_samples`` parameter can be used to efficiently count the + number of samples in a given set that exist in a particular subtree + using the :meth:`.Tree.get_num_tracked_samples` method. It is an + error to use the ``tracked_samples`` parameter when the ``sample_counts`` + flag is False. + + :warning: Do not store the results of this iterator in a list! + For performance reasons, the same underlying object is used + for every tree returned which will most likely lead to unexpected + behaviour. + + :param list tracked_samples: The list of samples to be tracked and + counted using the :meth:`.Tree.get_num_tracked_samples` + method. + :param bool sample_counts: If True, support constant time sample counts + via the :meth:`.Tree.num_samples` and + :meth:`.Tree.get_num_tracked_samples` methods. + :param bool sample_lists: If True, provide more efficient access + to the samples beneath a give node using the + :meth:`.Tree.samples` method. + :return: An iterator over the sparse trees in this tree sequence. + :rtype: iter + """ + # tracked_leaves, leaf_counts and leaf_lists are deprecated aliases + # for tracked_samples, sample_counts and sample_lists respectively. + # These are left over from an older version of the API when leaves + # and samples were synonymous. + if tracked_leaves is not None: + tracked_samples = tracked_leaves + if leaf_counts is not None: + sample_counts = leaf_counts + if leaf_lists is not None: + sample_lists = leaf_lists + flags = 0 + if sample_counts: + flags |= _tskit.SAMPLE_COUNTS + elif tracked_samples is not None: + raise ValueError("Cannot set tracked_samples without sample_counts") + if sample_lists: + flags |= _tskit.SAMPLE_LISTS + kwargs = {"flags": flags} + if tracked_samples is not None: + # TODO remove this when we allow numpy arrays in the low-level API. + kwargs["tracked_samples"] = list(tracked_samples) + ll_tree = _tskit.Tree(self._ll_tree_sequence, **kwargs) + iterator = _tskit.TreeIterator(ll_tree) + tree = Tree(ll_tree, self) + for _ in iterator: + yield tree + + def haplotypes(self): + """ + Returns an iterator over the haplotypes resulting from the trees + and mutations in this tree sequence as a string. + The iterator returns a total of :math:`n` strings, each of which + contains :math:`s` characters (:math:`n` is the sample size + returned by :attr:`tskit.TreeSequence.num_samples` and + :math:`s` is the number of sites returned by + :attr:`tskit.TreeSequence.num_sites`). The first + string returned is the haplotype for sample `0`, and so on. + For a given haplotype ``h``, the value of ``h[j]`` is the observed + allelic state at site ``j``. + + See also the :meth:`variants` iterator for site-centric access + to sample genotypes. + + This method is only supported for single-letter alleles. + + :return: An iterator over the haplotype strings for the samples in + this tree sequence. + :rtype: iter + :raises: LibraryError if called on a tree sequence containing + multiletter alleles. + """ + hapgen = _tskit.HaplotypeGenerator(self._ll_tree_sequence) + j = 0 + # Would use range here except for Python 2. + while j < self.num_samples: + yield hapgen.get_haplotype(j) + j += 1 + + # Samples is experimental for now, so we don't document it. + def variants(self, as_bytes=False, samples=None): + """ + Returns an iterator over the variants in this tree sequence. See the + :class:`Variant` class for details on the fields of each returned + object. By default the ``genotypes`` for the variants are numpy arrays, + corresponding to indexes into the ``alleles`` array. If the + ``as_bytes`` parameter is true, these allelic values are recorded + directly into a bytes array. + + .. note:: + The ``as_bytes`` parameter is kept as a compatibility + option for older code. It is not the recommended way of + accessing variant data, and will be deprecated in a later + release. Another method will be provided to obtain the allelic + states for each site directly. + + :param bool as_bytes: If True, the genotype values will be returned + as a Python bytes object. This is useful in certain situations + (i.e., directly printing the genotypes) or when numpy is + not available. Otherwise, genotypes are returned as a numpy + array (the default). + :return: An iterator of all variants this tree sequence. + :rtype: iter(:class:`Variant`) + """ + # See comments for the Variant type for discussion on why the + # present form was chosen. + iterator = _tskit.VariantGenerator(self._ll_tree_sequence, samples=samples) + for site_id, genotypes, alleles in iterator: + site = self.site(site_id) + if as_bytes: + if any(len(allele) > 1 for allele in alleles): + raise ValueError( + "as_bytes only supported for single-letter alleles") + bytes_genotypes = np.empty(self.num_samples, dtype=np.uint8) + lookup = np.array([ord(a[0]) for a in alleles], dtype=np.uint8) + bytes_genotypes[:] = lookup[genotypes] + genotypes = bytes_genotypes.tobytes() + yield Variant(site, alleles, genotypes) + + def genotype_matrix(self): + """ + Returns an :math:`m \\times n` numpy array of the genotypes in this + tree sequence, where :math:`m` is the number of sites and :math:`n` + the number of samples. The genotypes are the indexes into the array + of ``alleles``, as described for the :class:`Variant` class. The value + 0 always corresponds to the ancestal state, and values > 0 represent + distinct derived states. + + .. warning:: + This method can consume a **very large** amount of memory! If + all genotypes are not needed at once, it is usually better to + access them sequentially using the :meth:`.variants` iterator. + + :return: The full matrix of genotypes. + :rtype: numpy.ndarray (dtype=np.uint8) + """ + return self._ll_tree_sequence.get_genotype_matrix() + + def get_pairwise_diversity(self, samples=None): + # Deprecated alias for self.pairwise_diversity + return self.pairwise_diversity(samples) + + def pairwise_diversity(self, samples=None): + """ + Returns the value of :math:`\\pi`, the pairwise nucleotide site + diversity, the average number of mutations per unit of genome length + that differ between a randomly chosen pair of samples. If `samples` is + specified, calculate the diversity within this set. + + .. note:: This method does not currently support sites that have more + than one mutation. Using it on such a tree sequence will raise + a LibraryError with an "Unsupported operation" message. + + :param iterable samples: The set of samples within which we calculate + the diversity. If None, calculate diversity within the entire sample. + :return: The pairwise nucleotide site diversity. + :rtype: float + """ + if samples is None: + samples = self.samples() + return self._ll_tree_sequence.get_pairwise_diversity(list(samples)) + + def mean_descendants(self, reference_sets): + """ + Computes for every node the mean number of samples in each of the + `reference_sets` that descend from that node, averaged over the + portions of the genome for which the node is ancestral to *any* sample. + The output is an array, `C[node, j]`, which reports the total length of + all genomes in `reference_sets[j]` that inherit from `node`, divided by + the total length of the genome on which `node` is an ancestor to any + sample in the tree sequence. + + .. note:: This interface *may change*, particularly the normalization by + proportion of the genome that `node` is an ancestor to anyone. + + :param iterable reference sets: A list of lists of node IDs. + :return: An array with dimensions (number of nodes in the tree sequence, + number of reference sets) + """ + return self._ll_tree_sequence.mean_descendants(reference_sets) + + def genealogical_nearest_neighbours(self, focal, reference_sets, num_threads=0): + # TODO this may not be a good name because there is another version of the + # statistic which may be occasionally useful where we return the tree-by-tree + # value. We could do this by adding an extra dimension to the returned array + # which would give the values tree-by-tree. The tree lengths can be computed + # easily enough, *but* there may be occasions when the statistic isn't + # defined over particular trees. + # + # Probably the best thing to do is to add an option which allows us to compute + # the tree-wise GNNs, returning the values in a higher dimensional array + # rather than have another function entirely. + if num_threads <= 0: + return self._ll_tree_sequence.genealogical_nearest_neighbours( + focal, reference_sets) + else: + if IS_PY2: + raise ValueError("Threads not supported on Python 2.") + worker = functools.partial( + self._ll_tree_sequence.genealogical_nearest_neighbours, + reference_sets=reference_sets) + focal = np.array(focal).astype(np.int32) + splits = np.array_split(focal, num_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: + arrays = pool.map(worker, splits) + return np.vstack(arrays) + + def individual(self, id_): + """ + Returns the :ref:`individual ` + in this tree sequence with the specified ID. + + :rtype: :class:`.Individual` + """ + flags, location, metadata, nodes = self._ll_tree_sequence.get_individual(id_) + return Individual( + id_=id_, flags=flags, location=location, metadata=metadata, nodes=nodes) + + def node(self, id_): + """ + Returns the :ref:`node ` in this tree sequence + with the specified ID. + + :rtype: :class:`.Node` + """ + (flags, time, population, individual, + metadata) = self._ll_tree_sequence.get_node(id_) + return Node( + id_=id_, flags=flags, time=time, population=population, + individual=individual, metadata=metadata) + + def mutation(self, id_): + """ + Returns the :ref:`mutation ` in this tree sequence + with the specified ID. + + :rtype: :class:`.Mutation` + """ + ll_mut = self._ll_tree_sequence.get_mutation(id_) + return Mutation( + id_=id_, site=ll_mut[0], node=ll_mut[1], derived_state=ll_mut[2], + parent=ll_mut[3], metadata=ll_mut[4]) + + def site(self, id_): + """ + Returns the :ref:`site ` in this tree sequence + with the specified ID. + + :rtype: :class:`.Site` + """ + ll_site = self._ll_tree_sequence.get_site(id_) + pos, ancestral_state, ll_mutations, _, metadata = ll_site + mutations = [self.mutation(mut_id) for mut_id in ll_mutations] + return Site( + id_=id_, position=pos, ancestral_state=ancestral_state, + mutations=mutations, metadata=metadata) + + def population(self, id_): + """ + Returns the :ref:`population ` + in this tree sequence with the specified ID. + + :rtype: :class:`.Population` + """ + metadata, = self._ll_tree_sequence.get_population(id_) + return Population(id_=id_, metadata=metadata) + + def provenance(self, id_): + timestamp, record = self._ll_tree_sequence.get_provenance(id_) + return Provenance(id_=id_, timestamp=timestamp, record=record) + + def get_samples(self, population_id=None): + # Deprecated alias for samples() + return self.samples(population_id) + + def samples(self, population=None, population_id=None): + """ + Returns an array of the sample node IDs in this tree sequence. If the + ``population`` parameter is specified, only return sample IDs from this + population. + + :param int population: The population of interest. If None, + return all samples. + :param int population_id: Deprecated alias for ``population``. + :return: A numpy array of the node IDs for the samples of interest. + :rtype: numpy.ndarray (dtype=np.int32) + """ + if population is not None and population_id is not None: + raise ValueError( + "population_id and population are aliases. Cannot specify both") + if population_id is not None: + population = population_id + # TODO the low-level tree sequence should perform this operation natively + # and return a numpy array. + samples = self._ll_tree_sequence.get_samples() + if population is not None: + samples = [ + u for u in samples if self.node(u).population == population] + return np.array(samples, dtype=np.int32) + + def write_vcf(self, output, ploidy=1, contig_id="1"): + """ + Writes a VCF formatted file to the specified file-like object. If a + ploidy value is supplied, allele values are combined among adjacent + samples to form a phased genotype of the required ploidy. For example, + if we have a ploidy of 2 and a sample of size 6, then we will have + 3 diploid samples in the output, consisting of the combined alleles + for samples [0, 1], [2, 3] and [4, 5]. If we had alleles 011110 at + a particular variant, then we would output the genotypes 0|1, 1|1 + and 1|0 in VCF. Sample names are generated by appending the index + to the prefix ``msp_`` such that we would have the sample names + ``msp_0``, ``msp_1`` and ``msp_2`` in the running example. + + Example usage: + + >>> with open("output.vcf", "w") as vcf_file: + >>> tree_sequence.write_vcf(vcf_file, 2) + + .. warning:: + This output function does not currently use information in the + :class:`IndividualTable`, and so will only correctly produce + non-haploid output if the nodes corresponding to each individual + are contiguous as described above. + + :param File output: The file-like object to write the VCF output. + :param int ploidy: The ploidy of the individuals to be written to + VCF. This sample size must be evenly divisible by ploidy. + :param str contig_id: The value of the CHROM column in the output VCF. + """ + if ploidy < 1: + raise ValueError("Ploidy must be >= sample size") + if self.get_sample_size() % ploidy != 0: + raise ValueError("Sample size must be divisible by ploidy") + converter = _tskit.VcfConverter( + self._ll_tree_sequence, ploidy=ploidy, contig_id=contig_id) + output.write(converter.get_header()) + for record in converter: + output.write(record) + + def simplify( + self, samples=None, + filter_zero_mutation_sites=None, # Deprecated alias for filter_sites + map_nodes=False, + reduce_to_site_topology=False, + filter_populations=True, filter_individuals=True, filter_sites=True, + record_provenance=True): + """ + Returns a simplified tree sequence that retains only the history of + the nodes given in the list ``samples``. If ``map_nodes`` is true, + also return a numpy array mapping the node IDs in this tree sequence to + their node IDs in the simplified tree tree sequence. If a node ``u`` is not + present in the new tree sequence, the value of this mapping will be + NULL (-1). + + In the returned tree sequence, the node with ID ``0`` corresponds to + ``samples[0]``, node ``1`` corresponds to ``samples[1]``, and so on. + Besides the samples, node IDs in the returned tree sequence are then + allocated sequentially in time order. + + If you wish to simplify a set of tables that do not satisfy all + requirements for building a TreeSequence, then use + :meth:`TableCollection.simplify`. + + If the ``reduce_to_site_topology`` parameter is True, the returned tree + sequence will contain only topological information that is necessary to + represent the trees that contain sites. If there are zero sites in this + tree sequence, this will result in an output tree sequence with zero edges. + When the number of sites is greater than zero, every tree in the output + tree sequence will contain at least one site. For a given site, the + topology of the tree containing that site will be identical + (up to node ID remapping) to the topology of the corresponding tree + in the input tree sequence. + + If ``filter_populations``, ``filter_individuals`` or ``filter_sites`` is + True, any of the corresponding objects that are not referenced elsewhere + are filtered out. As this is the default behaviour, it is important to + realise IDs for these objects may change through simplification. By setting + these parameters to False, however, the corresponding tables can be preserved + without changes. + + :param list samples: The list of nodes for which to retain information. This + may be a numpy array (or array-like) object (dtype=np.int32). + :param bool filter_zero_mutation_sites: Deprecated alias for ``filter_sites``. + :param bool map_nodes: If True, return a tuple containing the resulting + tree sequence and a numpy array mapping node IDs in the current tree + sequence to their corresponding node IDs in the returned tree sequence. + If False (the default), return only the tree sequence object itself. + :param bool reduce_to_site_topology: Whether to reduce the topology down + to the trees that are present at sites. (Default: False) + :param bool filter_populations: If True, remove any populations that are + not referenced by nodes after simplification; new population IDs are + allocated sequentially from zero. If False, the population table will + not be altered in any way. (Default: True) + :param bool filter_individuals: If True, remove any individuals that are + not referenced by nodes after simplification; new individual IDs are + allocated sequentially from zero. If False, the individual table will + not be altered in any way. (Default: True) + :param bool filter_sites: If True, remove any sites that are + not referenced by mutations after simplification; new site IDs are + allocated sequentially from zero. If False, the site table will not + be altered in any way. (Default: True) + :param bool record_provenance: If True, record details of this call to + simplify in the returned tree sequence's provenance information + (Default: True). + :return: The simplified tree sequence, or (if ``map_nodes`` is True) + a tuple consisting of the simplified tree sequence and a numpy array + mapping source node IDs to their corresponding IDs in the new tree + sequence. + :rtype: .TreeSequence or a (.TreeSequence, numpy.array) tuple + """ + tables = self.dump_tables() + if samples is None: + samples = self.get_samples() + assert tables.sequence_length == self.sequence_length + node_map = tables.simplify( + samples=samples, + filter_zero_mutation_sites=filter_zero_mutation_sites, + reduce_to_site_topology=reduce_to_site_topology, + filter_populations=filter_populations, + filter_individuals=filter_individuals, + filter_sites=filter_sites) + if record_provenance: + # TODO add simplify arguments here + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = { + "command": "simplify", + "TODO": "add simplify parameters" + } + tables.provenances.add_row(record=json.dumps( + provenance.get_provenance_dict(parameters))) + new_ts = tables.tree_sequence() + assert new_ts.sequence_length == self.sequence_length + if map_nodes: + return new_ts, node_map + else: + return new_ts + + ############################################ + # + # Deprecated APIs. These are either already unsupported, or will be unsupported in a + # later release. + # + ############################################ + + def get_time(self, u): + # Deprecated. Use ts.node(u).time + if u < 0 or u >= self.get_num_nodes(): + raise ValueError("ID out of bounds") + node = self.node(u) + return node.time + + def get_population(self, u): + # Deprecated. Use ts.node(u).population + if u < 0 or u >= self.get_num_nodes(): + raise ValueError("ID out of bounds") + node = self.node(u) + return node.population + + def records(self): + # Deprecated. Use either ts.edges() or ts.edgesets(). + t = [node.time for node in self.nodes()] + pop = [node.population for node in self.nodes()] + for e in self.edgesets(): + yield CoalescenceRecord( + e.left, e.right, e.parent, e.children, t[e.parent], pop[e.parent]) + + # Unsupported old methods. + + def get_num_records(self): + raise NotImplementedError( + "This method is no longer supported. Please use the " + "TreeSequence.num_edges if possible to work with edges rather " + "than coalescence records. If not, please use len(list(ts.edgesets())) " + "which should return the number of coalescence records, as previously " + "defined. Please open an issue on GitHub if this is " + "important for your workflow.") + + def diffs(self): + raise NotImplementedError( + "This method is no longer supported. Please use the " + "TreeSequence.edge_diffs() method instead") + + def newick_trees(self, precision=3, breakpoints=None, Ne=1): + raise NotImplementedError( + "This method is no longer supported. Please use the Tree.newick" + " method instead") diff --git a/setup.py b/setup.py deleted file mode 100644 index 61c3967f23..0000000000 --- a/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import setup - -setup( - name="tskit", - version="0.0.0", - description="The tree sequence toolkit", - long_description="**HOLDING PAGE**; real package coming soon!", - url="https://github.com/tskit-dev/tskit", - author="The tskit developers", - author_email="jerome.kelleher@well.ox.uk", - packages=["tskit"], - zip_safe=False -) diff --git a/tskit/__init__.py b/tskit/__init__.py deleted file mode 100644 index 0221df9c88..0000000000 --- a/tskit/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Skeleton tskit module. Doesn't do anything at the moment. -"""